diff --git a/flink-core/src/main/java/org/apache/flink/core/fs/OffsetAwareOutputStream.java b/flink-core/src/main/java/org/apache/flink/core/fs/OffsetAwareOutputStream.java index 3ee4b761e1b304..375c95da25afa0 100644 --- a/flink-core/src/main/java/org/apache/flink/core/fs/OffsetAwareOutputStream.java +++ b/flink-core/src/main/java/org/apache/flink/core/fs/OffsetAwareOutputStream.java @@ -35,7 +35,7 @@ public final class OffsetAwareOutputStream implements Closeable { private long position; - OffsetAwareOutputStream(OutputStream currentOut, long position) { + public OffsetAwareOutputStream(OutputStream currentOut, long position) { this.currentOut = checkNotNull(currentOut); this.position = position; } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateCheckpointWriter.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateCheckpointWriter.java index 4173bb7140e787..186c8146c3c703 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateCheckpointWriter.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateCheckpointWriter.java @@ -19,6 +19,7 @@ import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.runtime.checkpoint.CheckpointException; +import org.apache.flink.runtime.checkpoint.channel.FetchedChannelStateReader.SpillSegment; import org.apache.flink.runtime.io.network.buffer.Buffer; import org.apache.flink.runtime.io.network.logger.NetworkActionsLogger; import org.apache.flink.runtime.jobgraph.JobVertexID; @@ -41,6 +42,7 @@ import java.util.HashSet; import java.util.Map; import java.util.Objects; +import java.util.Optional; import java.util.Set; import static org.apache.flink.runtime.checkpoint.CheckpointFailureReason.CHANNEL_STATE_SHARED_STREAM_EXCEPTION; @@ -161,6 +163,48 @@ void writeInput( } } + void writeInputFromSpill( + JobVertexID jobVertexID, int subtaskIndex, FetchedChannelStateReader reader) { + if (isDone()) { + try { + reader.close(); + } catch (Exception ignored) { + } + return; + } + ChannelStatePendingResult pendingResult = + getChannelStatePendingResult(jobVertexID, subtaskIndex); + runWithChecks( + () -> { + checkState(!pendingResult.isAllInputsReceived()); + try { + String action = "ChannelStateCheckpointWriter#writeInputFromSpill"; + Optional next; + while ((next = reader.nextSegment()).isPresent()) { + SpillSegment seg = next.get(); + long offset = checkpointStream.getPos(); + try (AutoCloseable ignored = + NetworkActionsLogger.measureIO(action, seg.channelInfo())) { + serializer.writeData(dataStream, seg.bodyStream(), seg.length()); + } + long size = checkpointStream.getPos() - offset; + pendingResult + .getInputChannelOffsets() + .computeIfAbsent( + seg.channelInfo(), unused -> new StateContentMetaInfo()) + .withDataAdded(offset, size); + NetworkActionsLogger.tracePersist( + action, + seg.length() + " bytes", + seg.channelInfo(), + checkpointId); + } + } finally { + reader.close(); + } + }); + } + void writeOutput( JobVertexID jobVertexID, int subtaskIndex, ResultSubpartitionInfo info, Buffer buffer) { try { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateFilteringHandler.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateFilteringHandler.java index b257c3b40544e0..0b6976068d7abd 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateFilteringHandler.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateFilteringHandler.java @@ -101,24 +101,18 @@ public static ChannelStateFilteringHandler createFromContext( } /** - * Filters a recovered buffer from the specified virtual channel, returning new buffers - * containing only the records that belong to the current subtask. - * - *

One source buffer may produce 0 to N result buffers: 0 if all records are filtered out, - * and potentially more than 1 when a spanning record completes in this buffer. The deserializer - * caches partial record data from previous buffers, so the output may contain data that was not - * in the current source buffer, causing the total output size to exceed one buffer capacity. - * This can happen with any spanning record regardless of its size. - * - * @return filtered buffers, possibly empty if all records were filtered out. + * Filters {@code sourceBuffer} through the virtual channel identified by {@code gateIndex} / + * {@code oldChannelIndex}, appending each surviving record (length-prefixed) into {@code + * outputSerializer}. One call may emit 0..N records depending on the filter result and whether + * records spanning previous buffers complete here. The caller owns the segment boundary. */ - public List filterAndRewrite( + public void filterAndRewrite( int gateIndex, int oldSubtaskIndex, int oldChannelIndex, Buffer sourceBuffer, - BufferSupplier bufferSupplier) - throws IOException, InterruptedException { + DataOutputSerializer outputSerializer) + throws IOException { if (gateIndex < 0 || gateIndex >= gateHandlers.length) { throw new IllegalStateException( @@ -135,8 +129,8 @@ public List filterAndRewrite( + gateIndex + ". This gate is not a network input and should not have recovered buffers."); } - return gateHandler.filterAndRewrite( - oldSubtaskIndex, oldChannelIndex, sourceBuffer, bufferSupplier); + gateHandler.filterAndRewrite( + oldSubtaskIndex, oldChannelIndex, sourceBuffer, outputSerializer); } /** Returns {@code true} if any virtual channel has a partial (spanning) record pending. */ @@ -215,7 +209,8 @@ private static GateFilterHandler createGateHandler( : VirtualChannelRecordFilterFactory.createPassThroughFilter(); RecordDeserializer> deserializer = - createDeserializer(filterContext.getTmpDirectories()); + new SpillingAdaptiveSpanningRecordDeserializer<>( + filterContext.getTmpDirectories()); VirtualChannel vc = new VirtualChannel<>(deserializer, recordFilter); gateVirtualChannels.put(key, vc); @@ -246,26 +241,10 @@ private static int[] getOldChannelIndexes(RescaleMappings channelMapping, int nu return oldIndexes.stream().mapToInt(Integer::intValue).toArray(); } - private static RecordDeserializer> createDeserializer( - String[] tmpDirectories) { - if (tmpDirectories != null && tmpDirectories.length > 0) { - return new SpillingAdaptiveSpanningRecordDeserializer<>(tmpDirectories); - } else { - String[] defaultDirs = new String[] {System.getProperty("java.io.tmpdir")}; - return new SpillingAdaptiveSpanningRecordDeserializer<>(defaultDirs); - } - } - // ------------------------------------------------------------------------------------------- // Inner classes // ------------------------------------------------------------------------------------------- - /** Provides buffers for re-serializing filtered records. Implementations may block. */ - @FunctionalInterface - public interface BufferSupplier { - Buffer requestBufferBlocking() throws IOException, InterruptedException; - } - /** * Handles record filtering for a single input gate. Each gate has its own serializer and set of * virtual channels, allowing different gates to handle different record types independently. @@ -275,8 +254,6 @@ static class GateFilterHandler { private final Map> virtualChannels; private final StreamElementSerializer serializer; private final DeserializationDelegate deserializationDelegate; - private final DataOutputSerializer outputSerializer; - private final byte[] lengthBuffer = new byte[4]; GateFilterHandler( Map> virtualChannels, @@ -284,23 +261,21 @@ static class GateFilterHandler { this.virtualChannels = checkNotNull(virtualChannels); this.serializer = checkNotNull(serializer); this.deserializationDelegate = new NonReusingDeserializationDelegate<>(serializer); - this.outputSerializer = new DataOutputSerializer(128); } /** * Deserializes records from {@code sourceBuffer}, applies the virtual channel's record - * filter, and immediately re-serializes each surviving record into output buffers. + * filter, and re-serializes each surviving record into {@code outputSerializer}. No + * intermediate network buffer is used; the caller owns the segment boundary. */ - List filterAndRewrite( + void filterAndRewrite( int oldSubtaskIndex, int oldChannelIndex, Buffer sourceBuffer, - BufferSupplier bufferSupplier) - throws IOException, InterruptedException { + DataOutputSerializer outputSerializer) + throws IOException { boolean sourceBufferOwnershipTransferred = false; - List resultBuffers = new ArrayList<>(); - Buffer currentBuffer = null; try { SubtaskConnectionDescriptor key = new SubtaskConnectionDescriptor(oldSubtaskIndex, oldChannelIndex); @@ -319,132 +294,33 @@ List filterAndRewrite( while (true) { DeserializationResult result = vc.getNextRecord(deserializationDelegate); if (result.isFullRecord()) { - if (currentBuffer == null) { - currentBuffer = bufferSupplier.requestBufferBlocking(); - } - currentBuffer = - serializeElement( - deserializationDelegate.getInstance(), - currentBuffer, - resultBuffers, - bufferSupplier); + serializeElement(deserializationDelegate.getInstance(), outputSerializer); } if (result.isBufferConsumed()) { break; } } - - if (currentBuffer != null) { - if (currentBuffer.readableBytes() > 0) { - resultBuffers.add(currentBuffer); - } else { - currentBuffer.recycleBuffer(); - } - currentBuffer = null; - } - - return resultBuffers; } catch (Throwable t) { if (!sourceBufferOwnershipTransferred) { sourceBuffer.recycleBuffer(); } - // Avoid double-recycle: currentBuffer may already be the last element in - // resultBuffers if serializeElement added it before the exception. - if (currentBuffer != null - && (resultBuffers.isEmpty() - || resultBuffers.get(resultBuffers.size() - 1) != currentBuffer)) { - currentBuffer.recycleBuffer(); - } - for (Buffer buf : resultBuffers) { - buf.recycleBuffer(); - } - resultBuffers.clear(); throw t; } } /** - * Serializes a single stream element into the current buffer using the length-prefixed - * format (4-byte big-endian length + record bytes) expected by Flink's record - * deserializers. Spills into new buffers from {@code bufferSupplier} when needed. - * - * @return the buffer to continue writing into (may differ from the input buffer). + * Appends one stream element as a length-prefixed record. Reserves the 4B prefix, + * serializes the element, then backfills the length, because {@code outputSerializer} + * already holds the segment header and earlier records, so the prefix cannot be written + * from a fixed offset. */ - private Buffer serializeElement( - StreamElement element, - Buffer currentBuffer, - List resultBuffers, - BufferSupplier bufferSupplier) - throws IOException, InterruptedException { - outputSerializer.clear(); + private void serializeElement(StreamElement element, DataOutputSerializer outputSerializer) + throws IOException { + int startPos = outputSerializer.length(); + outputSerializer.writeInt(0); // length placeholder serializer.serialize(element, outputSerializer); - int recordLength = outputSerializer.length(); - - writeLengthToBuffer(recordLength); - currentBuffer = - writeDataToBuffer( - lengthBuffer, 0, 4, currentBuffer, resultBuffers, bufferSupplier); - - byte[] serializedData = outputSerializer.getSharedBuffer(); - currentBuffer = - writeDataToBuffer( - serializedData, - 0, - recordLength, - currentBuffer, - resultBuffers, - bufferSupplier); - return currentBuffer; - } - - private void writeLengthToBuffer(int length) { - lengthBuffer[0] = (byte) (length >> 24); - lengthBuffer[1] = (byte) (length >> 16); - lengthBuffer[2] = (byte) (length >> 8); - lengthBuffer[3] = (byte) length; - } - - /** - * Writes data to the current buffer, spilling into new buffers from {@code bufferSupplier} - * when the current one is full. - * - * @return the buffer to continue writing into (may differ from the input buffer). - */ - private Buffer writeDataToBuffer( - byte[] data, - int dataOffset, - int dataLength, - Buffer currentBuffer, - List resultBuffers, - BufferSupplier bufferSupplier) - throws IOException, InterruptedException { - int offset = dataOffset; - int remaining = dataLength; - - while (remaining > 0) { - int writableBytes = currentBuffer.getMaxCapacity() - currentBuffer.getSize(); - - if (writableBytes == 0) { - // Buffer is full, transfer ownership to resultBuffers - resultBuffers.add(currentBuffer); - currentBuffer = bufferSupplier.requestBufferBlocking(); - writableBytes = currentBuffer.getMaxCapacity(); - } - - int bytesToWrite = Math.min(remaining, writableBytes); - currentBuffer - .getMemorySegment() - .put( - currentBuffer.getMemorySegmentOffset() + currentBuffer.getSize(), - data, - offset, - bytesToWrite); - currentBuffer.setSize(currentBuffer.getSize() + bytesToWrite); - - offset += bytesToWrite; - remaining -= bytesToWrite; - } - return currentBuffer; + int recordLength = outputSerializer.length() - startPos - Integer.BYTES; + outputSerializer.writeIntUnsafe(recordLength, startPos); } boolean hasPartialData() { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateSerializer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateSerializer.java index 252d25c2e29fd8..ec858460dd89b0 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateSerializer.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateSerializer.java @@ -43,6 +43,8 @@ interface ChannelStateSerializer { void writeData(DataOutputStream stream, Buffer... flinkBuffers) throws IOException; + void writeData(DataOutputStream stream, InputStream input, int length) throws IOException; + void readHeader(InputStream stream) throws IOException; int readLength(InputStream stream) throws IOException; @@ -165,6 +167,18 @@ public void writeData(DataOutputStream stream, Buffer... flinkBuffers) throws IO } } + @Override + public void writeData(DataOutputStream stream, InputStream input, int length) + throws IOException { + Preconditions.checkArgument(length >= 0, "negative state size"); + stream.writeInt(length); + long copied = input.transferTo(stream); + if (copied != length) { + throw new java.io.EOFException( + "Unexpected EOF: expected " + length + " bytes of segment body, got " + copied); + } + } + private int getSize(Buffer[] buffers) { int len = 0; for (Buffer buffer : buffers) { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequest.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequest.java index abef241c325b80..d1913df0416c11 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequest.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequest.java @@ -258,6 +258,20 @@ static ChannelStateWriteRequest abort( return new CheckpointAbortRequest(jobVertexID, subtaskIndex, checkpointId, cause); } + static ChannelStateWriteRequest replayInputDataFromSpill( + JobVertexID jobVertexID, + int subtaskIndex, + long checkpointId, + FetchedChannelStateReader reader) { + return new CheckpointInProgressRequest( + "writeInputFromSpill", + jobVertexID, + subtaskIndex, + checkpointId, + writer -> writer.writeInputFromSpill(jobVertexID, subtaskIndex, reader), + throwable -> reader.close()); + } + static ChannelStateWriteRequest registerSubtask(JobVertexID jobVertexID, int subtaskIndex) { return new SubtaskRegisterRequest(jobVertexID, subtaskIndex); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriter.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriter.java index 6fee1402036d69..b603a53a09ebb4 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriter.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriter.java @@ -190,6 +190,9 @@ void addOutputDataFuture( ChannelStateWriteResult getAndRemoveWriteResult(long checkpointId) throws IllegalArgumentException; + /** Records input-channel state from a spill file and takes ownership of {@code reader}. */ + void addInputDataFromSpill(long checkpointId, FetchedChannelStateReader reader); + ChannelStateWriter NO_OP = new NoOpChannelStateWriter(); /** No-op implementation of {@link ChannelStateWriter}. */ @@ -231,6 +234,14 @@ public ChannelStateWriteResult getAndRemoveWriteResult(long checkpointId) { CompletableFuture.completedFuture(Collections.emptyList())); } + @Override + public void addInputDataFromSpill(long checkpointId, FetchedChannelStateReader reader) { + try { + reader.close(); + } catch (Exception ignored) { + } + } + @Override public void close() {} } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriterImpl.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriterImpl.java index 40d7ddffd1e18d..21db97355db7c1 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriterImpl.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriterImpl.java @@ -42,6 +42,7 @@ import static org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequest.completeInput; import static org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequest.completeOutput; +import static org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequest.replayInputDataFromSpill; import static org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequest.write; /** @@ -235,6 +236,12 @@ public void finishOutput(long checkpointId) { enqueue(completeOutput(jobVertexID, subtaskIndex, checkpointId), false); } + @Override + public void addInputDataFromSpill(long checkpointId, FetchedChannelStateReader reader) { + LOG.debug("{} replaying input data from spill, checkpoint {}", taskName, checkpointId); + enqueue(replayInputDataFromSpill(jobVertexID, subtaskIndex, checkpointId, reader), false); + } + @Override public void abort(long checkpointId, Throwable cause, boolean cleanup) { LOG.debug("{} aborting, checkpoint {}", taskName, checkpointId); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelState.java new file mode 100644 index 00000000000000..5e42b8217b9dc0 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelState.java @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint.channel; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.annotation.VisibleForTesting; + +import java.io.Closeable; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.apache.flink.util.Preconditions.checkNotNull; + +/** + * Sealed container for recovered channel-state data written to spill files. + * + *

Holds a list of file paths (in write order). Segment boundaries are self-described in disk + * segment headers ([4B gateIdx][4B channelIdx][4B bufferLength]), so no in-memory segment locator + * table is maintained. The reader scans files sequentially, reading each 12-byte header to obtain + * the channel info and body length. + * + *

The file list grows as the writer rotates to new files (one rotation per 64 MB soft limit), + * and is sealed on writer close. + * + *

File lifecycle is managed by {@link #acquire()} / {@link #release()} reference counting. Files + * are deleted only when the last lifecycle grant is released (i.e. when both the drain reader and + * all snapshot readers have finished). + * + *

Mutations (file list appends) are single-writer and intentionally unsynchronized; callers must + * serialize them via the channel IO executor. + */ +@Internal +public final class FetchedChannelState implements Closeable { + + /** Ordered list of spill file paths, one entry per physical file. Sealed at construction. */ + private final List files; + + // close() and release() may be called from different threads; volatile ensures visibility. + private volatile boolean closed = false; + + private final AtomicInteger refCount = new AtomicInteger(0); + + private final AtomicBoolean cleanedUp = new AtomicBoolean(false); + + /** + * Wraps an already-written, ordered list of spill files. The list is sealed; it never grows. + */ + FetchedChannelState(List files) { + this.files = new ArrayList<>(checkNotNull(files)); + } + + // ------------------------------------------------------------------------------------------- + // Read-phase API (called by the reader after the writer is sealed) + // ------------------------------------------------------------------------------------------- + + /** + * Opens a root reader covering all segments from the beginning. The returned reader holds one + * lifecycle grant and must be closed when done. + */ + public FetchedChannelStateReader reader() { + return new FetchedChannelStateSnapshot( + this, FetchedChannelStateReaderImpl.Position.atStart()) + .reader(); + } + + /** Returns the ordered list of spill file paths. Read-only view. */ + public List files() { + return Collections.unmodifiableList(files); + } + + // ------------------------------------------------------------------------------------------- + // Lifecycle + // ------------------------------------------------------------------------------------------- + + /** Acquires a lifecycle grant for a reader or handoff owner. */ + public void acquire() { + refCount.incrementAndGet(); + } + + /** + * Releases a lifecycle grant. When the last grant is released (refCount reaches zero), all + * spill files are deleted. This preserves the invariant that files exist for the lifetime of + * all readers (drain + snapshot) and are cleaned up exactly once when the last reader finishes. + */ + public void release() throws IOException { + if (refCount.decrementAndGet() == 0) { + if (cleanedUp.compareAndSet(false, true)) { + closed = true; + deleteAllFiles(); // review todo: threads / call-sites + } + } + } + + /** Forces cleanup even when lifecycle grants are still outstanding. */ + @Override + public void close() throws IOException { + if (closed) { + return; + } + closed = true; + if (cleanedUp.compareAndSet(false, true)) { + deleteAllFiles(); + } + } + + private void deleteAllFiles() throws IOException { + IOException firstError = null; + for (Path file : files) { + try { + Files.deleteIfExists(file); + } catch (IOException e) { + if (firstError == null) { + firstError = e; + } else { + firstError.addSuppressed(e); + } + } + } + if (firstError != null) { + throw firstError; + } + } + + @VisibleForTesting + public boolean isClosed() { + return closed; + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateDrainer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateDrainer.java new file mode 100644 index 00000000000000..58ebb1d733be13 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateDrainer.java @@ -0,0 +1,201 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint.channel; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.runtime.checkpoint.channel.FetchedChannelStateReader.SpillSegment; +import org.apache.flink.runtime.io.network.buffer.Buffer; +import org.apache.flink.runtime.io.network.partition.consumer.RecoverableInputChannel; + +import java.io.Closeable; +import java.io.IOException; +import java.io.InputStream; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static org.apache.flink.util.Preconditions.checkNotNull; + +/** + * Drains a {@link FetchedChannelState} into recovered-buffer queues and snapshots remaining + * segments when a checkpoint fires during recovery. + * + *

The drainer lock pairs channel delivery with reader-cursor advancement and also protects + * snapshot creation plus barrier insertion. Disk reads and buffer allocation stay outside that + * lock. + */ +@Internal +public final class FetchedChannelStateDrainer implements RecoveryCheckpointTrigger, Closeable { + + private final FetchedChannelStateReader rootReader; + + private final ResolvedChannels channels; + + private final Object lock = new Object(); + private final FetchedChannelState channelState; + + /** + * Set under {@link #lock} once {@link #drain()} has consumed every segment. After that the + * {@link #rootReader} is closed by {@link #close()}, so a later {@link + * #snapshotAndInsertBarriers} must not derive from it; it returns an empty reader instead. + * Guarded by the lock so the check is atomic with barrier insertion. + */ + private boolean drainFinished; + + public FetchedChannelStateDrainer( + FetchedChannelState channelState, List channels) { + this.channelState = channelState; + this.rootReader = checkNotNull(channelState).reader(); + this.channels = new ResolvedChannels(channels); + } + + private static final class ResolvedChannels { + final List allChannels; + final Map channelByInfo; + + ResolvedChannels(List all) { + this.allChannels = all; + Map byInfo = new HashMap<>(); + for (RecoverableInputChannel ch : all) { + byInfo.put(ch.getChannelInfo(), ch); + } + this.channelByInfo = byInfo; + } + } + + /** + * Drains all segments from the spill file into the corresponding recovery buffer queues. Each + * segment is split into chunks of at most {@code memorySegmentSize} bytes; a full chunk is + * delivered under the drainer lock paired with a segment commit. After all segments are + * drained, every channel's {@link RecoverableInputChannel#finishRecoveredBufferDelivery()} is + * called. + * + *

Disk reads and buffer allocations happen outside the lock; only the "deliver + commit" + * pair is locked to guarantee atomicity with snapshot. + */ + public void drain() throws IOException, InterruptedException { // review: enforce called once + channelState.release(); + Optional next; + while ((next = rootReader.nextSegment()).isPresent()) { + SpillSegment seg = next.get(); + RecoverableInputChannel ch = channels.channelByInfo.get(seg.channelInfo()); + if (ch == null) { + throw new IllegalStateException( + "Drain: no physical channel found for " + seg.channelInfo()); + } + drainSegment(seg, ch); + } + + // Mark drain done before rootReader is closed, so a concurrent snapshot returns empty + // rather than deriving from the soon-to-be-closed rootReader. Under the lock to stay atomic + // with snapshotAndInsertBarriers' check. + synchronized (lock) { + drainFinished = true; + } + for (RecoverableInputChannel ch : channels.allChannels) { + ch.finishRecoveredBufferDelivery(); + } + } + + /** + * Drains one segment into the given channel. Fills buffers from the segment's opaque byte + * stream in chunks of at most {@code memorySegmentSize} bytes. A full buffer is delivered under + * the lock and a fresh one is requested; a partial tail buffer (if non-empty) is also + * delivered. + */ + private void drainSegment(SpillSegment seg, RecoverableInputChannel ch) + throws IOException, InterruptedException { + InputStream in = seg.bodyStream(); + Buffer buf = ch.requestRecoveryBufferBlocking(); + int cap = buf.getMaxCapacity(); + + while (fill(buf, in, cap - buf.getSize()) > 0) { + if (buf.getSize() == cap) { + // Buffer is full: deliver under lock and request a fresh one. + synchronized (lock) { + ch.onRecoveredStateBuffer(buf); + seg.commit(); + } + buf = ch.requestRecoveryBufferBlocking(); + cap = buf.getMaxCapacity(); + } + // If buf is not full yet, the fill returned > 0 bytes but segment is not exhausted; + // loop and keep filling the same buffer. + } + + if (buf.getSize() > 0) { + // Deliver the partial tail buffer. + synchronized (lock) { + ch.onRecoveredStateBuffer(buf); + seg.commit(); + } + } else { + buf.recycleBuffer(); + } + } + + /** + * Fills up to {@code remaining} bytes from {@code in} into {@code buf}. Returns the number of + * bytes actually written; returns 0 if the stream is at EOF. Does not close or recycle {@code + * buf}; ownership stays with the caller. + */ + private static int fill(Buffer buf, InputStream in, int remaining) throws IOException { + if (remaining == 0) { + return 0; + } + // Do not use try-with-resources: ChannelStateByteBuffer.close() recycles the buffer, + // but the buffer is still owned by the caller here. + ChannelStateByteBuffer view = ChannelStateByteBuffer.wrap(buf); // review: try-without ? + return view.writeBytes(in, remaining); + } + + /** + * Atomically snapshots the undrained portion of the spill and inserts {@link + * RecoveryCheckpointBarrier}s into all in-recovery channels. Returns an independent reader over + * the remaining segments for replay into the checkpoint stream; the caller owns and must close + * it. + * + *

If the drain has already finished, the root reader is closed and there is nothing left to + * snapshot; an empty reader is returned so the caller's normal flow handles it uniformly. + */ + @Override + public FetchedChannelStateReader snapshotAndInsertBarriers(long checkpointId) + throws IOException { + + // Barrier insertion and snapshot must occur within the same critical section so that the + // snapshot's committed position reflects exactly the drain position at the moment barriers + // were inserted, with no window for the drain thread to advance between. + synchronized (lock) { + for (RecoverableInputChannel ch : channels.allChannels) { + ch.insertRecoveryCheckpointBarrierIfInRecovery(checkpointId); + } + if (drainFinished) { + // Drain consumed everything and rootReader is (being) closed; nothing left to + // snapshot. Return an empty reader so the caller's normal flow handles it. + return FetchedChannelStateReader.emptyReader(); + } + return rootReader.snapshot().reader(); + } + } + + @Override + public void close() throws IOException { + rootReader.close(); + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateReader.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateReader.java new file mode 100644 index 00000000000000..5e65f7f6b79c2c --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateReader.java @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint.channel; + +import org.apache.flink.annotation.Internal; + +import java.io.Closeable; +import java.io.InputStream; +import java.util.Collections; +import java.util.Optional; + +/** + * Forward reader over a {@link FetchedChannelState}'s spill files. This is our own segment reader, + * on purpose not a Java {@link java.util.Iterator}: our access pattern ("a body must be + * fully read before the next segment", "body ownership is handed to the consumer", "consume and + * commit are separate steps") does not fit the {@code hasNext/next} contract. + * + *

This interface is the contract callers depend on; {@link FetchedChannelStateReaderImpl} holds + * the implementation (the live file stream, the two progress positions, the bounded body view). + * + *

Reading is strictly sequential: a reader is positioned once (offset 0 for the root reader, or + * the committed position for a {@link #snapshot()}), then consumes forward only via {@link + * #nextSegment()}. It never seeks backward and never re-positions mid-iteration. + * + *

The drain thread reads the root reader front to back and commits via {@link + * SpillSegment#commit()}; each checkpoint derives a fresh {@link #snapshot()} that resumes from the + * committed position. {@link #snapshot()} and {@link SpillSegment#commit()} must be called under + * the drainer lock; disk reads happen outside it. + */ +@Internal +public interface FetchedChannelStateReader extends Closeable { + + /** + * Advances to the next segment and returns it, or {@link Optional#empty()} when no segment + * remains. Advancing and probing are one step; there is no separate {@code hasNext}. + * + *

Entry rule (the first call is exempt): the previous segment's body must be fully read, + * otherwise this is a contract violation and fails loud (no skip-ahead). + */ + Optional nextSegment(); + + /** + * Derives an independent resume point starting from the committed position. The snapshot holds + * its own {@link FetchedChannelState} lifecycle grant; the caller must open a reader from it + * via {@link FetchedChannelStateSnapshot#reader()} and close that reader when done. + * + *

Must be called under the drainer lock so that the copied position reflects the latest + * committed state. + * + * @return a snapshot capturing the current committed position; caller must open and close a + * reader from it + */ + FetchedChannelStateSnapshot snapshot(); + + /** + * Returns a reader with no segments — its first {@link #nextSegment()} is empty. Each call + * hands out a fresh instance: readers have independent lifecycle and {@link #close()} is + * single-use, so a shared instance would let one consumer's close break later consumers. Used + * wherever there is nothing to snapshot (e.g. after drain finished, or the no-op recovery + * trigger). + */ + static FetchedChannelStateReader emptyReader() { + return new FetchedChannelState(Collections.emptyList()).reader(); + } + + /** + * One per-channel segment produced by {@link #nextSegment()}. + * + *

The segment body bytes are opaque to the reader; record framing is handled by the + * consumer's deserializer. A consumer reads {@link #bodyStream()} to EOF (after {@link + * #length()} bytes), and the drain consumer additionally calls {@link #commit()} under the + * drainer lock after each delivery so that a later {@link FetchedChannelStateReader#snapshot()} + * resumes from the delivered boundary. + * + *

Ownership of {@link #bodyStream()} passes to the consumer: the reader no longer tracks how + * far it has been read. The "previous body must be fully read" rule (no skip-ahead) is enforced + * at the next {@link FetchedChannelStateReader#nextSegment()} call, not here. + * + *

A segment is valid only until the next {@code nextSegment()} call on the parent reader. + */ + interface SpillSegment { + + /** The channel whose data this segment contains. */ + InputChannelInfo channelInfo(); + + /** + * Returns an {@link InputStream} bounded to this segment's body. Reading returns {@code -1} + * (EOF) after {@link #length()} bytes; it never reads into the next segment or the next + * file. + * + *

The stream is single-use, not thread-safe, and must be fully consumed before the next + * {@link FetchedChannelStateReader#nextSegment()}. + */ + InputStream bodyStream(); + + /** + * Number of body bytes this segment hands out before EOF. For the snapshot path this is the + * not-yet-delivered remainder used as the length prefix when writing to the checkpoint + * stream. Bounded by the spill file size limit, so it always fits in an {@code int}. + */ + int length(); + + /** + * Advances the reader's committed position to match how many body bytes have been read from + * {@link #bodyStream()} so far. Must be called under the drainer lock after each buffer + * delivery so that a subsequent {@link FetchedChannelStateReader#snapshot()} sees the + * correct delivered boundary. Only the drain (root) reader commits; the snapshot reader + * never does. + */ + void commit(); + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateReaderImpl.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateReaderImpl.java new file mode 100644 index 00000000000000..cb8b6dca4c2469 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateReaderImpl.java @@ -0,0 +1,533 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint.channel; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.runtime.checkpoint.channel.FetchedChannelStateReader.SpillSegment; + +import javax.annotation.Nullable; + +import java.io.ByteArrayInputStream; +import java.io.DataInputStream; +import java.io.EOFException; +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.List; +import java.util.Optional; + +import static org.apache.flink.runtime.checkpoint.channel.AbstractSpillingHandler.SEGMENT_HEADER_BYTES; +import static org.apache.flink.util.Preconditions.checkState; + +/** + * The single {@link FetchedChannelStateReader} implementation over a {@link FetchedChannelState}'s + * spill files. + * + *

Reading is strictly sequential and never seeks mid-iteration. There is exactly one place that + * skips bytes: the very first {@link #nextSegment()} call, where a snapshot reader started mid-body + * discards the already-delivered prefix to land on the not-yet-delivered remainder. Every later + * call does no skipping at all — the previous body was read to its end, so the stream already sits + * on the next segment's header. This "skip only on first positioning" rule is what keeps the + * steady-state path free of any seek/skip. + * + *

The reader holds two {@link Position}s and nothing else duplicates them: + * + *

+ * + *

The "previous body fully read before advancing" rule is checked at the {@link #nextSegment()} + * entry (the first call is exempt — there is no previous segment). Body ownership is handed to the + * consumer, so the reader does not track body progress except through {@code current}. + */ +@Internal +final class FetchedChannelStateReaderImpl implements FetchedChannelStateReader { + + private final FetchedChannelStateSnapshot snapshot; + private final FetchedChannelState channelState; + private final List files; + + /** Live read position; {@code readOffset} is where the open stream physically sits. */ + private final Position current; + + /** Delivered boundary; {@link SpillSegment#commit()} advances it from {@link #current}. */ + private final Position committed; + + /** Open stream over {@code current.fileIndex}, or {@code null} before the first read. */ + @Nullable private InputStream fileStream; + + /** Size of the file currently open. */ + private long currentFileSize; + + /** Body view of the segment returned by the last {@link #nextSegment()}, or {@code null}. */ + @Nullable private BoundedSegmentStream currentBody; + + private boolean positioned; + private boolean closed; + + FetchedChannelStateReaderImpl(FetchedChannelStateSnapshot snapshot) { + this.snapshot = snapshot; + this.channelState = snapshot.channelState(); + this.files = channelState.files(); + // Must copy the position so that this reader's commits do not mutate the snapshot's state. + this.committed = snapshot.position().copy(); + this.current = committed.copy(); + } + + @Override + public Optional nextSegment() { + checkState(!closed, "FetchedChannelStateReader is closed"); + checkState( + currentBody == null || currentBody.remaining() == 0, + "Previous segment body not fully consumed before advancing: %s bytes left", + currentBody == null ? 0 : currentBody.remaining()); + try { + if (!positioned) { + positioned = true; + return firstSegment(); + } + return followingSegment(); + } catch (IOException e) { + throw new RuntimeException("Failed to read segment", e); + } + } + + /** + * First positioning — the only path that may skip bytes. Opens the file at the committed header + * offset and reads the header. A snapshot may resume in the middle of a segment: the committed + * {@code readOffset} says how many body bytes were already delivered, and that prefix is + * skipped so the returned body starts at the not-yet-delivered remainder. If the segment was + * already fully delivered (prefix == whole body), it is exhausted here and we move on to the + * next one. + */ + private Optional firstSegment() throws IOException { + // The committed read offset may sit mid-body (after a partial commit), but the header lives + // at segmentStartOffset. Capture how much was already delivered, then rewind the live read + // offset to the header so we open the file there and read the header, not mid-body. + int deliveredPrefix = (int) current.deliveredBodyBytes(); + current.rewindToSegmentStart(); + + if (!openCurrentFile()) { + return Optional.empty(); + } + SegmentHeader header = readHeaderAtCurrent(); + checkState( + deliveredPrefix <= header.bufferLength, + "Delivered offset %s exceeds segment length %s", + deliveredPrefix, + header.bufferLength); + + if (deliveredPrefix == header.bufferLength) { + // This segment was already fully delivered before the snapshot; nothing remains in it. + // Skip its whole body to reach the next segment's header, then take the steady path. + skipBody(header.bufferLength); + return followingSegment(); + } + + // Discard the already-delivered prefix (the one and only skip in this class), then hand out + // the remainder. alreadyDelivered is carried so commit() records the boundary from the + // head. + skipBody(deliveredPrefix); + currentBody = + new BoundedSegmentStream(header.bufferLength - deliveredPrefix, deliveredPrefix); + return Optional.of(new Segment(header.channelInfo, currentBody)); + } + + /** + * Steady-state path — no skipping. The previous body was read to its end, so the stream sits + * exactly on this segment's header (or at the current file's end, in which case we roll to the + * next file). Reads the header and returns the whole-body view. + */ + private Optional followingSegment() throws IOException { + if (!openCurrentFile()) { + return Optional.empty(); + } + SegmentHeader header = readHeaderAtCurrent(); + currentBody = new BoundedSegmentStream(header.bufferLength); + return Optional.of(new Segment(header.channelInfo, currentBody)); + } + + @Override + public FetchedChannelStateSnapshot snapshot() { + checkState(!closed, "FetchedChannelStateReader is closed"); + return new FetchedChannelStateSnapshot(channelState, committed.copy()); + } + + @Override + public void close() throws IOException { + if (closed) { + return; + } + closed = true; + try { + closeFileStream(); + } finally { + snapshot.release(); + } + } + + // ------------------------------------------------------------------------------------------- + // Sequential IO over the spill files; all of it advances current.readOffset / current.fileIndex + // ------------------------------------------------------------------------------------------- + + /** + * Ensures a file is open with the stream positioned at {@code current}'s read offset, ready to + * read this segment's header. Rolls to the next file when the current one is exhausted. Returns + * false when no segment remains. + * + *

{@code current.segmentStartOffset} is set to where the header begins, so a later {@link + * SpillSegment#commit()} records the right segment for the snapshot to resume from. + */ + private boolean openCurrentFile() throws IOException { + if (current.fileIndex >= files.size()) { + return false; + } + openFileAndSeek(); + if (current.readOffset < currentFileSize) { + current.startSegmentHere(); + return true; + } + // Current file fully read: move to the next file's first segment. + closeFileStream(); + current.rollToNextFile(); + if (current.fileIndex >= files.size()) { + return false; + } + openFileAndSeek(); + if (current.readOffset < currentFileSize) { + current.startSegmentHere(); + return true; + } + return false; + } + + /** Reads the 12-byte header at the current read offset; advances past it. */ + private SegmentHeader readHeaderAtCurrent() throws IOException { + byte[] headerBytes = new byte[SEGMENT_HEADER_BYTES]; + readFully(headerBytes); + DataInputStream h = new DataInputStream(new ByteArrayInputStream(headerBytes)); + int gateIdx = h.readInt(); + int channelIdx = h.readInt(); + int bufferLength = h.readInt(); + checkState(bufferLength >= 0, "negative segment length: %s", bufferLength); + return new SegmentHeader(new InputChannelInfo(gateIdx, channelIdx), bufferLength); + } + + /** + * Ensures the file at {@code current.fileIndex} is open with the stream positioned at {@code + * current.readOffset}. If a stream is already open it is left as-is: sequential reading + * guarantees it is already there. + */ + private void openFileAndSeek() throws IOException { + if (fileStream != null) { + return; + } + Path path = files.get(current.fileIndex); + currentFileSize = Files.size(path); + InputStream in = Files.newInputStream(path); + try { + skipOnStream(in, current.readOffset, path); + } catch (IOException e) { + in.close(); + throw e; + } + fileStream = in; + } + + /** Skips {@code count} body bytes on the open stream, advancing the read offset. */ + private void skipBody(long count) throws IOException { + if (count > 0) { + skipOnStream(fileStream, count, files.get(current.fileIndex)); + current.advanceReadOffset(count); + } + } + + /** Skips exactly {@code count} bytes on {@code in}, failing loud if the file ends early. */ + private void skipOnStream(InputStream in, long count, Path path) throws IOException { + long skipped = 0; + while (skipped < count) { + long s = in.skip(count - skipped); + if (s <= 0) { + // skip can return 0 near EOF; read-and-discard as a fallback. + if (in.read() < 0) { + throw new EOFException( + "Cannot position to offset " + count + " in spill file " + path); + } + skipped++; + } else { + skipped += s; + } + } + } + + private void readFully(byte[] buf) throws IOException { + int read = 0; + while (read < buf.length) { + int n = fileStream.read(buf, read, buf.length - read); + if (n < 0) { + throw new EOFException( + "Truncated segment header in file " + + files.get(current.fileIndex) + + " at offset " + + current.readOffset + + ": expected " + + buf.length + + " bytes, got " + + read); + } + read += n; + current.advanceReadOffset(n); + } + } + + /** Reads up to {@code len} body bytes from the current file; called only by the body view. */ + private int readBody(byte[] buf, int off, int len) throws IOException { + int n = fileStream.read(buf, off, len); + if (n > 0) { + current.advanceReadOffset(n); + } + return n; + } + + private void closeFileStream() throws IOException { + if (fileStream != null) { + fileStream.close(); + fileStream = null; + } + } + + // ------------------------------------------------------------------------------------------- + // Position: the three progress values locating where a snapshot resumes + // ------------------------------------------------------------------------------------------- + + /** + * One progress point, expressed with the three values that fully locate where a snapshot + * resumes. The reader holds two of these: {@code current} (live read position) and {@code + * committed} (delivered boundary). They are the same shape but different in meaning; neither + * keeps a shadow copy of the other. + * + *

+ */ + static final class Position { + private int fileIndex; + private long segmentStartOffset; + private long readOffset; + + Position(int fileIndex, long segmentStartOffset, long readOffset) { + this.fileIndex = fileIndex; + this.segmentStartOffset = segmentStartOffset; + this.readOffset = readOffset; + } + + static Position atStart() { + return new Position(0, 0L, 0L); + } + + Position copy() { + return new Position(fileIndex, segmentStartOffset, readOffset); + } + + /** + * Already-delivered body bytes of the current segment, clamped to 0 before the header is + * crossed (where {@code readOffset <= segmentStartOffset}). Only the first-positioning path + * uses it, to size the prefix a snapshot must discard. + */ + long deliveredBodyBytes() { + return Math.max(0L, readOffset - segmentStartOffset - SEGMENT_HEADER_BYTES); + } + + /** Advances the live read offset by {@code delta} bytes just read/skipped. */ + void advanceReadOffset(long delta) { + readOffset += delta; + } + + /** + * Rewinds the live read offset back to the current segment's header. Used only on first + * positioning: a snapshot's committed offset may sit mid-body, but the header must be read + * first, so the read offset returns to {@code segmentStartOffset} before opening the file. + */ + void rewindToSegmentStart() { + readOffset = segmentStartOffset; + } + + /** + * Marks the current read offset as the start of the segment about to be read (its header + * begins here). Called right before reading a header. + */ + void startSegmentHere() { + segmentStartOffset = readOffset; + } + + /** Rolls to the start of the next file once the current one is exhausted. */ + void rollToNextFile() { + fileIndex++; + segmentStartOffset = 0L; + readOffset = 0L; + } + + /** + * Copies this position into {@code target} but pins {@code target}'s read offset to {@code + * deliveredBody} body bytes of the current segment — i.e. {@code commit} records "delivered + * up to here", not "physically read up to here". + */ + void copyAsDelivered(Position target, long deliveredBody) { + target.fileIndex = fileIndex; + target.segmentStartOffset = segmentStartOffset; + target.readOffset = segmentStartOffset + SEGMENT_HEADER_BYTES + deliveredBody; + } + } + + /** Parsed segment header: channel and full body length. */ + private static final class SegmentHeader { + private final InputChannelInfo channelInfo; + private final int bufferLength; + + private SegmentHeader(InputChannelInfo channelInfo, int bufferLength) { + this.channelInfo = channelInfo; + this.bufferLength = bufferLength; + } + } + + /** + * The single {@link SpillSegment} implementation. Exposes one segment's channel, body, and + * length; {@link #commit()} advances the reader's {@code committed} position to however many + * body bytes have been read. Reading the body and committing are separate steps so the consumer + * can read outside the drainer lock and commit inside it. + * + *

Only the root (drain) reader commits. + */ + private final class Segment implements SpillSegment { + private final InputChannelInfo channelInfo; + private final BoundedSegmentStream body; + + private Segment(InputChannelInfo channelInfo, BoundedSegmentStream body) { + this.channelInfo = channelInfo; + this.body = body; + } + + @Override + public InputChannelInfo channelInfo() { + return channelInfo; + } + + @Override + public InputStream bodyStream() { + return body; + } + + @Override + public int length() { + return body.deliverableLength(); + } + + @Override + public void commit() { + current.copyAsDelivered(committed, body.deliveredFromSegmentHead()); + } + } + + /** + * A forward-only, bounded view over one segment's not-yet-delivered body remainder. It hands + * out {@code remainingLength} bytes and reaches EOF after them; it never reads into the next + * segment or file. The underlying stream is already positioned at the first byte this view + * hands out (the prefix, if any, was skipped before construction). If the file ends before the + * segment end, an {@link EOFException} is thrown (fail-loud). Closing this view does not close + * the underlying file; the reader owns it. + */ + private final class BoundedSegmentStream extends InputStream { + + /** Body bytes already delivered before this view (the skipped prefix); 0 for the root. */ + private final int alreadyDelivered; + + private final int remainingLength; + private int read; + + private BoundedSegmentStream(int remainingLength) { + this(remainingLength, 0); + } + + private BoundedSegmentStream(int remainingLength, int alreadyDelivered) { + this.remainingLength = remainingLength; + this.alreadyDelivered = alreadyDelivered; + } + + /** Body bytes not yet handed out (between the read position and the segment end). */ + int remaining() { + return remainingLength - read; + } + + /** Number of body bytes this view will hand out: the not-yet-delivered remainder. */ + int deliverableLength() { + return remainingLength; + } + + /** + * Total delivered body bytes measured from the segment head: the skipped prefix plus what + * has been read through this view. This is what {@link Segment#commit()} records. + */ + int deliveredFromSegmentHead() { + return alreadyDelivered + read; + } + + @Override + public int read() throws IOException { + byte[] one = new byte[1]; + int n = read(one, 0, 1); + return n < 0 ? -1 : (one[0] & 0xFF); + } + + @Override + public int read(byte[] buf, int off, int len) throws IOException { + if (read >= remainingLength) { + return -1; + } + int toRead = Math.min(len, remainingLength - read); + int n = readBody(buf, off, toRead); + if (n < 0) { + throw new EOFException( + "Unexpected EOF in segment body after " + + read + + "/" + + remainingLength + + " bytes"); + } + read += n; + return n; + } + + @Override + public void close() { + // Do not close the underlying file; it is owned by the reader. + } + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateSnapshot.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateSnapshot.java new file mode 100644 index 00000000000000..cb54006ece40da --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateSnapshot.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint.channel; + +import java.io.IOException; + +import static org.apache.flink.util.Preconditions.checkState; + +/** + * An immutable resume point for a {@link FetchedChannelState} reader. It captures the {@link + * FetchedChannelStateReaderImpl.Position} at which reading should start and holds one lifecycle + * grant on the underlying {@link FetchedChannelState} (acquired in the constructor). + * + *

A snapshot is a one-shot, single-reader handle: + * + *

+ * + *

The 1:1 reader constraint is enforced fail-loud: calling {@link #reader()} a second time + * throws {@link IllegalStateException} immediately. + */ +final class FetchedChannelStateSnapshot { + + private final FetchedChannelState channelState; + private final FetchedChannelStateReaderImpl.Position position; + + /** True once {@link #reader()} has been called; prevents opening a second reader. */ + private boolean readerOpened; + + /** + * Creates a snapshot at {@code position} within {@code channelState}. Acquires one lifecycle + * grant on {@code channelState}; the grant is released when the reader returned by {@link + * #reader()} is closed. + */ + FetchedChannelStateSnapshot( + FetchedChannelState channelState, FetchedChannelStateReaderImpl.Position position) { + this.channelState = channelState; + this.position = position; + channelState.acquire(); + } + + /** + * Opens the reader for this snapshot. May be called at most once; a second call fails loud to + * enforce the 1:1 snapshot-to-reader invariant. + * + * @return a new reader starting from this snapshot's position; caller must close it when done + */ + FetchedChannelStateReader reader() { + checkState(!readerOpened, "A reader has already been opened from this snapshot"); + readerOpened = true; + return new FetchedChannelStateReaderImpl(this); + } + + /** + * Releases the lifecycle grant held by this snapshot. Called by the reader on close; must not + * be called directly by any other party. + */ + void release() throws IOException { + channelState.release(); + } + + /** Returns the underlying channel state (package-private; used by the reader). */ + FetchedChannelState channelState() { + return channelState; + } + + /** + * Returns the start position (package-private; used by the reader, which copies it + * immediately). + */ + FetchedChannelStateReaderImpl.Position position() { + return position; + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/RecoveredChannelStateHandler.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/RecoveredChannelStateHandler.java index ca01ff37bd3696..3475169ff82b40 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/RecoveredChannelStateHandler.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/RecoveredChannelStateHandler.java @@ -18,6 +18,8 @@ package org.apache.flink.runtime.checkpoint.channel; import org.apache.flink.annotation.VisibleForTesting; +import org.apache.flink.core.fs.OffsetAwareOutputStream; +import org.apache.flink.core.memory.DataOutputSerializer; import org.apache.flink.core.memory.MemorySegment; import org.apache.flink.core.memory.MemorySegmentFactory; import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor; @@ -38,13 +40,21 @@ import javax.annotation.Nonnull; import javax.annotation.Nullable; +import java.io.BufferedOutputStream; +import java.io.FileOutputStream; import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.UUID; import static org.apache.flink.runtime.checkpoint.channel.ChannelStateByteBuffer.wrap; import static org.apache.flink.util.Preconditions.checkArgument; +import static org.apache.flink.util.Preconditions.checkNotNull; import static org.apache.flink.util.Preconditions.checkState; interface RecoveredChannelStateHandler extends AutoCloseable { @@ -73,28 +83,396 @@ void recover(Info info, int oldSubtaskIndex, BufferWithContext bufferWi throws IOException, InterruptedException; } -class InputChannelRecoveredStateHandler +/** + * Abstract base for all input-channel recovery handlers. Holds the channel mapping logic shared by + * all three variants (no-spilling, spilling-no-filtering, spilling-with-filtering). + * + *

Subclasses implement {@link #recover} according to their specific recovery mode and override + * {@link #closeInternal()} to release mode-specific resources. + * + *

Use the static {@link #create} factory to obtain the correct concrete subclass. + */ +abstract class AbstractInputChannelRecoveredStateHandler implements RecoveredChannelStateHandler { - private final InputGate[] inputGates; - private final InflightDataRescalingDescriptor channelMapping; + final InputGate[] inputGates; + final InflightDataRescalingDescriptor channelMapping; + final Map rescaledChannels = new HashMap<>(); + final Map oldToNewMappings = new HashMap<>(); + + AbstractInputChannelRecoveredStateHandler( + InputGate[] inputGates, InflightDataRescalingDescriptor channelMapping) { + this.inputGates = inputGates; + this.channelMapping = channelMapping; + } + + /** + * Factory that selects the correct subclass based on {@code checkpointingDuringRecoveryEnabled} + * and whether a {@code filteringHandler} is present. + * + *

+ */ + static AbstractInputChannelRecoveredStateHandler create( + InputGate[] inputGates, + InflightDataRescalingDescriptor channelMapping, + boolean checkpointingDuringRecoveryEnabled, + @Nullable ChannelStateFilteringHandler filteringHandler, + int memorySegmentSize, + String[] spillTmpDirectories) { + if (!checkpointingDuringRecoveryEnabled) { + return new NoSpillingHandler(inputGates, channelMapping); + } + if (filteringHandler == null) { + return new SpillingNoFilteringHandler(inputGates, channelMapping, spillTmpDirectories); + } + return new SpillingWithFilteringHandler( + inputGates, + channelMapping, + filteringHandler, + memorySegmentSize, + spillTmpDirectories); + } + + /** Default buffer allocation from the network buffer pool, used by non-filtering modes. */ + @Override + public BufferWithContext getBuffer(InputChannelInfo channelInfo) + throws IOException, InterruptedException { + RecoveredInputChannel channel = getMappedChannels(channelInfo); + Buffer buffer = channel.requestBufferBlocking(); + return new BufferWithContext<>(wrap(buffer), buffer); + } + + /** + * Returns the {@link FetchedChannelState} produced during spilling, or {@code null} if spilling + * was not active (i.e., {@link NoSpillingHandler}). + */ + @Nullable + FetchedChannelState getProducedChannelState() { + return null; + } + + @Override + public void close() throws IOException { + closeInternal(); + } + + /** Hook for subclasses to release their own resources. Called by {@link #close()}. */ + void closeInternal() throws IOException {} + + RecoveredInputChannel getMappedChannels(InputChannelInfo channelInfo) { + return rescaledChannels.computeIfAbsent(channelInfo, this::calculateMapping); + } + + @Nonnull + private RecoveredInputChannel calculateMapping(InputChannelInfo info) { + final RescaleMappings oldToNewMapping = + oldToNewMappings.computeIfAbsent( + info.getGateIdx(), idx -> channelMapping.getChannelMapping(idx).invert()); + int[] mappedIndexes = oldToNewMapping.getMappedIndexes(info.getInputChannelIdx()); + checkState( + mappedIndexes.length == 1, + "One buffer is only distributed to one target InputChannel since " + + "one buffer is expected to be processed once by the same task."); + return getChannel(info.getGateIdx(), mappedIndexes[0]); + } + + private RecoveredInputChannel getChannel(int gateIndex, int subPartitionIndex) { + final InputChannel inputChannel = inputGates[gateIndex].getChannel(subPartitionIndex); + if (!(inputChannel instanceof RecoveredInputChannel)) { + throw new IllegalStateException( + "Cannot restore state to a non-recovered input channel: " + inputChannel); + } + return (RecoveredInputChannel) inputChannel; + } +} + +/** + * Recovery handler for the case where checkpointing during recovery is disabled. Delivers recovered + * buffers directly into the input channel via {@code onRecoveredStateBuffer}, with no spill file. + */ +class NoSpillingHandler extends AbstractInputChannelRecoveredStateHandler { + + NoSpillingHandler(InputGate[] inputGates, InflightDataRescalingDescriptor channelMapping) { + super(inputGates, channelMapping); + } + + @Override + public void recover( + InputChannelInfo channelInfo, + int oldSubtaskIndex, + BufferWithContext bufferWithContext) + throws IOException, InterruptedException { + Buffer buffer = bufferWithContext.context; + try { + if (buffer.readableBytes() > 0) { + RecoveredInputChannel channel = getMappedChannels(channelInfo); + channel.onRecoveredStateBuffer( + EventSerializer.toBuffer( + new SubtaskConnectionDescriptor( + oldSubtaskIndex, channelInfo.getInputChannelIdx()), + false)); + channel.onRecoveredStateBuffer(buffer.retainBuffer()); + } + } finally { + buffer.recycleBuffer(); + } + } +} + +/** + * Intermediate abstract base for the two spilling variants. Owns the on-disk spill format end to + * end: a single reusable {@link DataOutputSerializer} accumulates one channel's segment, the + * segment header is backfilled with the body length at seal time, and sealed segments are flushed + * to the current file stream with 64 MB-bounded rotation. + * + *

Disk format

+ * + *
+ * [ 4B BE int: gate idx      ]   segment header: written once per channel segment
+ * [ 4B BE int: channel idx   ]
+ * [ 4B BE int: buffer length ]   segment body byte count (backfilled at segment seal)
+ *   [ 4B BE int: record length ]  repeated for every record in this segment
+ *   [ N bytes: serialized record ]
+ * [ 4B BE int: gate idx      ]   next segment header (channel switch or post-rotation)
+ * ...
+ * 
+ * + *

The body byte count is only known after the whole segment is written, so each segment is first + * accumulated in {@link #segmentSerializer} (header written at open with a zero placeholder) and + * {@link DataOutputSerializer#writeIntUnsafe} backfills the length at seal. A segment is one + * uninterrupted run of records for a single channel; file rotation happens only after a segment is + * fully sealed, so a segment never crosses a file boundary. + */ +abstract class AbstractSpillingHandler extends AbstractInputChannelRecoveredStateHandler { + + /** Byte offset of the {@code bufferLength} field within a segment's header. */ + static final int BUFFER_LENGTH_HEADER_OFFSET = 2 * Integer.BYTES; + + /** Total size of the segment header in bytes: gateIdx + channelIdx + bufferLength. */ + static final int SEGMENT_HEADER_BYTES = 3 * Integer.BYTES; + + final String[] spillTmpDirectories; + + public static final long DEFAULT_SPILL_FILE_SIZE_BYTES = 64L * 1024 * 1024; + + /** Soft per-file size bound that triggers rotation between segments. */ + private final long maxFileSizeBytes; + + /** + * Accumulates the current segment: the header followed by the body, which is either + * length-prefixed filtered records or verbatim pass-through bytes, depending on the subclass. + * Reused across segments via {@code clear()}. + */ + private final DataOutputSerializer segmentSerializer = new DataOutputSerializer(256); + + /** + * Spill files written so far, in order. The {@link FetchedChannelState} handoff is built from + * this list once writing is sealed; an empty list means the handler never spilled any bytes, so + * it produces no state. + */ + private final List files = new ArrayList<>(); + + /** + * Unique directory for this handler's spill files; created lazily when the first file opens. + */ + private final Path baseDir; + + /** + * Output stream to the current spill file; tracks the bytes written so far via {@link + * OffsetAwareOutputStream#getLength()} to decide when to rotate. Null before the first segment + * is flushed. + */ + @Nullable private OffsetAwareOutputStream currentStream; + + /** Channel whose segment is currently open; null when no segment is in progress. */ + @Nullable private InputChannelInfo currentChannel; + + @Nullable private FetchedChannelState producedChannelState; + + AbstractSpillingHandler( + InputGate[] inputGates, + InflightDataRescalingDescriptor channelMapping, + String[] spillTmpDirectories, + long maxFileSizeBytes) { + super(inputGates, channelMapping); + checkArgument( + checkNotNull(spillTmpDirectories).length > 0, + "spillTmpDirectories must not be empty"); + checkArgument( + maxFileSizeBytes > 0, "maxFileSizeBytes must be positive: %s", maxFileSizeBytes); + this.spillTmpDirectories = spillTmpDirectories; + this.maxFileSizeBytes = maxFileSizeBytes; + this.baseDir = + Paths.get(spillTmpDirectories[0], "flink-channel-spill-" + UUID.randomUUID()); + } + + /** + * Opens (or switches to) the segment for {@code channelInfo} and returns its buffer for the + * caller to append the body into. The caller must not seal the segment. + */ + DataOutputSerializer segmentSerializerFor(InputChannelInfo channelInfo) throws IOException { + switchChannelIfNeeded(channelInfo); + return segmentSerializer; + } + + private void switchChannelIfNeeded(InputChannelInfo channelInfo) throws IOException { + if (channelInfo.equals(currentChannel)) { + return; + } + if (currentChannel != null) { + sealCurrentSegment(); + } + segmentSerializer.clear(); + segmentSerializer.writeInt(channelInfo.getGateIdx()); + segmentSerializer.writeInt(channelInfo.getInputChannelIdx()); + segmentSerializer.writeInt(0); // bufferLength placeholder + currentChannel = channelInfo; + } + + /** + * Backfills the body length into the segment header and flushes the whole segment to the file + * stream. Empty segments (filtered out entirely, or a zero-byte pass-through) are dropped + * without opening a file, so no empty file is created. + */ + private void sealCurrentSegment() throws IOException { + if (currentChannel == null) { + return; + } + currentChannel = null; + int totalBytes = segmentSerializer.length(); + int bodyBytes = totalBytes - SEGMENT_HEADER_BYTES; + if (bodyBytes == 0) { + return; + } + // Math.toIntExact guards against the unlikely case of a single segment > 2 GB. + segmentSerializer.writeIntUnsafe(Math.toIntExact(bodyBytes), BUFFER_LENGTH_HEADER_OFFSET); + ensureFileOpen(); + currentStream.write(segmentSerializer.getSharedBuffer(), 0, totalBytes); + } + + /** + * Ensures an output stream is ready for the next segment, rotating to a fresh file first if the + * current one reached the size bound. Rotation happens here, between sealed segments, so a + * segment is never split across files. + */ + private void ensureFileOpen() throws IOException { + if (currentStream != null && currentStream.getLength() >= maxFileSizeBytes) { + currentStream.flush(); + currentStream.close(); + currentStream = null; + } + if (currentStream != null) { + return; + } + // create the spill dir on the first file; no-op afterwards + Files.createDirectories(baseDir); + Path filePath = baseDir.resolve("spill-segment-" + files.size() + ".bin"); + currentStream = + new OffsetAwareOutputStream( + new BufferedOutputStream(new FileOutputStream(filePath.toFile())), 0L); + files.add(filePath); + } + + @Override + @Nullable + FetchedChannelState getProducedChannelState() { + return producedChannelState; + } - private final Map rescaledChannels = new HashMap<>(); - private final Map oldToNewMappings = new HashMap<>(); + /** Spill files written so far; empty if this handler never spilled any bytes. */ + @VisibleForTesting + List peekSpillFilesForTesting() { + return files; + } /** - * Optional filtering handler for filtering recovered buffers. When non-null, filtering is - * performed during recovery in the channel-state-unspilling thread. + * Seals the open segment and the file stream, then builds the {@link FetchedChannelState} + * handoff from the written files. Produces nothing if no bytes were ever spilled. */ - @Nullable private final ChannelStateFilteringHandler filteringHandler; + @Override + void closeInternal() throws IOException { + if (currentChannel != null) { + sealCurrentSegment(); + } + if (currentStream != null) { + currentStream.flush(); + currentStream.close(); // OffsetAwareOutputStream closes the wrapped stream quietly + currentStream = null; + } + if (files.isEmpty()) { + return; + } + producedChannelState = new FetchedChannelState(files); + // Keep the files alive between close() and drain-reader construction. + producedChannelState.acquire(); + } +} + +/** + * Recovery handler for the case where checkpointing during recovery is enabled but no filtering + * handler is present. Appends recovered buffer bytes verbatim into the current segment. + */ +class SpillingNoFilteringHandler extends AbstractSpillingHandler { + + SpillingNoFilteringHandler( + InputGate[] inputGates, + InflightDataRescalingDescriptor channelMapping, + String[] spillTmpDirectories) { + super(inputGates, channelMapping, spillTmpDirectories, DEFAULT_SPILL_FILE_SIZE_BYTES); + } + + @Override + public void recover( + InputChannelInfo channelInfo, + int oldSubtaskIndex, + BufferWithContext bufferWithContext) + throws IOException, InterruptedException { + Buffer buffer = bufferWithContext.context; + try { + if (buffer.readableBytes() > 0) { + recoverPassThroughToSpill(getMappedChannels(channelInfo).getChannelInfo(), buffer); + } + } finally { + buffer.recycleBuffer(); + } + } + + private void recoverPassThroughToSpill(InputChannelInfo channelInfo, Buffer source) + throws IOException { + // The recovered bytes are already a length-prefixed record sequence, so append them + // verbatim into the segment without re-framing. Writing straight from the backing + // MemorySegment lets it absorb the heap/off-heap distinction, avoiding both a branch on the + // NIO buffer kind and the intermediate copy a direct buffer would otherwise require. + segmentSerializerFor(channelInfo) + .write( + source.getMemorySegment(), + source.getMemorySegmentOffset() + source.getReaderIndex(), + source.readableBytes()); + } +} + +/** + * Recovery handler for the case where checkpointing during recovery is enabled and a filtering + * handler is present. Uses a reusable heap-backed pre-filter buffer (isolated from the Network + * Buffer Pool) and writes filtered/rewritten output to the spill file via {@link + * ChannelStateFilteringHandler#filterAndRewrite}. + */ +class SpillingWithFilteringHandler extends AbstractSpillingHandler { + + private final ChannelStateFilteringHandler filteringHandler; /** Network buffer memory segment size in bytes. Used to size the reusable pre-filter buffer. */ private final int memorySegmentSize; /** * Reusable heap memory segment backing the pre-filter buffer in filtering mode. Lazily - * allocated on the first {@link #getPreFilterBuffer} call, reused for every subsequent call, - * and freed in {@link #close()}. + * allocated on the first {@link #getBuffer} call, reused for every subsequent call, and freed + * in {@link #closeInternal()}. * *

Reuse is safe because at most one pre-filter buffer is in flight per task at any moment. * This invariant is enforced at runtime by {@link #preFilterBufferInUse}. @@ -108,39 +486,27 @@ class InputChannelRecoveredStateHandler */ private boolean preFilterBufferInUse; - InputChannelRecoveredStateHandler( + SpillingWithFilteringHandler( InputGate[] inputGates, InflightDataRescalingDescriptor channelMapping, - @Nullable ChannelStateFilteringHandler filteringHandler, - int memorySegmentSize) { - this.inputGates = inputGates; - this.channelMapping = channelMapping; + ChannelStateFilteringHandler filteringHandler, + int memorySegmentSize, + String[] spillTmpDirectories) { + super(inputGates, channelMapping, spillTmpDirectories, DEFAULT_SPILL_FILE_SIZE_BYTES); this.filteringHandler = filteringHandler; checkArgument( memorySegmentSize > 0, "memorySegmentSize must be positive: %s", memorySegmentSize); this.memorySegmentSize = memorySegmentSize; } - @Override - public BufferWithContext getBuffer(InputChannelInfo channelInfo) - throws IOException, InterruptedException { - if (filteringHandler != null) { - return getPreFilterBuffer(); - } - // Non-filtering mode: use existing network buffer pool allocation. - RecoveredInputChannel channel = getMappedChannels(channelInfo); - Buffer buffer = channel.requestBufferBlocking(); - return new BufferWithContext<>(wrap(buffer), buffer); - } - /** * Allocates a pre-filter buffer from a reusable heap segment (isolated from the Network Buffer * Pool) in filtering mode. * *

Memory management: a single {@link MemorySegment} per task is lazily allocated on first * invocation and reused across every subsequent call. The custom {@link BufferRecycler} does - * not free the segment — it only flips {@link #preFilterBufferInUse} back to {@code false} so - * the next call can reuse it. The segment itself is freed in {@link #close()}. + * not free the segment; it only flips {@link #preFilterBufferInUse} back to {@code false} so + * the next call can reuse it. The segment itself is freed in {@link #closeInternal()}. * *

Runtime invariant check: the one-at-a-time invariant on pre-filter buffers is guaranteed * by Flink's serial recovery loop and the deserializer's ownership contract. This method @@ -148,7 +514,8 @@ public BufferWithContext getBuffer(InputChannelInfo channelInfo) * recycled, it throws {@link IllegalStateException} so any future regression fails loudly * instead of silently corrupting memory. */ - private BufferWithContext getPreFilterBuffer() { + @Override + public BufferWithContext getBuffer(InputChannelInfo channelInfo) { checkState( !preFilterBufferInUse, "Previous pre-filter buffer has not been recycled. This violates the " @@ -165,17 +532,6 @@ private BufferWithContext getPreFilterBuffer() { return new BufferWithContext<>(wrap(buffer), buffer); } - @VisibleForTesting - boolean isPreFilterBufferInUse() { - return preFilterBufferInUse; - } - - @VisibleForTesting - @Nullable - MemorySegment getPreFilterSegmentForTesting() { - return preFilterSegment; - } - @Override public void recover( InputChannelInfo channelInfo, @@ -185,90 +541,40 @@ public void recover( Buffer buffer = bufferWithContext.context; try { if (buffer.readableBytes() > 0) { - RecoveredInputChannel channel = getMappedChannels(channelInfo); - - if (filteringHandler != null) { - recoverWithFiltering( - channel, channelInfo, oldSubtaskIndex, buffer.retainBuffer()); - } else { - channel.onRecoveredStateBuffer( - EventSerializer.toBuffer( - new SubtaskConnectionDescriptor( - oldSubtaskIndex, channelInfo.getInputChannelIdx()), - false)); - channel.onRecoveredStateBuffer(buffer.retainBuffer()); - } - } - } finally { - buffer.recycleBuffer(); - } - } - - private void recoverWithFiltering( - RecoveredInputChannel channel, - InputChannelInfo channelInfo, - int oldSubtaskIndex, - Buffer retainedBuffer) - throws IOException, InterruptedException { - checkState(filteringHandler != null, "filtering handler not set."); - List filteredBuffers = filteringHandler.filterAndRewrite( channelInfo.getGateIdx(), oldSubtaskIndex, channelInfo.getInputChannelIdx(), - retainedBuffer, - channel::requestBufferBlocking); - - int i = 0; - try { - for (; i < filteredBuffers.size(); i++) { - channel.onRecoveredStateBuffer(filteredBuffers.get(i)); + buffer.retainBuffer(), + segmentSerializerFor(getMappedChannels(channelInfo).getChannelInfo())); } - } catch (Throwable t) { - for (int j = i; j < filteredBuffers.size(); j++) { - filteredBuffers.get(j).recycleBuffer(); - } - throw t; - } - } - - @Override - public void close() throws IOException { - // note that we need to finish all RecoveredInputChannels, not just those with state - for (final InputGate inputGate : inputGates) { - inputGate.finishReadRecoveredState(); - } - if (preFilterSegment != null) { - preFilterSegment.free(); - preFilterSegment = null; - preFilterBufferInUse = false; + } finally { + buffer.recycleBuffer(); } } - private RecoveredInputChannel getChannel(int gateIndex, int subPartitionIndex) { - final InputChannel inputChannel = inputGates[gateIndex].getChannel(subPartitionIndex); - if (!(inputChannel instanceof RecoveredInputChannel)) { - throw new IllegalStateException( - "Cannot restore state to a non-recovered input channel: " + inputChannel); - } - return (RecoveredInputChannel) inputChannel; + @VisibleForTesting + boolean isPreFilterBufferInUse() { + return preFilterBufferInUse; } - private RecoveredInputChannel getMappedChannels(InputChannelInfo channelInfo) { - return rescaledChannels.computeIfAbsent(channelInfo, this::calculateMapping); + @VisibleForTesting + @Nullable + MemorySegment getPreFilterSegmentForTesting() { + return preFilterSegment; } - @Nonnull - private RecoveredInputChannel calculateMapping(InputChannelInfo info) { - final RescaleMappings oldToNewMapping = - oldToNewMappings.computeIfAbsent( - info.getGateIdx(), idx -> channelMapping.getChannelMapping(idx).invert()); - int[] mappedIndexes = oldToNewMapping.getMappedIndexes(info.getInputChannelIdx()); - checkState( - mappedIndexes.length == 1, - "One buffer is only distributed to one target InputChannel since " - + "one buffer is expected to be processed once by the same task."); - return getChannel(info.getGateIdx(), mappedIndexes[0]); + @Override + void closeInternal() throws IOException { + try { + super.closeInternal(); + } finally { + if (preFilterSegment != null) { + preFilterSegment.free(); + preFilterSegment = null; + preFilterBufferInUse = false; + } + } } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/RecoveryCheckpointBarrier.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/RecoveryCheckpointBarrier.java new file mode 100644 index 00000000000000..d263168b0119cd --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/RecoveryCheckpointBarrier.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint.channel; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.runtime.event.RuntimeEvent; + +/** Task-local event marking the recovery-state cut for a recovery checkpoint. */ +// review nit: checkpoint during recovery +@Internal +public final class RecoveryCheckpointBarrier extends RuntimeEvent { + + private final long checkpointId; + + public RecoveryCheckpointBarrier(long checkpointId) { + this.checkpointId = checkpointId; + } + + public long getCheckpointId() { + return checkpointId; + } + + @Override + public void write(DataOutputView out) { + throw new UnsupportedOperationException( + "RecoveryCheckpointBarrier must be serialized via EventSerializer's dedicated" + + " type-tag path, not reflective write()."); + } + + @Override + public void read(DataInputView in) { + throw new UnsupportedOperationException( + "RecoveryCheckpointBarrier must be deserialized via EventSerializer's dedicated" + + " type-tag path, not reflective read()."); + } + + @Override + public int hashCode() { + return Long.hashCode(checkpointId); + } + + @Override + public boolean equals(Object other) { + return other != null + && other.getClass() == RecoveryCheckpointBarrier.class + && ((RecoveryCheckpointBarrier) other).checkpointId == this.checkpointId; + } + + @Override + public String toString() { + return "RecoveryCheckpointBarrier(" + checkpointId + ")"; + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/RecoveryCheckpointTrigger.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/RecoveryCheckpointTrigger.java new file mode 100644 index 00000000000000..d2f08736f856ad --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/RecoveryCheckpointTrigger.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint.channel; + +import org.apache.flink.annotation.Internal; + +import java.io.IOException; + +@Internal +public interface RecoveryCheckpointTrigger { + + /** + * Atomically snapshots the undrained spill slice and inserts matching {@link + * RecoveryCheckpointBarrier}s into in-recovery channels. Returns an independent reader over the + * remaining segments; the caller owns and must close it. + */ + FetchedChannelStateReader snapshotAndInsertBarriers(long checkpointId) throws IOException; + + /** Returns an empty reader (no spill files, so no segments) and inserts no barriers. */ + RecoveryCheckpointTrigger NO_OP = checkpointId -> FetchedChannelStateReader.emptyReader(); + + RecoveryCheckpointTrigger NOT_READY = + ign -> { + throw new IllegalStateException("RecoveryCheckpointTrigger is not ready yet"); + }; + RecoveryCheckpointTrigger FAILING = + ign -> { + throw new IllegalStateException("Triggering checkpoints is not possible"); + }; +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReader.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReader.java index 547b60ef93aee7..88296c517b0150 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReader.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReader.java @@ -23,6 +23,7 @@ import org.apache.flink.streaming.runtime.io.recovery.RecordFilterContext; import java.io.IOException; +import java.util.Optional; /** Reads channel state saved during checkpoint/savepoint. */ @Internal @@ -34,7 +35,8 @@ public interface SequentialChannelStateReader extends AutoCloseable { * @param inputGates The input gates to recover state for. * @param filterContext The filter context containing input configs and rescaling info. */ - void readInputData(InputGate[] inputGates, RecordFilterContext filterContext) + Optional readInputData( + InputGate[] inputGates, RecordFilterContext filterContext) throws IOException, InterruptedException; void readOutputData(ResultPartitionWriter[] writers, boolean notifyAndBlockOnCompletion) @@ -47,8 +49,10 @@ void readOutputData(ResultPartitionWriter[] writers, boolean notifyAndBlockOnCom new SequentialChannelStateReader() { @Override - public void readInputData( - InputGate[] inputGates, RecordFilterContext filterContext) {} + public Optional readInputData( + InputGate[] inputGates, RecordFilterContext filterContext) { + return Optional.empty(); + } @Override public void readOutputData( diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImpl.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImpl.java index c52572e52faecb..263335bc0c7c8e 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImpl.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImpl.java @@ -30,12 +30,15 @@ import org.apache.flink.runtime.state.StreamStateHandle; import org.apache.flink.streaming.runtime.io.recovery.RecordFilterContext; +import javax.annotation.Nullable; + import java.io.Closeable; import java.io.IOException; import java.util.Collection; import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; import java.util.function.Consumer; import java.util.function.Function; @@ -53,6 +56,8 @@ public class SequentialChannelStateReaderImpl implements SequentialChannelStateR private final ChannelStateSerializer serializer; private final ChannelStateChunkReader chunkReader; + @Nullable private FetchedChannelState producedChannelState; + public SequentialChannelStateReaderImpl(TaskStateSnapshot taskStateSnapshot) { this.taskStateSnapshot = taskStateSnapshot; serializer = new ChannelStateSerializerImpl(); @@ -60,7 +65,8 @@ public SequentialChannelStateReaderImpl(TaskStateSnapshot taskStateSnapshot) { } @Override - public void readInputData(InputGate[] inputGates, RecordFilterContext filterContext) + public Optional readInputData( + InputGate[] inputGates, RecordFilterContext filterContext) throws IOException, InterruptedException { // Create filtering handler if filtering is needed @@ -69,29 +75,38 @@ public void readInputData(InputGate[] inputGates, RecordFilterContext filterCont ? ChannelStateFilteringHandler.createFromContext(filterContext, inputGates) : null; - try (ChannelStateFilteringHandler ignored = filteringHandler; - InputChannelRecoveredStateHandler stateHandler = - new InputChannelRecoveredStateHandler( - inputGates, - taskStateSnapshot.getInputRescalingDescriptor(), - filteringHandler, - filterContext.getMemorySegmentSize())) { - read( - stateHandler, - groupByDelegate( - streamSubtaskStates(), - ChannelStateHelper::extractUnmergedInputHandles)); - read( - stateHandler, - groupByDelegate( - streamSubtaskStates(), - OperatorSubtaskState::getUpstreamOutputBufferState)); + // Manual close ordering so the produced spill file can be published after + // stateHandler.close() flushes the filter writer. + AbstractInputChannelRecoveredStateHandler stateHandler = + AbstractInputChannelRecoveredStateHandler.create( + inputGates, + taskStateSnapshot.getInputRescalingDescriptor(), + filterContext.isCheckpointingDuringRecoveryEnabled(), + filteringHandler, + filterContext.getMemorySegmentSize(), + filterContext.getTmpDirectories()); + try (ChannelStateFilteringHandler ignored = filteringHandler) { + try (stateHandler) { + read( + stateHandler, + groupByDelegate( + streamSubtaskStates(), + ChannelStateHelper::extractUnmergedInputHandles)); + read( + stateHandler, + groupByDelegate( + streamSubtaskStates(), + OperatorSubtaskState::getUpstreamOutputBufferState)); - if (filteringHandler != null) { - checkState( - !filteringHandler.hasPartialData(), - "Not all data has been fully consumed during filtering"); + if (filteringHandler != null) { + checkState( + !filteringHandler.hasPartialData(), + "Not all data has been fully consumed during filtering"); + } } + // stateHandler.close() (above) has flushed the filter writer and published the + // produced spill file, so read getProducedChannelState() after the close completes. + return Optional.ofNullable(stateHandler.getProducedChannelState()); } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/serialization/EventSerializer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/serialization/EventSerializer.java index e73d9168cb273f..5711d1640f9094 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/serialization/EventSerializer.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/serialization/EventSerializer.java @@ -27,6 +27,7 @@ import org.apache.flink.runtime.checkpoint.CheckpointType; import org.apache.flink.runtime.checkpoint.SavepointType; import org.apache.flink.runtime.checkpoint.SnapshotType; +import org.apache.flink.runtime.checkpoint.channel.RecoveryCheckpointBarrier; import org.apache.flink.runtime.event.AbstractEvent; import org.apache.flink.runtime.event.WatermarkEvent; import org.apache.flink.runtime.io.network.api.CancelCheckpointMarker; @@ -43,6 +44,7 @@ import org.apache.flink.runtime.io.network.buffer.BufferConsumer; import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler; import org.apache.flink.runtime.io.network.buffer.NetworkBuffer; +import org.apache.flink.runtime.io.network.partition.consumer.EndOfFetchedChannelStateEvent; import org.apache.flink.runtime.io.network.partition.consumer.EndOfInputChannelStateEvent; import org.apache.flink.runtime.io.network.partition.consumer.EndOfOutputChannelStateEvent; import org.apache.flink.runtime.state.CheckpointStorageLocationReference; @@ -87,6 +89,10 @@ public class EventSerializer { private static final int END_OF_INPUT_CHANNEL_STATE_EVENT = 12; + private static final int RECOVERY_CHECKPOINT_BARRIER_EVENT = 13; + + private static final int END_OF_FETCHED_CHANNEL_STATE_EVENT = 14; + private static final byte CHECKPOINT_TYPE_CHECKPOINT = 0; private static final byte CHECKPOINT_TYPE_SAVEPOINT = 1; @@ -116,6 +122,8 @@ public static ByteBuffer toSerializedEvent(AbstractEvent event) throws IOExcepti return ByteBuffer.wrap(new byte[] {0, 0, 0, END_OF_OUTPUT_CHANNEL_STATE_EVENT}); } else if (eventClass == EndOfInputChannelStateEvent.class) { return ByteBuffer.wrap(new byte[] {0, 0, 0, END_OF_INPUT_CHANNEL_STATE_EVENT}); + } else if (eventClass == EndOfFetchedChannelStateEvent.class) { + return ByteBuffer.wrap(new byte[] {0, 0, 0, END_OF_FETCHED_CHANNEL_STATE_EVENT}); } else if (eventClass == EndOfData.class) { return ByteBuffer.wrap( new byte[] { @@ -162,6 +170,13 @@ public static ByteBuffer toSerializedEvent(AbstractEvent event) throws IOExcepti buf.putInt(0, RECOVERY_METADATA); buf.putInt(4, recoveryMetadata.getFinalBufferSubpartitionId()); return buf; + } else if (eventClass == RecoveryCheckpointBarrier.class) { + RecoveryCheckpointBarrier barrier = (RecoveryCheckpointBarrier) event; + + ByteBuffer buf = ByteBuffer.allocate(12); + buf.putInt(0, RECOVERY_CHECKPOINT_BARRIER_EVENT); + buf.putLong(4, barrier.getCheckpointId()); + return buf; } else if (eventClass == WatermarkEvent.class) { try { final DataOutputSerializer serializer = new DataOutputSerializer(128); @@ -206,6 +221,8 @@ public static AbstractEvent fromSerializedEvent(ByteBuffer buffer, ClassLoader c return EndOfOutputChannelStateEvent.INSTANCE; } else if (type == END_OF_INPUT_CHANNEL_STATE_EVENT) { return EndOfInputChannelStateEvent.INSTANCE; + } else if (type == END_OF_FETCHED_CHANNEL_STATE_EVENT) { + return EndOfFetchedChannelStateEvent.INSTANCE; } else if (type == END_OF_USER_RECORDS_EVENT) { return new EndOfData(StopMode.values()[buffer.get()]); } else if (type == CANCEL_CHECKPOINT_MARKER_EVENT) { @@ -222,6 +239,9 @@ public static AbstractEvent fromSerializedEvent(ByteBuffer buffer, ClassLoader c } else if (type == RECOVERY_METADATA) { int subpartitionId = buffer.getInt(); return new RecoveryMetadata(subpartitionId); + } else if (type == RECOVERY_CHECKPOINT_BARRIER_EVENT) { + long checkpointId = buffer.getLong(); + return new RecoveryCheckpointBarrier(checkpointId); } else if (type == GENERALIZED_WATERMARK_EVENT) { final DataInputDeserializer deserializer = new DataInputDeserializer(buffer); WatermarkEvent watermarkEvent = new WatermarkEvent(); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/logger/NetworkActionsLogger.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/logger/NetworkActionsLogger.java index 204f3d3c074b6c..137334ae14fc35 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/logger/NetworkActionsLogger.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/logger/NetworkActionsLogger.java @@ -98,13 +98,13 @@ public static void traceRecover( public static void tracePersist( String action, Buffer buffer, Object channelInfo, long checkpointId) { + tracePersist(action, buffer.toDebugString(INCLUDE_HASH), channelInfo, checkpointId); + } + + public static void tracePersist( + String action, Object persisted, Object channelInfo, long checkpointId) { if (LOG.isTraceEnabled()) { - LOG.trace( - "{} {}, checkpoint {} @ {}", - action, - buffer.toDebugString(INCLUDE_HASH), - checkpointId, - channelInfo); + LOG.trace("{} {}, checkpoint {} @ {}", action, persisted, checkpointId, channelInfo); } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/CreditBasedSequenceNumberingViewReader.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/CreditBasedSequenceNumberingViewReader.java index ae57b39b08d3c5..9498353f4d09c9 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/CreditBasedSequenceNumberingViewReader.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/CreditBasedSequenceNumberingViewReader.java @@ -47,7 +47,8 @@ *

It also keeps track of available buffers and notifies the outbound handler about * non-emptiness, similar to the {@link LocalInputChannel}. */ -class CreditBasedSequenceNumberingViewReader +class CreditBasedSequenceNumberingViewReader // review nit: can we split THIS commit into remote and + // local input channel changes? implements BufferAvailabilityListener, NetworkSequenceViewReader { private final Object requestLock = new Object(); @@ -81,12 +82,17 @@ class CreditBasedSequenceNumberingViewReader private int numCreditsAvailable; CreditBasedSequenceNumberingViewReader( - InputChannelID receiverId, int initialCredit, PartitionRequestQueue requestQueue) { + InputChannelID receiverId, + int initialCredit, + boolean needsRecovery, + PartitionRequestQueue requestQueue) { checkArgument(initialCredit >= 0, "Must be non-negative."); this.receiverId = receiverId; this.initialCredit = initialCredit; - this.numCreditsAvailable = initialCredit; + // During spill recovery, exclusive buffers are on loan to the recovery drain; real credit + // is announced only after recovery completes. + this.numCreditsAvailable = needsRecovery ? 0 : initialCredit; this.requestQueue = requestQueue; this.subpartitionId = -1; } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyMessage.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyMessage.java index ebe596e01f6b5a..6747e139a034c3 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyMessage.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyMessage.java @@ -583,15 +583,19 @@ static class PartitionRequest extends NettyMessage { final int credit; + final boolean needsRecovery; + PartitionRequest( ResultPartitionID partitionId, ResultSubpartitionIndexSet queueIndexSet, InputChannelID receiverId, - int credit) { + int credit, + boolean needsRecovery) { this.partitionId = checkNotNull(partitionId); this.queueIndexSet = queueIndexSet; this.receiverId = checkNotNull(receiverId); this.credit = credit; + this.needsRecovery = needsRecovery; } @Override @@ -604,6 +608,7 @@ void write(ChannelOutboundInvoker out, ChannelPromise promise, ByteBufAllocator queueIndexSet.writeTo(bb); receiverId.writeTo(bb); bb.writeInt(credit); + bb.writeBoolean(needsRecovery); }; writeToChannel( @@ -616,7 +621,8 @@ void write(ChannelOutboundInvoker out, ChannelPromise promise, ByteBufAllocator + ExecutionAttemptID.getByteBufLength() + ResultSubpartitionIndexSet.getByteBufLength(queueIndexSet) + InputChannelID.getByteBufLength() - + Integer.BYTES); + + Integer.BYTES + + Byte.BYTES); } static PartitionRequest readFrom(ByteBuf buffer) { @@ -628,8 +634,10 @@ static PartitionRequest readFrom(ByteBuf buffer) { ResultSubpartitionIndexSet.fromByteBuf(buffer); InputChannelID receiverId = InputChannelID.fromByteBuf(buffer); int credit = buffer.readInt(); + boolean needsRecovery = buffer.readBoolean(); - return new PartitionRequest(partitionId, queueIndexSet, receiverId, credit); + return new PartitionRequest( + partitionId, queueIndexSet, receiverId, credit, needsRecovery); } @Override diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyPartitionRequestClient.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyPartitionRequestClient.java index 7cc52a234e09a9..737e3de72690df 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyPartitionRequestClient.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyPartitionRequestClient.java @@ -129,7 +129,8 @@ public void requestSubpartition( partitionId, subpartitionIndexSet, inputChannel.getInputChannelId(), - inputChannel.getInitialCredit()); + inputChannel.getInitialCredit(), + inputChannel.needsRecovery()); final ChannelFutureListener listener = future -> { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestServerHandler.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestServerHandler.java index f93651b9f3e8d2..6f2cd6b13ad419 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestServerHandler.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestServerHandler.java @@ -85,7 +85,10 @@ protected void channelRead0(ChannelHandlerContext ctx, NettyMessage msg) throws NetworkSequenceViewReader reader; reader = new CreditBasedSequenceNumberingViewReader( - request.receiverId, request.credit, outboundQueue); + request.receiverId, + request.credit, + request.needsRecovery, + outboundQueue); reader.requestSubpartitionViewOrRegisterListener( partitionProvider, request.partitionId, request.queueIndexSet); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/BufferManager.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/BufferManager.java index db38025def9e3d..1eed2a0284fd0a 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/BufferManager.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/BufferManager.java @@ -70,13 +70,24 @@ public class BufferManager implements BufferListener, BufferRecycler { @GuardedBy("bufferQueue") private int numRequiredBuffers; + /** + * Gates credit announcements while a recovery drain borrows this channel's buffers. Kept under + * {@code bufferQueue} to avoid inverting the queue/recovered-buffer lock order. + */ + @GuardedBy("bufferQueue") + private boolean notifyAvailable; + public BufferManager( - MemorySegmentProvider globalPool, InputChannel inputChannel, int numRequiredBuffers) { + MemorySegmentProvider globalPool, + InputChannel inputChannel, + int numRequiredBuffers, + boolean notifyInitiallyEnabled) { this.globalPool = checkNotNull(globalPool); this.inputChannel = checkNotNull(inputChannel); checkArgument(numRequiredBuffers >= 0); this.numRequiredBuffers = numRequiredBuffers; + this.notifyAvailable = notifyInitiallyEnabled; } // ------------------------------------------------------------------------ @@ -158,23 +169,22 @@ void requestExclusiveBuffers(int numExclusiveBuffers) throws IOException { /** * Requests floating buffers from the buffer pool based on the given required amount, and - * returns the actual requested amount. If the required amount is not fully satisfied, it will - * register as a listener. + * returns the number of buffers that may be announced to the producer as credit. During + * recovery, requested buffers are queued but announced only by {@link #enableNotify()}. */ int requestFloatingBuffers(int numRequired) { - int numRequestedBuffers = 0; synchronized (bufferQueue) { // Similar to notifyBufferAvailable(), make sure that we never add a buffer after // channel // released all buffers via releaseAllResources(). if (inputChannel.isReleased()) { - return numRequestedBuffers; + return 0; } numRequiredBuffers = numRequired; - numRequestedBuffers = tryRequestBuffers(); + int numRequestedBuffers = tryRequestBuffers(); + return notifyAvailable ? numRequestedBuffers : 0; } - return numRequestedBuffers; } private int tryRequestBuffers() { @@ -209,6 +219,7 @@ private int tryRequestBuffers() { @Override public void recycle(MemorySegment segment) { @Nullable Buffer releasedFloatingBuffer = null; + boolean announceCredit = false; synchronized (bufferQueue) { try { // Similar to notifyBufferAvailable(), make sure that we never add a buffer @@ -226,11 +237,12 @@ public void recycle(MemorySegment segment) { } finally { bufferQueue.notifyAll(); } + announceCredit = releasedFloatingBuffer == null && notifyAvailable; } if (releasedFloatingBuffer != null) { releasedFloatingBuffer.recycleBuffer(); - } else { + } else if (announceCredit) { try { inputChannel.notifyBufferAvailable(1); } catch (Throwable t) { @@ -344,6 +356,9 @@ public boolean notifyBufferAvailable(Buffer buffer) { isBufferUsed = true; numBuffers += 1 + tryRequestBuffers(); bufferQueue.notifyAll(); + if (!notifyAvailable) { + numBuffers = 0; + } } inputChannel.notifyBufferAvailable(numBuffers); @@ -359,6 +374,19 @@ public void notifyBufferDestroyed() { // Nothing to do actually. } + /** + * Opens the recovery credit gate and announces the queued buffers atomically with respect to + * concurrent recycle/floating-buffer callbacks. + */ + void enableNotify() throws IOException { + int available; + synchronized (bufferQueue) { + notifyAvailable = true; + available = bufferQueue.getAvailableBufferSize(); + } + inputChannel.notifyBufferAvailable(available); + } + // ------------------------------------------------------------------------ // Getter properties // ------------------------------------------------------------------------ diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/EndOfFetchedChannelStateEvent.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/EndOfFetchedChannelStateEvent.java new file mode 100644 index 00000000000000..9900a0cdc660df --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/EndOfFetchedChannelStateEvent.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.io.network.partition.consumer; + +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.runtime.event.RuntimeEvent; + +/** + * Marks the tail of recovered buffers that the spill drain pushed into a {@link + * RecoverableInputChannel}. The consume path polls this sentinel to learn the exact moment all + * recovered buffers have been consumed; it is never delivered to the operator. It is distinct from + * {@link EndOfInputChannelStateEvent} (which terminates the {@link RecoveredInputChannel} read + * stream) so the two recovery handoffs cannot be confused. + */ +public class EndOfFetchedChannelStateEvent extends RuntimeEvent { + + /** The singleton instance of this event. */ + public static final EndOfFetchedChannelStateEvent INSTANCE = + new EndOfFetchedChannelStateEvent(); + + // ------------------------------------------------------------------------ + + // not instantiable + private EndOfFetchedChannelStateEvent() {} + + // ------------------------------------------------------------------------ + + @Override + public void write(DataOutputView out) { + throw new UnsupportedOperationException( + "EndOfFetchedChannelStateEvent must be serialized via EventSerializer's dedicated" + + " type-tag path, not reflective write()."); + } + + @Override + public void read(DataInputView in) { + throw new UnsupportedOperationException( + "EndOfFetchedChannelStateEvent must be deserialized via EventSerializer's dedicated" + + " type-tag path, not reflective read()."); + } + + // ------------------------------------------------------------------------ + + @Override + public int hashCode() { + return 20250814; + } + + @Override + public boolean equals(Object obj) { + return obj != null && obj.getClass() == EndOfFetchedChannelStateEvent.class; + } + + @Override + public String toString() { + return getClass().getSimpleName(); + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/IndexedInputGate.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/IndexedInputGate.java index 915012924d1610..5daa277cd9b34f 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/IndexedInputGate.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/IndexedInputGate.java @@ -74,10 +74,4 @@ public void convertToPriorityEvent(int channelIndex, int sequenceNumber) throws public abstract ResultPartitionType getConsumedPartitionType(); public abstract void triggerDebloating(); - - /** Sets whether unaligned checkpointing during recovery is enabled. */ - public abstract void setCheckpointingDuringRecoveryEnabled(boolean enabled); - - /** Returns whether unaligned checkpointing during recovery is enabled. */ - public abstract boolean isCheckpointingDuringRecoveryEnabled(); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputGate.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputGate.java index dd744bae330ff3..142f45cf9bcc5e 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputGate.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputGate.java @@ -141,6 +141,14 @@ public abstract void acknowledgeAllRecordsProcessed(InputChannelInfo channelInfo /** Returns the channel of this gate. */ public abstract InputChannel getChannel(int channelIndex); + /** + * Returns the channel identified by {@code channelInfo}. Unlike {@link #getChannel(int)}, whose + * index is gate-global, this resolves through the full {@code (gateIdx, inputChannelIdx)} pair, + * so it stays correct for {@link UnionInputGate} where the global index differs from a member + * gate's local channel index. + */ + public abstract InputChannel getChannel(InputChannelInfo channelInfo); + /** Returns the channel infos of this gate. */ public List getChannelInfos() { return IntStream.range(0, getNumberOfInputChannels()) @@ -190,14 +198,16 @@ public String toString() { public abstract void requestPartitions() throws IOException; - public abstract CompletableFuture getStateConsumedFuture(); - /** - * Returns a future that completes when buffer filtering is complete for all channels. This - * future completes before {@link #getStateConsumedFuture()}, enabling earlier RUNNING state - * transition when unaligned checkpoint during recovery is enabled. + * Requests the partitions. {@code needsRecovery} controls whether converted physical channels + * start in recovery (i.e. with no credit / no floating buffers until the recovered channel + * state has been drained). The default implementation ignores the flag. */ - public abstract CompletableFuture getBufferFilteringCompleteFuture(); + public void requestPartitions(boolean needsRecovery) throws IOException { + requestPartitions(); + } + + public abstract CompletableFuture getStateConsumedFuture(); public abstract void finishReadRecoveredState() throws IOException; } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannel.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannel.java index 661e4b063c75f0..b4a7e2bd20b1a6 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannel.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannel.java @@ -21,11 +21,15 @@ import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.metrics.Counter; import org.apache.flink.runtime.checkpoint.CheckpointException; +import org.apache.flink.runtime.checkpoint.CheckpointFailureReason; import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter; +import org.apache.flink.runtime.checkpoint.channel.RecoveryCheckpointBarrier; +import org.apache.flink.runtime.event.AbstractEvent; import org.apache.flink.runtime.event.TaskEvent; import org.apache.flink.runtime.execution.CancelTaskException; import org.apache.flink.runtime.io.network.TaskEventPublisher; import org.apache.flink.runtime.io.network.api.CheckpointBarrier; +import org.apache.flink.runtime.io.network.api.serialization.EventSerializer; import org.apache.flink.runtime.io.network.buffer.Buffer; import org.apache.flink.runtime.io.network.buffer.CompositeBuffer; import org.apache.flink.runtime.io.network.buffer.FileRegionBuffer; @@ -43,21 +47,26 @@ import org.slf4j.LoggerFactory; import javax.annotation.Nullable; +import javax.annotation.concurrent.GuardedBy; import java.io.IOException; import java.util.ArrayDeque; import java.util.ArrayList; +import java.util.Collections; import java.util.Deque; +import java.util.Iterator; import java.util.List; import java.util.Optional; import java.util.Timer; import java.util.TimerTask; +import java.util.concurrent.CompletableFuture; import static org.apache.flink.util.Preconditions.checkNotNull; import static org.apache.flink.util.Preconditions.checkState; /** An input channel, which requests a local subpartition. */ -public class LocalInputChannel extends InputChannel implements BufferAvailabilityListener { +public class LocalInputChannel extends InputChannel + implements BufferAvailabilityListener, RecoverableInputChannel { private static final Logger LOG = LoggerFactory.getLogger(LocalInputChannel.class); @@ -81,12 +90,46 @@ public class LocalInputChannel extends InputChannel implements BufferAvailabilit private final Deque toBeConsumedBuffers = new ArrayDeque<>(); /** - * Flag indicating whether there is a pending priority event (e.g., checkpoint barrier) in the - * subpartitionView that should be consumed before toBeConsumedBuffers. This is set by {@link - * #notifyPriorityEvent} and checked in {@link #getNextBuffer()}. + * Buffers delivered from {@code RecoveredInputChannel}, kept separately from {@link + * #toBeConsumedBuffers} so that recovery semantics (priority event interleaving, checkpoint + * inflight persistence) do not leak into the FullyFilledBuffer split path. Holds recovered + * buffers plus the {@code RecoveryCheckpointBarrier} and {@code EndOfFetchedChannelStateEvent} + * sentinels. The deque object is its own monitor; {@link #inRecovery} and {@link + * #recoverySequenceNumber} are guarded by it too. + */ + private final Deque recoveredBuffers = new ArrayDeque<>(); + + /** + * Whether the channel is still replaying recovered state. Starts {@code false} for channels + * that do not need recovery and is flipped to {@code false} the moment the consume path polls + * the {@code EndOfFetchedChannelStateEvent} sentinel appended after the last recovered buffer + * (see {@link #onRecoveredStateConsumed()}). While {@code true} the consume path serves + * recovered buffers and does not poll ordinary upstream data. + */ + @GuardedBy("recoveredBuffers") + private boolean inRecovery; + + /** + * Sequence number assigned to recovered buffers, starting at {@link Integer#MIN_VALUE}, + * consistent with {@link RecoveredInputChannel}. + */ + private int recoverySequenceNumber = Integer.MIN_VALUE; + + @Nullable private final BufferManager bufferManager; + + private final int networkBuffersPerChannel; + + private final boolean needsRecovery; + + /** + * Whether a priority event (e.g., checkpoint barrier) is pending in {@code subpartitionView} + * and must be consumed before {@code recoveredBuffers}. Volatile because it is written by the + * network thread and read by the task thread. */ private volatile boolean hasPendingPriorityEvent = false; + private final CompletableFuture upstreamReady = new CompletableFuture<>(); + public LocalInputChannel( SingleInputGate inputGate, int channelIndex, @@ -99,7 +142,8 @@ public LocalInputChannel( Counter numBytesIn, Counter numBuffersIn, ChannelStateWriter stateWriter, - ArrayDeque initialRecoveredBuffers) { + int networkBuffersPerChannel, + boolean needsRecovery) { super( inputGate, @@ -113,48 +157,220 @@ public LocalInputChannel( this.partitionManager = checkNotNull(partitionManager); this.taskEventPublisher = checkNotNull(taskEventPublisher); - this.channelStatePersister = new ChannelStatePersister(stateWriter, getChannelInfo()); - - // Migrate recovered buffers from RecoveredInputChannel if provided. - // These buffers have been filtered but not yet consumed by the Task. - if (!initialRecoveredBuffers.isEmpty()) { - final int expectedCount = initialRecoveredBuffers.size(); - // Sequence number starts at Integer.MIN_VALUE, consistent with RecoveredInputChannel. - int seqNum = Integer.MIN_VALUE; - while (!initialRecoveredBuffers.isEmpty()) { - Buffer buffer = initialRecoveredBuffers.poll(); - // Determine next data type based on the next buffer in the queue - Buffer.DataType nextDataType = - initialRecoveredBuffers.isEmpty() - ? Buffer.DataType.NONE - : initialRecoveredBuffers.peek().getDataType(); - // buffersInBacklog is set to 0 as these are recovered buffers - BufferAndBacklog bufferAndBacklog = - new BufferAndBacklog(buffer, 0, nextDataType, seqNum++); - toBeConsumedBuffers.add(bufferAndBacklog); + this.channelStatePersister = + new ChannelStatePersister(checkNotNull(stateWriter), getChannelInfo()); + this.inRecovery = needsRecovery; + this.bufferManager = + needsRecovery + // review nit: false for consistency? + ? new BufferManager(inputGate.getMemorySegmentProvider(), this, 0, true) + : null; + this.networkBuffersPerChannel = networkBuffersPerChannel; + this.needsRecovery = needsRecovery; + } + + @Override + void setup() throws IOException { + if (needsRecovery && networkBuffersPerChannel > 0) { + bufferManager.requestExclusiveBuffers(networkBuffersPerChannel); + } + } + + // ------------------------------------------------------------------------ + // RecoverableInputChannel implementation + // ------------------------------------------------------------------------ + + @Override + public void onRecoveredStateBuffer(Buffer buffer) { + boolean wasEmpty; + synchronized (recoveredBuffers) { + if (isReleased) { + buffer.recycleBuffer(); + return; } - checkState( - toBeConsumedBuffers.size() == expectedCount, - "Buffer migration failed: expected %s buffers but got %s", - expectedCount, - toBeConsumedBuffers.size()); + // Migrate recovered buffers from RecoveredInputChannel. These buffers have been + // filtered but not yet consumed by the Task. + wasEmpty = offerRecoveredBuffer(buffer); + } + if (wasEmpty) { + notifyChannelNonEmpty(); + } + } + + @Override + public void finishRecoveredBufferDelivery() throws IOException { + upstreamReady.join(); + boolean wasEmpty; + synchronized (recoveredBuffers) { + checkState(inRecovery, "Recovery delivery already finished."); + // Append the sentinel after the last recovered buffer. The consume path flips out of + // recovery only once it polls this sentinel, guaranteeing all recovered buffers are + // consumed first. + wasEmpty = + offerRecoveredBuffer( + EventSerializer.toBuffer( + EndOfFetchedChannelStateEvent.INSTANCE, false)); + } + if (wasEmpty) { + notifyChannelNonEmpty(); + } + } + + @Override + public Buffer requestRecoveryBufferBlocking() throws InterruptedException, IOException { + checkState( + bufferManager != null, + "requestRecoveryBufferBlocking called on a Local channel constructed with" + + " needsRecovery=false"); + upstreamReady.join(); + return bufferManager.requestBufferBlocking(); + } + + @Override + public void insertRecoveryCheckpointBarrierIfInRecovery(long checkpointId) throws IOException { + boolean wasEmpty = false; + synchronized (recoveredBuffers) { + if (!isReleased && inRecovery) { + wasEmpty = + offerRecoveredBuffer( + EventSerializer.toBuffer( + new RecoveryCheckpointBarrier(checkpointId), false)); + } + } + if (wasEmpty) { + notifyChannelNonEmpty(); + } + } + + /** + * Flips out of recovery the moment the consume path polls the {@code + * EndOfFetchedChannelStateEvent} sentinel, i.e. once all recovered buffers have been consumed. + * Live upstream data may flow again afterwards. + */ + @Override + public void onRecoveredStateConsumed() { + synchronized (recoveredBuffers) { + checkState(inRecovery, "Recovery already finished."); + inRecovery = false; + } + notifyChannelNonEmpty(); + } + + /** + * Appends a recovered buffer (or {@code RecoveryCheckpointBarrier} / {@code + * EndOfFetchedChannelStateEvent} sentinel) to {@link #recoveredBuffers}. + * + * @return {@code true} iff {@link #recoveredBuffers} transitioned from empty to non-empty. + */ + private boolean offerRecoveredBuffer(Buffer buffer) { + assert Thread.holdsLock(recoveredBuffers); + checkState(inRecovery, "Push into recovered buffers after recovery finished."); + boolean wasEmpty = recoveredBuffers.isEmpty(); + recoveredBuffers.add(buffer); + return wasEmpty; + } + + private int nextRecoverySequenceNumber() { + assert Thread.holdsLock(recoveredBuffers); + return recoverySequenceNumber++; + } + + /** + * Walks {@link #recoveredBuffers} up to the {@link RecoveryCheckpointBarrier} sentinel matching + * {@code checkpointId}, retaining each pre-barrier recovered data buffer and removing the + * sentinel. + * + * @throws IOException if no sentinel matching {@code checkpointId} is found (the snapshot + * protocol guarantees one must be present while the channel is in recovery). + */ + private List collectPreRecoveryBarrier(long checkpointId) + throws IOException, CheckpointException { + assert Thread.holdsLock(recoveredBuffers); + List retained = new ArrayList<>(); + try { + Iterator it = recoveredBuffers.iterator(); + while (it.hasNext()) { + Buffer b = it.next(); + if (isRecoveryCheckpointBarrier(b, checkpointId)) { + it.remove(); + b.recycleBuffer(); + return retained; + } + if (b.isBuffer()) { + retained.add(b.retainBuffer()); + } + } + } catch (IOException e) { + releaseRetainedBuffers(retained); + throw e; + } + releaseRetainedBuffers(retained); + // The during-recovery sentinel for this checkpoint was never inserted into this channel + // (the recovery checkpoint trigger had already transitioned away from the drainer while + // this channel was still in recovery). The channel is simply not ready to snapshot + // recovered state for this checkpoint yet, so decline as TASK_NOT_READY: that is not + // counted against the tolerable-failure threshold, so the checkpoint is deferred and + // retried instead of failing the job. The recovered buffers remain queued and are + // captured by a later checkpoint, so no in-flight data is lost. + throw new CheckpointException( + "RecoveryCheckpointBarrier for checkpoint " + + checkpointId + + " not yet present in channel " + + getChannelInfo(), + CheckpointFailureReason.CHECKPOINT_DECLINED_TASK_NOT_READY); + } + + private static void releaseRetainedBuffers(List retained) { + for (Buffer buffer : retained) { + buffer.recycleBuffer(); } } + private static boolean isRecoveryCheckpointBarrier(Buffer b, long checkpointId) + throws IOException { + if (b.isBuffer()) { + return false; + } + AbstractEvent event = + EventSerializer.fromBuffer(b, RecoveryCheckpointBarrier.class.getClassLoader()); + b.setReaderIndex(0); + return event instanceof RecoveryCheckpointBarrier + && ((RecoveryCheckpointBarrier) event).getCheckpointId() == checkpointId; + } + // ------------------------------------------------------------------------ // Consume // ------------------------------------------------------------------------ + @Override public void checkpointStarted(CheckpointBarrier barrier) throws CheckpointException { - // Collect inflight buffers from toBeConsumedBuffers to be persisted. - // These are buffers that have not been consumed yet when the checkpoint barrier arrives. - List inflightBuffers = new ArrayList<>(); - for (BufferAndBacklog bufferAndBacklog : toBeConsumedBuffers) { - if (bufferAndBacklog.buffer().isBuffer()) { - inflightBuffers.add(bufferAndBacklog.buffer().retainBuffer()); + try { + List toPersist; + boolean stopPersisting = false; + synchronized (recoveredBuffers) { + if (inRecovery) { + // Collect inflight buffers from recoveredBuffers to be persisted. These are + // recovered buffers that have not been consumed yet when the checkpoint barrier + // arrives. + toPersist = collectPreRecoveryBarrier(barrier.getId()); + stopPersisting = true; + } else { + toPersist = Collections.emptyList(); + } + } + channelStatePersister.startPersisting(barrier.getId(), toPersist); + if (stopPersisting) { + // Recovered inflight buffers are collected in one shot and no upstream data flows + // during recovery, so close the persist window immediately to keep the persister + // from carrying a pending state into later checkpoints. + channelStatePersister.stopPersisting(barrier.getId()); } + } catch (IOException e) { + throw new CheckpointException( + "Failed to extract recovered buffers for checkpoint " + barrier.getId(), + CheckpointFailureReason.CHECKPOINT_DECLINED, + e); } - channelStatePersister.startPersisting(barrier.getId(), inflightBuffers); } public void checkpointStopped(long checkpointId) { @@ -163,6 +379,8 @@ public void checkpointStopped(long checkpointId) { @Override protected void requestSubpartitions() throws IOException { + checkState(toBeConsumedBuffers.isEmpty()); + boolean retriggerRequest = false; boolean notifyDataAvailable = false; @@ -196,6 +414,7 @@ protected void requestSubpartitions() throws IOException { this.subpartitionView = null; } else { notifyDataAvailable = true; + upstreamReady.complete(null); } } catch (PartitionNotFoundException notFound) { if (increaseBackoff()) { @@ -272,8 +491,35 @@ protected int peekNextBufferSubpartitionIdInternal() throws IOException { public Optional getNextBuffer() throws IOException { checkError(); + // Read inRecovery and poll the recovered buffer under a single lock acquisition to avoid + // grabbing the monitor twice on the hot path. + boolean inRecovery; + Buffer recoveredBuf = null; + synchronized (recoveredBuffers) { + inRecovery = this.inRecovery; + if (inRecovery && !hasPendingPriorityEvent && !recoveredBuffers.isEmpty()) { + recoveredBuf = recoveredBuffers.poll(); + } + } + + if (inRecovery) { + // Always return an already-polled recovered buffer first: hasPendingPriorityEvent may + // be flipped to true by a concurrent notifyPriorityEvent() after the poll, and + // re-reading + // it here would otherwise drop this buffer. A pending priority event is served on the + // next getNextBuffer() call instead. + if (recoveredBuf != null) { + return wrapRecoveredBufferAsAvailability(recoveredBuf); + } + if (hasPendingPriorityEvent) { + return pullPriorityFromSubpartitionView(); + } + // Drain not finished yet; block normal upstream data until delivery completes. + return Optional.empty(); + } + if (!toBeConsumedBuffers.isEmpty()) { - return getNextRecoveredBuffer(); + return getBufferAndAvailability(toBeConsumedBuffers.removeFirst()); } ResultSubpartitionView subpartitionView = this.subpartitionView; @@ -335,66 +581,93 @@ public Optional getNextBuffer() throws IOException { return getBufferAndAvailability(next); } - /** - * Consumes the next buffer from toBeConsumedBuffers (recovered buffers), handling pending - * priority events and dynamic availability detection for the last recovered buffer. - */ - private Optional getNextRecoveredBuffer() throws IOException { - // If there is a pending priority event (e.g., unaligned checkpoint barrier), fetch it - // from subpartitionView first, skipping toBeConsumedBuffers. This ensures priority - // events are processed immediately even when there are pending recovered buffers. - if (hasPendingPriorityEvent) { - checkState(subpartitionView != null, "No subpartition view available"); - BufferAndBacklog next = subpartitionView.getNextBuffer(); - checkState( - next != null && next.buffer().getDataType().hasPriority(), - "Expected priority event, but got %s", - next == null ? "null" : next.buffer().getDataType()); - - // Check for barrier to update channel state persister. - // Note: maybePersist is not needed for barriers as they are not regular data buffers. - channelStatePersister.checkForBarrier(next.buffer()); - - Buffer.DataType expectedNextDataType = next.getNextDataType(); - if (!expectedNextDataType.hasPriority()) { - // Reset hasPendingPriorityEvent to false if no more priority event - hasPendingPriorityEvent = false; - if (!toBeConsumedBuffers.isEmpty()) { - // Correct nextDataType: if toBeConsumedBuffers is not empty, the actual next - // element to consume is from toBeConsumedBuffers, not from subpartitionView - expectedNextDataType = toBeConsumedBuffers.peek().buffer().getDataType(); - } - } + private Optional pullPriorityFromSubpartitionView() throws IOException { + // If there is a pending priority event (e.g., unaligned checkpoint barrier), fetch it from + // subpartitionView first, skipping recoveredBuffers. This ensures priority events are + // processed immediately even when there are pending recovered buffers. + checkState(subpartitionView != null, "No subpartition view available"); + BufferAndBacklog next = subpartitionView.getNextBuffer(); + checkState( + next != null && next.buffer().getDataType().hasPriority(), + "Expected priority event, but got %s", + next == null ? "null" : next.buffer().getDataType()); + + // Check for barrier to update channel state persister. Note: maybePersist is not needed for + // barriers as they are not regular data buffers. + channelStatePersister.checkForBarrier(next.buffer()); + + Buffer.DataType expectedNextDataType = next.getNextDataType(); + if (!expectedNextDataType.hasPriority()) { + // Reset hasPendingPriorityEvent to false if no more priority event. + hasPendingPriorityEvent = false; + // Correct nextDataType: if recoveredBuffers is not empty, the actual next element to + // consume is from recoveredBuffers, not from subpartitionView. + expectedNextDataType = peekNextDataType(next.getNextDataType()); + } - return getBufferAndAvailability( - new BufferAndBacklog( - next.buffer(), - next.buffersInBacklog(), - expectedNextDataType, - next.getSequenceNumber())); - } - - BufferAndBacklog next = toBeConsumedBuffers.removeFirst(); - - // If this is the last recovered buffer and nextDataType is NONE, - // dynamically check if subpartitionView has data available. - // The last buffer's nextDataType was preset to NONE during construction, - // but subpartitionView may already have data available. - if (toBeConsumedBuffers.isEmpty() - && next.getNextDataType() == Buffer.DataType.NONE - && subpartitionView != null) { - ResultSubpartitionView.AvailabilityWithBacklog availability = - subpartitionView.getAvailabilityAndBacklog(true); - if (availability.isAvailable()) { - next = - new BufferAndBacklog( - next.buffer(), - availability.getBacklog(), - Buffer.DataType.DATA_BUFFER, - next.getSequenceNumber()); + return Optional.of( + new BufferAndAvailability( + next.buffer(), + expectedNextDataType, + next.buffersInBacklog(), + next.getSequenceNumber())); + } + + private Optional wrapRecoveredBufferAsAvailability(Buffer buf) + throws IOException { + if (buf instanceof FileRegionBuffer) { + buf = ((FileRegionBuffer) buf).readInto(inputGate.getUnpooledSegment()); + } + if (buf instanceof CompositeBuffer) { + buf = ((CompositeBuffer) buf).getFullBufferData(inputGate.getUnpooledSegment()); + } + + numBytesIn.inc(buf.readableBytes()); + numBuffersIn.inc(); + + ResultSubpartitionView view = subpartitionView; + Buffer.DataType upstreamProbe; + if (view != null && view.getAvailabilityAndBacklog(true).isAvailable()) { + upstreamProbe = Buffer.DataType.DATA_BUFFER; + } else { + upstreamProbe = Buffer.DataType.NONE; + } + + int sequenceNumber; + synchronized (recoveredBuffers) { + Buffer.DataType nextDataType = peekNextDataType(upstreamProbe); + sequenceNumber = nextRecoverySequenceNumber(); + NetworkActionsLogger.traceInput( + "LocalInputChannel#getNextBuffer", + buf, + inputGate.getOwningTaskName(), + channelInfo, + channelStatePersister, + sequenceNumber); + // buffersInBacklog is set to 0 as these are recovered buffers. + return Optional.of(new BufferAndAvailability(buf, nextDataType, 0, sequenceNumber)); + } + } + + private Buffer.DataType peekNextDataType(Buffer.DataType nextDataTypeOnUpstream) { + synchronized (recoveredBuffers) { + if (!recoveredBuffers.isEmpty()) { + return recoveredBuffers.peek().getDataType(); + } + if (inRecovery) { + // If this is the last currently available recovered buffer, hide upstream data + // until the EndOfFetchedChannelStateEvent sentinel flips the channel out of + // recovery. The last buffer's nextDataType is effectively NONE while the drain can + // still append more recovered buffers. + return Buffer.DataType.NONE; } } - return getBufferAndAvailability(next); + // If this is the last recovered buffer after delivery finished, dynamically check if + // subpartitionView has data available. The last buffer's nextDataType may have been NONE + // while recovered data was still being delivered, but subpartitionView may already have + // data + // available now. + return nextDataTypeOnUpstream; } private Optional getBufferAndAvailability(BufferAndBacklog next) @@ -435,7 +708,7 @@ public void notifyDataAvailable(ResultSubpartitionView view) { @Override public void notifyPriorityEvent(int prioritySequenceNumber) { // Set flag so that getNextBuffer() knows to fetch priority event from subpartitionView - // before consuming toBeConsumedBuffers. + // before consuming recoveredBuffers. hasPendingPriorityEvent = true; super.notifyPriorityEvent(prioritySequenceNumber); } @@ -506,18 +779,30 @@ void releaseAllResources() throws IOException { if (!isReleased) { isReleased = true; + upstreamReady.completeExceptionally(new CancelTaskException("Channel released.")); + ResultSubpartitionView view = subpartitionView; if (view != null) { view.releaseAllResources(); subpartitionView = null; } - // Release any remaining buffers in toBeConsumedBuffers to avoid memory leak. - // These may be recovered buffers or partial buffers from FullyFilledBuffer. + // Release any remaining buffers in recoveredBuffers (migrated recovered buffers not yet + // consumed) and toBeConsumedBuffers (FullyFilledBuffer partial splits) to avoid memory + // leak. + synchronized (recoveredBuffers) { + for (Buffer buffer : recoveredBuffers) { + buffer.recycleBuffer(); + } + recoveredBuffers.clear(); + } for (BufferAndBacklog bufferAndBacklog : toBeConsumedBuffers) { bufferAndBacklog.buffer().recycleBuffer(); } toBeConsumedBuffers.clear(); + if (bufferManager != null) { + bufferManager.releaseAllBuffers(new ArrayDeque<>()); + } } } @@ -534,14 +819,16 @@ void announceBufferSize(int newBufferSize) { @Override int getBuffersInUseCount() { ResultSubpartitionView view = this.subpartitionView; - return toBeConsumedBuffers.size() + (view == null ? 0 : view.getNumberOfQueuedBuffers()); + return recoveredBuffers.size() + + toBeConsumedBuffers.size() + + (view == null ? 0 : view.getNumberOfQueuedBuffers()); } @Override public int unsynchronizedGetNumberOfQueuedBuffers() { ResultSubpartitionView view = subpartitionView; - int count = toBeConsumedBuffers.size(); + int count = recoveredBuffers.size() + toBeConsumedBuffers.size(); if (view != null) { count += view.unsynchronizedGetNumberOfQueuedBuffers(); } @@ -569,4 +856,9 @@ public String toString() { ResultSubpartitionView getSubpartitionView() { return subpartitionView; } + + @VisibleForTesting + void completeUpstreamReadyForTest() { + upstreamReady.complete(null); + } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalRecoveredInputChannel.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalRecoveredInputChannel.java index bdde2244f38ef6..de058bd3723ce5 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalRecoveredInputChannel.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalRecoveredInputChannel.java @@ -19,14 +19,11 @@ package org.apache.flink.runtime.io.network.partition.consumer; import org.apache.flink.runtime.io.network.TaskEventPublisher; -import org.apache.flink.runtime.io.network.buffer.Buffer; import org.apache.flink.runtime.io.network.metrics.InputChannelMetrics; import org.apache.flink.runtime.io.network.partition.ResultPartitionID; import org.apache.flink.runtime.io.network.partition.ResultPartitionManager; import org.apache.flink.runtime.io.network.partition.ResultSubpartitionIndexSet; -import java.util.ArrayDeque; - import static org.apache.flink.util.Preconditions.checkNotNull; /** @@ -64,7 +61,7 @@ public class LocalRecoveredInputChannel extends RecoveredInputChannel { } @Override - protected InputChannel toInputChannelInternal(ArrayDeque remainingBuffers) { + protected InputChannel toInputChannelInternal(boolean needsRecovery) { return new LocalInputChannel( inputGate, getChannelIndex(), @@ -77,6 +74,7 @@ protected InputChannel toInputChannelInternal(ArrayDeque remainingBuffer numBytesIn, numBuffersIn, channelStateWriter, - remainingBuffers); + networkBuffersPerChannel, + needsRecovery); } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RecoverableInputChannel.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RecoverableInputChannel.java new file mode 100644 index 00000000000000..9ccca71237f1eb --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RecoverableInputChannel.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.io.network.partition.consumer; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo; +import org.apache.flink.runtime.io.network.buffer.Buffer; + +import java.io.IOException; + +/** Physical input channel that can receive recovered buffers pushed by the spill drain. */ +// review: please rephrase: physical, spill, drain +// review: should this be in the previous commit? according to THIS commit message +// review todo impl +@Internal +public interface RecoverableInputChannel { + + InputChannelInfo getChannelInfo(); + + /** + * Appends a recovered buffer or recovery-checkpoint sentinel. Released channels recycle the + * buffer silently. + */ + void onRecoveredStateBuffer(Buffer buffer); + + /** + * Marks producer-side recovery delivery complete. Implementations wait for upstream readiness + * before flipping this state so channels without spill entries still observe the same handoff. + */ + void finishRecoveredBufferDelivery() throws IOException, InterruptedException; + + /** + * Inserts a {@code RecoveryCheckpointBarrier} for {@code checkpointId} into this channel's + * recovery queue if the channel is still in recovery. + */ + void insertRecoveryCheckpointBarrierIfInRecovery(long checkpointId) throws IOException; + + /** + * Blocks until a buffer is available from this channel's own buffer pool. Implementations must + * first await upstream readiness and must be invoked outside the drainer lock. + */ + Buffer requestRecoveryBufferBlocking() throws InterruptedException, IOException; + + /** + * Invoked by the consume path the moment it polls the {@code EndOfFetchedChannelStateEvent} + * sentinel, i.e. once all recovered buffers have been consumed. Implementations flip out of + * recovery, release any upstream events held back during recovery, and reopen the upstream so + * live data may flow again. + */ + // review: is this guaranteed to be called after finishRecoveredBufferDelivery? + void onRecoveredStateConsumed() throws IOException; +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredInputChannel.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredInputChannel.java index d9b7885815bd12..d291e4ea8e9962 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredInputChannel.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredInputChannel.java @@ -19,8 +19,6 @@ package org.apache.flink.runtime.io.network.partition.consumer; import org.apache.flink.annotation.VisibleForTesting; -import org.apache.flink.core.memory.MemorySegment; -import org.apache.flink.core.memory.MemorySegmentFactory; import org.apache.flink.metrics.Counter; import org.apache.flink.runtime.checkpoint.CheckpointException; import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter; @@ -29,13 +27,10 @@ import org.apache.flink.runtime.io.network.api.CheckpointBarrier; import org.apache.flink.runtime.io.network.api.serialization.EventSerializer; import org.apache.flink.runtime.io.network.buffer.Buffer; -import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler; -import org.apache.flink.runtime.io.network.buffer.NetworkBuffer; import org.apache.flink.runtime.io.network.logger.NetworkActionsLogger; import org.apache.flink.runtime.io.network.partition.ChannelStateHolder; import org.apache.flink.runtime.io.network.partition.ResultPartitionID; import org.apache.flink.runtime.io.network.partition.ResultSubpartitionIndexSet; -import org.apache.flink.runtime.memory.MemoryManager; import org.apache.flink.util.Preconditions; import org.slf4j.Logger; @@ -62,13 +57,6 @@ public abstract class RecoveredInputChannel extends InputChannel implements Chan private final CompletableFuture stateConsumedFuture = new CompletableFuture<>(); protected final BufferManager bufferManager; - /** - * Future that completes when recovered buffers have been filtered for this channel. This - * completes before stateConsumedFuture, enabling earlier RUNNING state transition when - * unaligned checkpoint during recovery is enabled. - */ - private final CompletableFuture bufferFilteringCompleteFuture = new CompletableFuture<>(); - @GuardedBy("receivedBuffers") private boolean isReleased; @@ -105,7 +93,7 @@ public abstract class RecoveredInputChannel extends InputChannel implements Chan numBytesIn, numBuffersIn); - bufferManager = new BufferManager(inputGate.getMemorySegmentProvider(), this, 0); + bufferManager = new BufferManager(inputGate.getMemorySegmentProvider(), this, 0, true); this.networkBuffersPerChannel = networkBuffersPerChannel; } @@ -115,23 +103,14 @@ public void setChannelStateWriter(ChannelStateWriter channelStateWriter) { this.channelStateWriter = checkNotNull(channelStateWriter); } - public final InputChannel toInputChannel() throws IOException { - Preconditions.checkState( - bufferFilteringCompleteFuture.isDone(), "buffer filtering is not complete"); - if (!inputGate.isCheckpointingDuringRecoveryEnabled()) { - Preconditions.checkState( - stateConsumedFuture.isDone(), "recovered state is not fully consumed"); - } - - // Extract remaining buffers before conversion. - // These buffers have been filtered but not yet consumed by the Task. - final ArrayDeque remainingBuffers; + public final InputChannel toInputChannel(boolean needsRecovery) throws IOException { + // With checkpointing-during-recovery, data is spilled instead of queued here. synchronized (receivedBuffers) { - remainingBuffers = new ArrayDeque<>(receivedBuffers); - receivedBuffers.clear(); + Preconditions.checkState(receivedBuffers.isEmpty(), "Received buffer should be empty."); } - final InputChannel inputChannel = toInputChannelInternal(remainingBuffers); + final InputChannel inputChannel = toInputChannelInternal(needsRecovery); + inputChannel.setup(); inputChannel.checkpointStopped(lastStoppedCheckpointId); return inputChannel; } @@ -142,23 +121,12 @@ public void checkpointStopped(long checkpointId) { } /** - * Creates the physical InputChannel from this recovered channel. - * - * @param remainingBuffers buffers that have been filtered but not yet consumed by the Task. - * These buffers will be migrated to the new physical channel. - * @return the physical InputChannel (LocalInputChannel or RemoteInputChannel) + * Creates the physical {@link InputChannel}; {@code needsRecovery} controls whether it starts + * in recovery. */ - protected abstract InputChannel toInputChannelInternal(ArrayDeque remainingBuffers) + protected abstract InputChannel toInputChannelInternal(boolean needsRecovery) throws IOException; - /** - * Returns the future that completes when buffer filtering is complete. This future completes - * before stateConsumedFuture, at the point when finishReadRecoveredState() is called. - */ - CompletableFuture getBufferFilteringCompleteFuture() { - return bufferFilteringCompleteFuture; - } - CompletableFuture getStateConsumedFuture() { return stateConsumedFuture; } @@ -166,10 +134,7 @@ CompletableFuture getStateConsumedFuture() { public void onRecoveredStateBuffer(Buffer buffer) { boolean recycleBuffer = true; NetworkActionsLogger.traceRecover( - "InputChannelRecoveredStateHandler#recover", - buffer, - inputGate.getOwningTaskName(), - channelInfo); + "NoSpillingHandler#recover", buffer, inputGate.getOwningTaskName(), channelInfo); try { final boolean wasEmpty; synchronized (receivedBuffers) { @@ -195,21 +160,9 @@ public void onRecoveredStateBuffer(Buffer buffer) { } public void finishReadRecoveredState() throws IOException { - // Adding the event and completing the future must be atomic under receivedBuffers lock. - // Without this, either ordering has a race: - // - event first: task thread consumes EndOfInputChannelStateEvent, which completes - // stateConsumedFuture. When checkpointing during recovery is disabled, - // stateConsumedFuture triggers requestPartitions -> toInputChannel(), which - // fails because bufferFilteringCompleteFuture is not yet done. - // - future first: toInputChannel() extracts buffers before the event is added, - // losing the EndOfInputChannelStateEvent. - // Both toInputChannel() and getNextRecoveredStateBuffer() synchronize on - // receivedBuffers, so holding the same lock here guarantees - // bufferFilteringCompleteFuture is always done before stateConsumedFuture. synchronized (receivedBuffers) { onRecoveredStateBuffer( EventSerializer.toBuffer(EndOfInputChannelStateEvent.INSTANCE, false)); - bufferFilteringCompleteFuture.complete(null); } bufferManager.releaseFloatingBuffers(); LOG.debug("{}/{} finished recovering input.", inputGate.getOwningTaskName(), channelInfo); @@ -229,8 +182,6 @@ private BufferAndAvailability getNextRecoveredStateBuffer() throws IOException { if (next == null) { return null; } else if (isEndOfInputChannelStateEvent(next)) { - Preconditions.checkState( - bufferFilteringCompleteFuture.isDone(), "buffer filtering is not complete"); stateConsumedFuture.complete(null); return null; } else { @@ -307,7 +258,7 @@ boolean isReleased() { } } - void releaseAllResources() throws IOException { + public void releaseAllResources() throws IOException { ArrayDeque releasedBuffers = new ArrayDeque<>(); boolean shouldRelease = false; @@ -338,26 +289,7 @@ public Buffer requestBufferBlocking() throws InterruptedException, IOException { bufferManager.requestExclusiveBuffers(networkBuffersPerChannel); exclusiveBuffersAssigned = true; } - if (!inputGate.isCheckpointingDuringRecoveryEnabled()) { - // When checkpoint-during-recovery is not enabled, the original blocking allocation - // is used as-is — no heap buffer fallback, no behavior change from the legacy path. - return bufferManager.requestBufferBlocking(); - } - // Use heap buffer fallback to avoid deadlock during filtering recovery: the filtering - // thread first requests buffers to read state (pre-filter), then requests more buffers - // to write filtered output (post-filter). If pre-filter buffers exhaust the pool, - // post-filter allocation blocks, stalling the thread so pre-filter buffers can never - // be consumed and released — the thread deadlocks itself. Heap buffers bypass the pool - // so post-filter writes always proceed. Both call sites (getBuffer and filterAndRewrite) - // go through this method, so the fallback applies uniformly. - // TODO: replace heap fallback with disk spilling to bound memory usage in FLINK-38544. - Buffer buffer = bufferManager.requestBuffer(); - if (buffer != null) { - return buffer; - } - MemorySegment memorySegment = - MemorySegmentFactory.allocateUnpooledSegment(MemoryManager.DEFAULT_PAGE_SIZE); - return new NetworkBuffer(memorySegment, FreeingBufferRecycler.INSTANCE); + return bufferManager.requestBufferBlocking(); } @Override diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java index 645446120d09e8..e272aef9efa270 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java @@ -25,6 +25,7 @@ import org.apache.flink.runtime.checkpoint.CheckpointException; import org.apache.flink.runtime.checkpoint.CheckpointFailureReason; import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter; +import org.apache.flink.runtime.checkpoint.channel.RecoveryCheckpointBarrier; import org.apache.flink.runtime.event.AbstractEvent; import org.apache.flink.runtime.event.TaskEvent; import org.apache.flink.runtime.execution.CancelTaskException; @@ -62,9 +63,11 @@ import java.util.List; import java.util.Optional; import java.util.OptionalLong; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; +import java.util.stream.Stream; import static org.apache.flink.runtime.io.network.buffer.Buffer.DataType.RECOVERY_METADATA; import static org.apache.flink.util.Preconditions.checkArgument; @@ -72,7 +75,7 @@ import static org.apache.flink.util.Preconditions.checkState; /** An input channel, which requests a remote partition queue. */ -public class RemoteInputChannel extends InputChannel { +public class RemoteInputChannel extends InputChannel implements RecoverableInputChannel { private static final Logger LOG = LoggerFactory.getLogger(RemoteInputChannel.class); private static final int NONE = -1; @@ -107,6 +110,8 @@ public class RemoteInputChannel extends InputChannel { /** The initial number of exclusive buffers assigned to this channel. */ private final int initialCredit; + private final boolean needsRecovery; + /** The milliseconds timeout for partition request listener in result partition manager. */ private final int partitionRequestListenerTimeout; @@ -123,6 +128,43 @@ public class RemoteInputChannel extends InputChannel { private final ChannelStatePersister channelStatePersister; + /** + * Whether the channel is still replaying recovered state. Recovered buffers delivered by the + * spill drain are appended directly to {@link #receivedBuffers}, so the consume path needs no + * recovery-specific branch. Starts {@code false} for channels that do not need recovery and is + * flipped to {@code false} the moment the consume path polls the {@code + * EndOfFetchedChannelStateEvent} sentinel that the drain appended after the last recovered + * buffer (see {@link #onRecoveredStateConsumed()}). + */ + @GuardedBy("receivedBuffers") + private boolean inRecovery; + + /** + * Sequence number assigned to recovered buffers, starting at {@link Integer#MIN_VALUE}, + * consistent with {@link RecoveredInputChannel}. + */ + @GuardedBy("receivedBuffers") + private int recoverySequenceNumber = Integer.MIN_VALUE; + + /** + * Ordinary (non-priority) upstream events received while recovery is still in progress. They + * cannot enter {@link #receivedBuffers} ahead of the recovered buffers, so they are stashed + * here and appended once recovery delivery finishes. Credit is suppressed during recovery, so + * the upstream can only send events (never data buffers) before {@link + * #finishRecoveredBufferDelivery()}. + */ + @GuardedBy("receivedBuffers") + private final ArrayDeque recoveryEventStash = new ArrayDeque<>(); + + /** + * One-shot latch that opens once the upstream reader is registered and the connection is live + * (signalled by the first {@link #onBuffer} or by {@link #releaseAllResources()}). + * Recovery-side awaiters block on it before handing off; once open, {@link + * CountDownLatch#countDown()} on the hot path is a cheap idempotent no-op, unlike completing a + * {@code CompletableFuture}. + */ + private final CountDownLatch upstreamReady = new CountDownLatch(1); + private long totalQueueSizeInBytes; public RemoteInputChannel( @@ -139,7 +181,7 @@ public RemoteInputChannel( Counter numBytesIn, Counter numBuffersIn, ChannelStateWriter stateWriter, - ArrayDeque initialRecoveredBuffers) { + boolean needsRecovery) { super( inputGate, @@ -156,31 +198,12 @@ public RemoteInputChannel( this.initialCredit = networkBuffersPerChannel; this.connectionId = checkNotNull(connectionId); this.connectionManager = checkNotNull(connectionManager); - this.bufferManager = new BufferManager(inputGate.getMemorySegmentProvider(), this, 0); - this.channelStatePersister = new ChannelStatePersister(stateWriter, getChannelInfo()); - - // Migrate recovered buffers from RecoveredInputChannel if provided. - // These buffers have been filtered but not yet consumed by the Task. - if (!initialRecoveredBuffers.isEmpty()) { - final int expectedCount = initialRecoveredBuffers.size(); - // Sequence number starts at Integer.MIN_VALUE, consistent with RecoveredInputChannel. - int seqNum = Integer.MIN_VALUE; - for (Buffer buffer : initialRecoveredBuffers) { - // subpartitionId is set to 0 for recovered buffers. This is correct because: - // 1) For single-subpartition channels, the only valid subpartition is 0. - // 2) For multi-subpartition channels (consumedSubpartitionIndexSet.size() > 1), - // RecoveryMetadata events embedded in the recovered buffer sequence track - // the actual subpartition context for proper routing. - SequenceBuffer sequenceBuffer = new SequenceBuffer(buffer, seqNum++, 0); - receivedBuffers.add(sequenceBuffer); - totalQueueSizeInBytes += buffer.getSize(); - } - checkState( - receivedBuffers.size() == expectedCount, - "Buffer migration failed: expected %s buffers but got %s", - expectedCount, - receivedBuffers.size()); - } + this.needsRecovery = needsRecovery; + this.bufferManager = + new BufferManager(inputGate.getMemorySegmentProvider(), this, 0, !needsRecovery); + this.channelStatePersister = + new ChannelStatePersister(checkNotNull(stateWriter), getChannelInfo()); + this.inRecovery = needsRecovery; } @VisibleForTesting @@ -188,6 +211,11 @@ void setExpectedSequenceNumber(int expectedSequenceNumber) { this.expectedSequenceNumber = expectedSequenceNumber; } + @VisibleForTesting + void completeUpstreamReadyForTest() { + upstreamReady.countDown(); + } + /** * Setup includes assigning exclusive buffers to this input channel, and this method should be * called only once after this input channel is created. @@ -201,6 +229,113 @@ void setup() throws IOException { bufferManager.requestExclusiveBuffers(initialCredit); } + // ------------------------------------------------------------------------ + // RecoverableInputChannel implementation + // ------------------------------------------------------------------------ + + @Override + public void onRecoveredStateBuffer(Buffer buffer) { + boolean wasEmpty; + synchronized (receivedBuffers) { + if (isReleased.get()) { + buffer.recycleBuffer(); + return; + } + // Migrate recovered buffers from RecoveredInputChannel. These buffers have been + // filtered but not yet consumed by the Task. They are appended to receivedBuffers so + // the consume path stays identical to the non-recovery case. + wasEmpty = appendRecoveredBuffer(buffer); + } + if (wasEmpty) { + notifyChannelNonEmpty(); + } + } + + @Override + public void finishRecoveredBufferDelivery() throws IOException, InterruptedException { + upstreamReady.await(); + boolean wasEmpty; + synchronized (receivedBuffers) { + // A release may have opened the latch instead of the first buffer; bail out so we never + // append to a queue that releaseAllResources() already cleared. + if (isReleased.get()) { + return; + } + checkState(inRecovery, "Recovery delivery already finished."); + // Append the sentinel after the last recovered buffer. The consume path flips out of + // recovery (unstash + reopen credit) only once it polls this sentinel, guaranteeing all + // recovered buffers are consumed first. + wasEmpty = + appendRecoveredBuffer( + EventSerializer.toBuffer( + EndOfFetchedChannelStateEvent.INSTANCE, false)); + } + if (wasEmpty) { + notifyChannelNonEmpty(); + } + } + + /** + * Flips out of recovery once the consume path polls the {@code EndOfFetchedChannelStateEvent} + * sentinel: releases the upstream events stashed during recovery so they are consumed after the + * recovered buffers, then reopens the suppressed credit notifications. + */ + @Override + public void onRecoveredStateConsumed() throws IOException { + synchronized (receivedBuffers) { + checkState(inRecovery, "Recovery already finished."); + inRecovery = false; + recoveryEventStash.forEach(receivedBuffers::add); + recoveryEventStash.clear(); + } + notifyChannelNonEmpty(); + // Credit notifications are suppressed while recovery borrows the exclusive buffers. + bufferManager.enableNotify(); + } + + @Override + public Buffer requestRecoveryBufferBlocking() throws InterruptedException, IOException { + upstreamReady.await(); + // If a release opened the latch instead of the first buffer, requestBufferBlocking() + // detects + // the released channel and throws CancelTaskException. + return bufferManager.requestBufferBlocking(); + } + + @Override + public void insertRecoveryCheckpointBarrierIfInRecovery(long checkpointId) throws IOException { + boolean wasEmpty = false; + synchronized (receivedBuffers) { + if (!isReleased.get() && inRecovery) { + wasEmpty = + appendRecoveredBuffer( + EventSerializer.toBuffer( + new RecoveryCheckpointBarrier(checkpointId), false)); + } + } + if (wasEmpty) { + notifyChannelNonEmpty(); + } + } + + /** + * Appends a recovered buffer (or {@code RecoveryCheckpointBarrier} sentinel) to {@link + * #receivedBuffers} with a recovery sequence number. + * + * @return {@code true} iff {@code receivedBuffers} transitioned from empty to non-empty. + */ + @GuardedBy("receivedBuffers") + private boolean appendRecoveredBuffer(Buffer buffer) { + boolean wasEmpty = receivedBuffers.isEmpty(); + // Recovered buffers carry no per-buffer subpartition id (NONE): they are snapshotted via + // the + // recovery path, never via getInflightBuffersUnsafe which is the only consumer of that + // field. + receivedBuffers.add(new SequenceBuffer(buffer, recoverySequenceNumber++, NONE)); + totalQueueSizeInBytes += buffer.getSize(); + return wasEmpty; + } + // ------------------------------------------------------------------------ // Consume // ------------------------------------------------------------------------ @@ -343,14 +478,18 @@ public boolean isReleased() { @Override void releaseAllResources() throws IOException { if (isReleased.compareAndSet(false, true)) { + // Unblock any thread awaiting upstreamReady (drain still in flight) so it falls + // through and observes the released state instead of deadlocking. + upstreamReady.countDown(); final ArrayDeque releasedBuffers; synchronized (receivedBuffers) { releasedBuffers = - receivedBuffers.stream() + Stream.concat(receivedBuffers.stream(), recoveryEventStash.stream()) .map(sb -> sb.buffer) .collect(Collectors.toCollection(ArrayDeque::new)); receivedBuffers.clear(); + recoveryEventStash.clear(); } bufferManager.releaseAllBuffers(releasedBuffers); @@ -558,6 +697,10 @@ public int getInitialCredit() { return initialCredit; } + public boolean needsRecovery() { + return needsRecovery; + } + public BufferProvider getBufferProvider() throws IOException { if (isReleased.get()) { return null; @@ -596,6 +739,11 @@ public void onBuffer(Buffer buffer, int sequenceNumber, int backlog, int subpart throws IOException { boolean recycleBuffer = true; + // The first buffer from the producer proves the upstream reader is registered and the + // connection is live; release any recovery-side awaiter. On later buffers this is a cheap + // idempotent no-op (the latch count is already zero). + upstreamReady.countDown(); + try { if (expectedSequenceNumber != sequenceNumber) { onError(new BufferReorderingException(expectedSequenceNumber, sequenceNumber)); @@ -633,7 +781,20 @@ public void onBuffer(Buffer buffer, int sequenceNumber, int backlog, int subpart firstPriorityEvent = addPriorityBuffer(sequenceBuffer); recycleBuffer = false; } else { - receivedBuffers.add(sequenceBuffer); + if (inRecovery) { // review nit: merge if + // The upstream has no credit until recovery delivery finishes, so it can + // only + // send events here, never data buffers. Stash ordinary events so they are + // consumed after the recovered buffers; data buffers are a protocol + // violation. + checkState( + !buffer.isBuffer(), // review todo: check what events can be sent + "Received live data buffer during recovery on channel %s", + getChannelInfo()); + recoveryEventStash.add(sequenceBuffer); + } else { + receivedBuffers.add(sequenceBuffer); + } recycleBuffer = false; if (dataType.requiresAnnouncement()) { firstPriorityEvent = addPriorityBuffer(announce(sequenceBuffer)); @@ -713,30 +874,127 @@ private void checkAnnouncedOnlyOnce(SequenceBuffer sequenceBuffer) { } /** - * Spills all queued buffers on checkpoint start. If barrier has already been received (and - * reordered), spill only the overtaken buffers. + * Persists inflight data on checkpoint start. During recovery, persists recovered buffers + * before the matching RecoveryCheckpointBarrier sentinel; after recovery, uses the normal + * remote-channel barrier sequence tracking and persists overtaken live buffers. */ + // review todo public void checkpointStarted(CheckpointBarrier barrier) throws CheckpointException { - synchronized (receivedBuffers) { - if (barrier.getId() < lastBarrierId) { - throw new CheckpointException( - String.format( - "Sequence number for checkpoint %d is not known (it was likely been overwritten by a newer checkpoint %d)", - barrier.getId(), lastBarrierId), - CheckpointFailureReason - .CHECKPOINT_SUBSUMED); // currently, at most one active unaligned - // checkpoint is possible - } else if (barrier.getId() > lastBarrierId) { - // This channel has received some obsolete barrier, older compared to the - // checkpointId - // which we are processing right now, and we should ignore that obsoleted checkpoint - // barrier sequence number. - resetLastBarrier(); + try { + List toPersist; + synchronized (receivedBuffers) { + if (inRecovery) { + toPersist = collectPreRecoveryBarrier(barrier.getId()); + } else { + if (barrier.getId() < lastBarrierId) { + // Currently, at most one active unaligned checkpoint is possible. + throw new CheckpointException( + String.format( + "Sequence number for checkpoint %d is not known (it was likely been overwritten by a newer checkpoint %d)", + barrier.getId(), lastBarrierId), + CheckpointFailureReason.CHECKPOINT_SUBSUMED); + } else if (barrier.getId() > lastBarrierId) { + // This channel has received some obsolete barrier, older compared to the + // checkpointId which we are processing right now, and we should ignore that + // obsoleted checkpoint barrier sequence number. + resetLastBarrier(); + } + toPersist = getInflightBuffersUnsafe(barrier.getId()); + } + channelStatePersister.startPersisting(barrier.getId(), toPersist); + if (inRecovery) { + // Recovered inflight buffers are collected in one shot and the upstream sends + // no + // data during recovery, so close the persist window immediately to keep the + // persister from carrying a pending state into later checkpoints. + channelStatePersister.stopPersisting(barrier.getId()); + } } + } catch (IOException e) { + throw new CheckpointException( + "Failed to extract recovered buffers for checkpoint " + barrier.getId(), + CheckpointFailureReason.CHECKPOINT_DECLINED, + e); + } + } - channelStatePersister.startPersisting( - barrier.getId(), getInflightBuffersUnsafe(barrier.getId())); + /** + * Walks {@link #receivedBuffers} (skipping priority events) up to the {@link + * RecoveryCheckpointBarrier} sentinel matching {@code checkpointId}, retaining each pre-barrier + * recovered data buffer and removing the sentinel. During recovery the upstream has no credit, + * so {@code receivedBuffers} holds only recovered buffers, sentinels, and priority events — no + * live data buffers. + * + * @throws IOException if no sentinel matching {@code checkpointId} is found (the snapshot + * protocol guarantees one must be present while the channel is in recovery). + */ + @GuardedBy("receivedBuffers") + private List collectPreRecoveryBarrier(long checkpointId) + throws IOException, CheckpointException { + assert Thread.holdsLock(receivedBuffers); + List retained = new ArrayList<>(); + SequenceBuffer sentinel = null; + try { + Iterator it = receivedBuffers.iterator(); + // Priority events are stored separately at the head and never carry recovered data. + Iterators.advance(it, receivedBuffers.getNumPriorityElements()); + while (it.hasNext()) { + SequenceBuffer sb = it.next(); + if (isRecoveryCheckpointBarrier(sb.buffer, checkpointId)) { + sentinel = sb; + break; + } + // Skip non-data events (e.g. the EndOfFetchedChannelStateEvent sentinel appended + // after the recovered buffers): only recovered data buffers are snapshotted. + if (sb.buffer.isBuffer()) { + retained.add(sb.buffer.retainBuffer()); + } + } + } catch (IOException e) { + releaseRetainedBuffers(retained); + throw e; + } + if (sentinel == null) { + releaseRetainedBuffers(retained); + // The during-recovery sentinel for this checkpoint was never inserted into this channel + // (the recovery checkpoint trigger had already transitioned away from the drainer while + // this channel was still in recovery). The channel is simply not ready to snapshot + // recovered state for this checkpoint yet, so decline as TASK_NOT_READY: that is not + // counted against the tolerable-failure threshold, so the checkpoint is deferred and + // retried instead of failing the job. The recovered buffers remain queued and are + // captured by a later checkpoint, so no in-flight data is lost. + throw new CheckpointException( + "RecoveryCheckpointBarrier for checkpoint " + + checkpointId + + " not yet present in channel " + + getChannelInfo(), + CheckpointFailureReason.CHECKPOINT_DECLINED_TASK_NOT_READY); + } + // receivedBuffers is a PrioritizedDeque whose iterator() is read-only; remove the matched + // sentinel by identity through its mutable removal API. + final SequenceBuffer matched = sentinel; + receivedBuffers.getAndRemove(sb -> sb == matched); + totalQueueSizeInBytes -= matched.buffer.getSize(); + matched.buffer.recycleBuffer(); + return retained; + } + + private static void releaseRetainedBuffers(List retained) { + for (Buffer buffer : retained) { + buffer.recycleBuffer(); + } + } + + private static boolean isRecoveryCheckpointBarrier(Buffer b, long checkpointId) + throws IOException { + if (b.isBuffer()) { + return false; } + AbstractEvent event = + EventSerializer.fromBuffer(b, RecoveryCheckpointBarrier.class.getClassLoader()); + b.setReaderIndex(0); + return event instanceof RecoveryCheckpointBarrier + && ((RecoveryCheckpointBarrier) event).getCheckpointId() == checkpointId; } public void checkpointStopped(long checkpointId) { @@ -909,13 +1167,13 @@ public void onError(Throwable cause) { } /** - * When receivedBuffers contains migrated buffers from RecoveredInputChannel, they can be read - * before requestSubpartitions(). In that case only check for errors. Once migrated buffers are - * drained, require full client initialization check. + * Allows reads while recovery data or already queued network data is available before the + * remote partition request is fully initialized. If neither recovery nor queued data can + * satisfy the read, require the partition request client to be initialized. */ private void checkReadability() throws IOException { assert Thread.holdsLock(receivedBuffers); - if (receivedBuffers.isEmpty()) { + if (!inRecovery && receivedBuffers.isEmpty()) { checkPartitionRequestQueueInitialized(); } else { checkError(); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteRecoveredInputChannel.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteRecoveredInputChannel.java index 2cfff6f5e7972a..b76aa347fe6445 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteRecoveredInputChannel.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteRecoveredInputChannel.java @@ -20,13 +20,11 @@ import org.apache.flink.runtime.io.network.ConnectionID; import org.apache.flink.runtime.io.network.ConnectionManager; -import org.apache.flink.runtime.io.network.buffer.Buffer; import org.apache.flink.runtime.io.network.metrics.InputChannelMetrics; import org.apache.flink.runtime.io.network.partition.ResultPartitionID; import org.apache.flink.runtime.io.network.partition.ResultSubpartitionIndexSet; import java.io.IOException; -import java.util.ArrayDeque; import static org.apache.flink.util.Preconditions.checkNotNull; @@ -68,8 +66,7 @@ public class RemoteRecoveredInputChannel extends RecoveredInputChannel { } @Override - protected InputChannel toInputChannelInternal(ArrayDeque remainingBuffers) - throws IOException { + protected InputChannel toInputChannelInternal(boolean needsRecovery) throws IOException { RemoteInputChannel remoteInputChannel = new RemoteInputChannel( inputGate, @@ -85,8 +82,7 @@ protected InputChannel toInputChannelInternal(ArrayDeque remainingBuffer numBytesIn, numBuffersIn, channelStateWriter, - remainingBuffers); - remoteInputChannel.setup(); + needsRecovery); return remoteInputChannel; } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java index 438efa2f58bd5b..979a674751698e 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java @@ -243,7 +243,9 @@ public class SingleInputGate extends IndexedInputGate { */ private final int[] endOfPartitions; - private volatile boolean checkpointingDuringRecoveryEnabled = false; + // review: this should be a parameter to inputGate.finishReadRecoveredState() or a new method + // finishFetchState() + // private volatile boolean checkpointingDuringRecoveryEnabled = false; public SingleInputGate( String owningTaskName, @@ -332,32 +334,12 @@ public CompletableFuture getStateConsumedFuture() { } @Override - public void setCheckpointingDuringRecoveryEnabled(boolean enabled) { - this.checkpointingDuringRecoveryEnabled = enabled; - } - - @Override - public boolean isCheckpointingDuringRecoveryEnabled() { - return checkpointingDuringRecoveryEnabled; - } - - @Override - public CompletableFuture getBufferFilteringCompleteFuture() { - synchronized (requestLock) { - List> futures = new ArrayList<>(numberOfInputChannels); - for (InputChannel inputChannel : inputChannels()) { - if (inputChannel instanceof RecoveredInputChannel) { - futures.add( - ((RecoveredInputChannel) inputChannel) - .getBufferFilteringCompleteFuture()); - } - } - return CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])); - } + public void requestPartitions() { + requestPartitions(false); } @Override - public void requestPartitions() { + public void requestPartitions(boolean needsRecovery) { synchronized (requestLock) { if (!requestedPartitionsFlag) { if (closeFuture.isDone()) { @@ -376,7 +358,7 @@ public void requestPartitions() { numInputChannels, numberOfInputChannels)); } - convertRecoveredInputChannels(); + convertRecoveredInputChannels(needsRecovery); internalRequestPartitions(); } @@ -390,12 +372,18 @@ public void requestPartitions() { } } + @VisibleForTesting + public void convertRecoveredInputChannels() { + convertRecoveredInputChannels(false); + } + /** * Converts all {@link RecoveredInputChannel}s to their real channel types ({@link - * LocalInputChannel} or {@link RemoteInputChannel}). + * LocalInputChannel} or {@link RemoteInputChannel}). {@code needsRecovery} controls whether the + * converted physical channels start in recovery. */ @VisibleForTesting - public void convertRecoveredInputChannels() { + public void convertRecoveredInputChannels(boolean needsRecovery) { LOG.debug("Converting recovered input channels ({} channels)", getNumberOfInputChannels()); for (Map inputChannelsForCurrentPartition : inputChannels.values()) { @@ -413,7 +401,7 @@ public void convertRecoveredInputChannels() { // order with onRecoveredStateBuffer() which acquires receivedBuffers // first and then inputChannelsWithData. InputChannel realInputChannel = - ((RecoveredInputChannel) inputChannel).toInputChannel(); + ((RecoveredInputChannel) inputChannel).toInputChannel(needsRecovery); inputChannel.releaseAllResources(); int buffersInUseCount = realInputChannel.getBuffersInUseCount(); @@ -595,6 +583,11 @@ public InputChannel getChannel(int channelIndex) { return channels[channelIndex]; } + @Override + public InputChannel getChannel(InputChannelInfo channelInfo) { + return channels[channelInfo.getInputChannelIdx()]; + } + // ------------------------------------------------------------------------ // Setup/Life-cycle // ------------------------------------------------------------------------ diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java index dda71c63be38f4..b7a708f38f27e8 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java @@ -176,6 +176,13 @@ public InputChannel getChannel(int channelIndex) { .getChannel(channelIndex - inputGateChannelIndexOffsets[gateIndex]); } + @Override + public InputChannel getChannel(InputChannelInfo channelInfo) { + // The member gate's local channel index is carried directly in channelInfo, so resolve + // through the owning gate instead of the gate-global getChannel(int) addressing. + return inputGatesByGateIndex.get(channelInfo.getGateIdx()).getChannel(channelInfo); + } + @Override public boolean isFinished() { return inputGatesWithRemainingData.isEmpty(); @@ -351,18 +358,14 @@ public CompletableFuture getStateConsumedFuture() { } @Override - public CompletableFuture getBufferFilteringCompleteFuture() { - return CompletableFuture.allOf( - inputGatesByGateIndex.values().stream() - .map(InputGate::getBufferFilteringCompleteFuture) - .collect(Collectors.toList()) - .toArray(new CompletableFuture[] {})); + public void requestPartitions() throws IOException { + requestPartitions(false); } @Override - public void requestPartitions() throws IOException { + public void requestPartitions(boolean needsRecovery) throws IOException { for (InputGate inputGate : inputGatesByGateIndex.values()) { - inputGate.requestPartitions(); + inputGate.requestPartitions(needsRecovery); } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnknownInputChannel.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnknownInputChannel.java index 15182cedadb9fc..2d717f66fe8a88 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnknownInputChannel.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnknownInputChannel.java @@ -35,7 +35,6 @@ import javax.annotation.Nullable; import java.io.IOException; -import java.util.ArrayDeque; import java.util.Optional; import static org.apache.flink.runtime.checkpoint.CheckpointFailureReason.CHECKPOINT_DECLINED_TASK_NOT_READY; @@ -171,21 +170,23 @@ public String toString() { public RemoteInputChannel toRemoteInputChannel( ConnectionID producerAddress, ResultPartitionID resultPartitionID) { - return new RemoteInputChannel( - inputGate, - getChannelIndex(), - resultPartitionID, - consumedSubpartitionIndexSet, - checkNotNull(producerAddress), - connectionManager, - initialBackoff, - maxBackoff, - partitionRequestListenerTimeout, - networkBuffersPerChannel, - metrics.getNumBytesInRemoteCounter(), - metrics.getNumBuffersInRemoteCounter(), - channelStateWriter == null ? ChannelStateWriter.NO_OP : channelStateWriter, - new ArrayDeque<>()); + RemoteInputChannel channel = + new RemoteInputChannel( + inputGate, + getChannelIndex(), + resultPartitionID, + consumedSubpartitionIndexSet, + checkNotNull(producerAddress), + connectionManager, + initialBackoff, + maxBackoff, + partitionRequestListenerTimeout, + networkBuffersPerChannel, + metrics.getNumBytesInRemoteCounter(), + metrics.getNumBuffersInRemoteCounter(), + channelStateWriter == null ? ChannelStateWriter.NO_OP : channelStateWriter, + false); + return channel; } public LocalInputChannel toLocalInputChannel(ResultPartitionID resultPartitionID) { @@ -201,7 +202,8 @@ public LocalInputChannel toLocalInputChannel(ResultPartitionID resultPartitionID metrics.getNumBytesInLocalCounter(), metrics.getNumBuffersInLocalCounter(), channelStateWriter == null ? ChannelStateWriter.NO_OP : channelStateWriter, - new ArrayDeque<>()); + networkBuffersPerChannel, + false); } @Override diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/InputGateWithMetrics.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/InputGateWithMetrics.java index bff412f53b3303..5d1ca9c680b81b 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/InputGateWithMetrics.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/InputGateWithMetrics.java @@ -80,6 +80,11 @@ public InputChannel getChannel(int channelIndex) { return inputGate.getChannel(channelIndex); } + @Override + public InputChannel getChannel(InputChannelInfo channelInfo) { + return inputGate.getChannel(channelInfo); + } + @Override public int getGateIndex() { return inputGate.getGateIndex(); @@ -121,13 +126,13 @@ public CompletableFuture getStateConsumedFuture() { } @Override - public CompletableFuture getBufferFilteringCompleteFuture() { - return inputGate.getBufferFilteringCompleteFuture(); + public void requestPartitions() throws IOException { + inputGate.requestPartitions(); } @Override - public void requestPartitions() throws IOException { - inputGate.requestPartitions(); + public void requestPartitions(boolean needsRecovery) throws IOException { + inputGate.requestPartitions(needsRecovery); } @Override @@ -165,16 +170,6 @@ public void finishReadRecoveredState() throws IOException { inputGate.finishReadRecoveredState(); } - @Override - public void setCheckpointingDuringRecoveryEnabled(boolean enabled) { - inputGate.setCheckpointingDuringRecoveryEnabled(enabled); - } - - @Override - public boolean isCheckpointingDuringRecoveryEnabled() { - return inputGate.isCheckpointingDuringRecoveryEnabled(); - } - private BufferOrEvent updateMetrics(BufferOrEvent bufferOrEvent) { int incomingDataSize = bufferOrEvent.getSize(); diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/AbstractStreamTaskNetworkInput.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/AbstractStreamTaskNetworkInput.java index a3743edec7f940..26f0b5f8dfff08 100644 --- a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/AbstractStreamTaskNetworkInput.java +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/AbstractStreamTaskNetworkInput.java @@ -326,8 +326,7 @@ public void close() throws IOException { // terminate the TM Exception err = null; for (InputChannelInfo channelInfo : new ArrayList<>(recordDeserializers.keySet())) { - final boolean hadError = - checkpointedInputGate.getChannel(channelInfo.getInputChannelIdx()).hasError(); + final boolean hadError = checkpointedInputGate.getChannel(channelInfo).hasError(); try { releaseDeserializer(channelInfo); } catch (Exception e) { diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/AlternatingCollectingBarriers.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/AlternatingCollectingBarriers.java index 8ca37055bc3886..c918f9db0ee8ad 100644 --- a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/AlternatingCollectingBarriers.java +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/AlternatingCollectingBarriers.java @@ -21,7 +21,6 @@ import org.apache.flink.runtime.checkpoint.CheckpointException; import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo; import org.apache.flink.runtime.io.network.api.CheckpointBarrier; -import org.apache.flink.runtime.io.network.partition.consumer.CheckpointableInput; import java.io.IOException; @@ -44,9 +43,7 @@ public BarrierHandlerState alignedCheckpointTimeout( state.prioritizeAllAnnouncements(); CheckpointBarrier unalignedBarrier = checkpointBarrier.asUnaligned(); controller.initInputsCheckpoint(unalignedBarrier); - for (CheckpointableInput input : state.getInputs()) { - input.checkpointStarted(unalignedBarrier); - } + state.onCheckpointStartedForAllInputs(unalignedBarrier); controller.triggerGlobalCheckpoint(unalignedBarrier); return new AlternatingCollectingBarriersUnaligned(true, state); } diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/AlternatingWaitingForFirstBarrierUnaligned.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/AlternatingWaitingForFirstBarrierUnaligned.java index af04f4f8107d2a..1e12b4757e876b 100644 --- a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/AlternatingWaitingForFirstBarrierUnaligned.java +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/AlternatingWaitingForFirstBarrierUnaligned.java @@ -72,9 +72,7 @@ public BarrierHandlerState barrierReceived( CheckpointBarrier unalignedBarrier = checkpointBarrier.asUnaligned(); controller.initInputsCheckpoint(unalignedBarrier); - for (CheckpointableInput input : channelState.getInputs()) { - input.checkpointStarted(unalignedBarrier); - } + channelState.onCheckpointStartedForAllInputs(unalignedBarrier); controller.triggerGlobalCheckpoint(unalignedBarrier); if (controller.allBarriersReceived()) { for (CheckpointableInput input : channelState.getInputs()) { diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/ChannelState.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/ChannelState.java index 8f6bb211d2bf30..3f289895ac34de 100644 --- a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/ChannelState.java +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/ChannelState.java @@ -18,9 +18,14 @@ package org.apache.flink.streaming.runtime.io.checkpointing; +import org.apache.flink.runtime.checkpoint.CheckpointException; +import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter; +import org.apache.flink.runtime.checkpoint.channel.FetchedChannelStateReader; import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo; +import org.apache.flink.runtime.checkpoint.channel.RecoveryCheckpointTrigger; import org.apache.flink.runtime.io.network.api.CheckpointBarrier; import org.apache.flink.runtime.io.network.partition.consumer.CheckpointableInput; +import org.apache.flink.util.ExceptionUtils; import java.io.IOException; import java.util.HashMap; @@ -28,6 +33,7 @@ import java.util.Map; import java.util.Set; +import static org.apache.flink.util.Preconditions.checkNotNull; import static org.apache.flink.util.Preconditions.checkState; /** @@ -35,6 +41,7 @@ * and {@link AbstractAlternatingAlignedBarrierHandlerState}. */ final class ChannelState { + private final Map sequenceNumberInAnnouncedChannels = new HashMap<>(); @@ -47,8 +54,21 @@ final class ChannelState { private final CheckpointableInput[] inputs; + private final RecoveryCheckpointTrigger recoveryCheckpointTrigger; + + private final ChannelStateWriter channelStateWriter; + public ChannelState(CheckpointableInput[] inputs) { + this(inputs, RecoveryCheckpointTrigger.NO_OP, ChannelStateWriter.NO_OP); + } + + public ChannelState( + CheckpointableInput[] inputs, + RecoveryCheckpointTrigger recoveryCheckpointTrigger, + ChannelStateWriter channelStateWriter) { this.inputs = inputs; + this.recoveryCheckpointTrigger = checkNotNull(recoveryCheckpointTrigger); + this.channelStateWriter = checkNotNull(channelStateWriter); } public void blockChannel(InputChannelInfo channelInfo) { @@ -98,4 +118,34 @@ public ChannelState emptyState() { sequenceNumberInAnnouncedChannels.clear(); return this; } + + /** + * Transfers spill-snapshot ownership to the writer after all inputs observe checkpoint start. + */ + public void onCheckpointStartedForAllInputs(CheckpointBarrier barrier) + throws CheckpointException, IOException { + long cpId = barrier.getId(); + FetchedChannelStateReader snap = null; + try { + snap = recoveryCheckpointTrigger.snapshotAndInsertBarriers(cpId); + + for (CheckpointableInput input : inputs) { + input.checkpointStarted(barrier); + } + + channelStateWriter.addInputDataFromSpill(cpId, snap); + } catch (Throwable t) { + if (snap != null) { + try { + snap.close(); + } catch (Exception suppressed) { + t.addSuppressed(suppressed); + } + } + if (t instanceof CheckpointException) { + throw (CheckpointException) t; + } + ExceptionUtils.rethrowIOException(t); + } + } } diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/CheckpointedInputGate.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/CheckpointedInputGate.java index ad0dcb5b0bb74b..26f335386a198d 100644 --- a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/CheckpointedInputGate.java +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/CheckpointedInputGate.java @@ -29,9 +29,11 @@ import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent; import org.apache.flink.runtime.io.network.api.EventAnnouncement; import org.apache.flink.runtime.io.network.partition.consumer.BufferOrEvent; +import org.apache.flink.runtime.io.network.partition.consumer.EndOfFetchedChannelStateEvent; import org.apache.flink.runtime.io.network.partition.consumer.EndOfOutputChannelStateEvent; import org.apache.flink.runtime.io.network.partition.consumer.InputChannel; import org.apache.flink.runtime.io.network.partition.consumer.InputGate; +import org.apache.flink.runtime.io.network.partition.consumer.RecoverableInputChannel; import org.apache.flink.streaming.runtime.io.StreamTaskNetworkInput; import org.slf4j.Logger; @@ -202,6 +204,15 @@ private Optional handleEvent(BufferOrEvent bufferOrEvent) throws bufferOrEvent.getChannelInfo()); } else if (bufferOrEvent.getEvent().getClass() == EndOfOutputChannelStateEvent.class) { upstreamRecoveryTracker.handleEndOfRecovery(bufferOrEvent.getChannelInfo()); + } else if (eventClass == EndOfFetchedChannelStateEvent.class) { + // Tail of the recovered buffers: only a RecoverableInputChannel can produce this + // sentinel, so anything else here is a bug rather than something to tolerate. + InputChannel channel = inputGate.getChannel(bufferOrEvent.getChannelInfo()); + checkState( + channel instanceof RecoverableInputChannel, + "EndOfFetchedChannelStateEvent received on a non-recoverable channel %s", + bufferOrEvent.getChannelInfo()); + ((RecoverableInputChannel) channel).onRecoveredStateConsumed(); } return Optional.of(bufferOrEvent); } @@ -296,6 +307,10 @@ public InputChannel getChannel(int channelIndex) { return inputGate.getChannel(channelIndex); } + public InputChannel getChannel(InputChannelInfo channelInfo) { + return inputGate.getChannel(channelInfo); + } + public List getChannelInfos() { return inputGate.getChannelInfos(); } diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/InputProcessorUtil.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/InputProcessorUtil.java index 6d8a0268dadecd..e299399e9936d6 100644 --- a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/InputProcessorUtil.java +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/InputProcessorUtil.java @@ -22,6 +22,8 @@ import org.apache.flink.configuration.CheckpointingOptions; import org.apache.flink.configuration.Configuration; import org.apache.flink.core.execution.CheckpointingMode; +import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter; +import org.apache.flink.runtime.checkpoint.channel.RecoveryCheckpointTrigger; import org.apache.flink.runtime.io.network.partition.consumer.CheckpointableInput; import org.apache.flink.runtime.io.network.partition.consumer.IndexedInputGate; import org.apache.flink.runtime.io.network.partition.consumer.InputGate; @@ -88,6 +90,30 @@ public static CheckpointBarrierHandler createCheckpointBarrierHandler( List> sourceInputs, MailboxExecutor mailboxExecutor, TimerService timerService) { + return createCheckpointBarrierHandler( + toNotifyOnCheckpoint, + jobConf, + config, + checkpointCoordinator, + taskName, + inputGates, + sourceInputs, + mailboxExecutor, + timerService, + RecoveryCheckpointTrigger.NO_OP); + } + + public static CheckpointBarrierHandler createCheckpointBarrierHandler( + CheckpointableTask toNotifyOnCheckpoint, + Configuration jobConf, + StreamConfig config, + SubtaskCheckpointCoordinator checkpointCoordinator, + String taskName, + List[] inputGates, + List> sourceInputs, + MailboxExecutor mailboxExecutor, + TimerService timerService, + RecoveryCheckpointTrigger recoveryCheckpointTrigger) { CheckpointableInput[] inputs = Stream.concat( @@ -98,6 +124,7 @@ public static CheckpointBarrierHandler createCheckpointBarrierHandler( Clock clock = SystemClock.getInstance(); CheckpointingMode checkpointingMode = CheckpointingOptions.getCheckpointingMode(jobConf); + ChannelStateWriter channelStateWriter = checkpointCoordinator.getChannelStateWriter(); switch (checkpointingMode) { case EXACTLY_ONCE: int numberOfChannels = @@ -115,7 +142,9 @@ public static CheckpointBarrierHandler createCheckpointBarrierHandler( timerService, inputs, clock, - numberOfChannels); + numberOfChannels, + recoveryCheckpointTrigger, + channelStateWriter); case AT_LEAST_ONCE: if (CheckpointingOptions.isUnalignedCheckpointEnabled(jobConf)) { throw new IllegalStateException( @@ -148,7 +177,9 @@ private static SingleCheckpointBarrierHandler createBarrierHandler( TimerService timerService, CheckpointableInput[] inputs, Clock clock, - int numberOfChannels) { + int numberOfChannels, + RecoveryCheckpointTrigger recoveryCheckpointTrigger, + ChannelStateWriter channelStateWriter) { boolean enableCheckpointAfterTasksFinished = config.getConfiguration() .get(CheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH); @@ -161,6 +192,8 @@ private static SingleCheckpointBarrierHandler createBarrierHandler( numberOfChannels, BarrierAlignmentUtil.createRegisterTimerCallback(mailboxExecutor, timerService), enableCheckpointAfterTasksFinished, + recoveryCheckpointTrigger, + channelStateWriter, inputs); } else { return SingleCheckpointBarrierHandler.aligned( diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/SingleCheckpointBarrierHandler.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/SingleCheckpointBarrierHandler.java index c1bd9ad6c85616..547941c24474d8 100644 --- a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/SingleCheckpointBarrierHandler.java +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/SingleCheckpointBarrierHandler.java @@ -22,7 +22,9 @@ import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.runtime.checkpoint.CheckpointException; import org.apache.flink.runtime.checkpoint.CheckpointFailureReason; +import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter; import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo; +import org.apache.flink.runtime.checkpoint.channel.RecoveryCheckpointTrigger; import org.apache.flink.runtime.io.network.api.CancelCheckpointMarker; import org.apache.flink.runtime.io.network.api.CheckpointBarrier; import org.apache.flink.runtime.io.network.partition.consumer.CheckpointableInput; @@ -119,6 +121,8 @@ public static SingleCheckpointBarrierHandler createUnalignedCheckpointBarrierHan "Strictly unaligned checkpoints should never register any callbacks"); }, enableCheckpointsAfterTasksFinish, + RecoveryCheckpointTrigger.NO_OP, + ChannelStateWriter.NO_OP, inputs); } @@ -130,6 +134,8 @@ public static SingleCheckpointBarrierHandler unaligned( int numOpenChannels, DelayableTimer registerTimer, boolean enableCheckpointAfterTasksFinished, + RecoveryCheckpointTrigger recoveryCheckpointTrigger, + ChannelStateWriter channelStateWriter, CheckpointableInput... inputs) { return new SingleCheckpointBarrierHandler( taskName, @@ -137,7 +143,9 @@ public static SingleCheckpointBarrierHandler unaligned( checkpointCoordinator, clock, numOpenChannels, - new AlternatingWaitingForFirstBarrierUnaligned(false, new ChannelState(inputs)), + new AlternatingWaitingForFirstBarrierUnaligned( + false, + new ChannelState(inputs, recoveryCheckpointTrigger, channelStateWriter)), false, registerTimer, inputs, @@ -173,6 +181,8 @@ public static SingleCheckpointBarrierHandler alternating( int numOpenChannels, DelayableTimer registerTimer, boolean enableCheckpointAfterTasksFinished, + RecoveryCheckpointTrigger recoveryCheckpointTrigger, + ChannelStateWriter channelStateWriter, CheckpointableInput... inputs) { return new SingleCheckpointBarrierHandler( taskName, @@ -180,7 +190,8 @@ public static SingleCheckpointBarrierHandler alternating( checkpointCoordinator, clock, numOpenChannels, - new AlternatingWaitingForFirstBarrier(new ChannelState(inputs)), + new AlternatingWaitingForFirstBarrier( + new ChannelState(inputs, recoveryCheckpointTrigger, channelStateWriter)), true, registerTimer, inputs, diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilterContext.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilterContext.java index e207eb6213edc5..ae0e364ab712ff 100644 --- a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilterContext.java +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilterContext.java @@ -91,7 +91,7 @@ public int getNumberOfChannels() { /** Maximum parallelism for configuring partitioners. */ private final int maxParallelism; - /** Temporary directories for spilling spanning records. Can be empty but never null. */ + /** Temporary directories for spilling spanning records. Never null or empty. */ private final String[] tmpDirectories; /** Whether unaligned checkpointing during recovery is enabled. */ @@ -99,8 +99,8 @@ public int getNumberOfChannels() { /** * Network buffer memory segment size in bytes (from taskmanager.memory.segment-size). Used to - * size the reusable heap source buffer in {@code InputChannelRecoveredStateHandler} so it - * matches the network buffer size. + * size the reusable heap source buffer in {@code SpillingWithFilteringHandler} so it matches + * the network buffer size. */ private final int memorySegmentSize; @@ -112,8 +112,9 @@ public int getNumberOfChannels() { * @param rescalingDescriptor Descriptor containing rescaling information. Not null. * @param subtaskIndex Current subtask index. * @param maxParallelism Maximum parallelism. - * @param tmpDirectories Temporary directories for spilling spanning records. Can be null - * (converted to empty array). + * @param tmpDirectories Temporary directories for spilling spanning records. Not null or empty; + * sourced from {@code IOManager.getSpillingDirectoriesPaths()}, which is guaranteed + * non-empty by the task manager I/O configuration. * @param checkpointingDuringRecoveryEnabled Whether unaligned checkpointing during recovery is * enabled. * @param memorySegmentSize Network buffer memory segment size in bytes. Must be positive. @@ -130,8 +131,17 @@ public RecordFilterContext( this.rescalingDescriptor = checkNotNull(rescalingDescriptor); this.subtaskIndex = subtaskIndex; this.maxParallelism = maxParallelism; - this.tmpDirectories = tmpDirectories != null ? tmpDirectories : new String[0]; this.checkpointingDuringRecoveryEnabled = checkpointingDuringRecoveryEnabled; + if (checkpointingDuringRecoveryEnabled) { + // tmpDirectories are only used by the spilling (checkpointing-during-recovery) path. + checkArgument( + checkNotNull(tmpDirectories).length > 0, "tmpDirectories must not be empty"); + this.tmpDirectories = tmpDirectories.clone(); + } else { + // A disabled context never spills, so it needs no spill directories; tolerate a + // null/empty value (e.g. an environment without IOManager spilling directories). + this.tmpDirectories = tmpDirectories == null ? new String[0] : tmpDirectories.clone(); + } checkArgument( memorySegmentSize > 0, "memorySegmentSize must be positive: %s", memorySegmentSize); this.memorySegmentSize = memorySegmentSize; @@ -202,7 +212,7 @@ public int getMaxParallelism() { /** * Gets the temporary directories for spilling spanning records. * - * @return The temporary directories, never null (may be empty array). + * @return The temporary directories, never null or empty. */ public String[] getTmpDirectories() { return tmpDirectories; @@ -235,15 +245,19 @@ public boolean isAmbiguous(int gateIndex, int oldSubtaskIndex) { *

The returned context has empty inputConfigs and enabled=false, so {@link * #isCheckpointingDuringRecoveryEnabled()} will always return false. * + * @param tmpDirectories Temporary directories for spilling spanning records. Not null or empty; + * callers pass the real I/O manager spilling directories so the resulting context satisfies + * the same invariant as an enabled context, regardless of which downstream path consumes + * it. * @return A disabled RecordFilterContext. */ - public static RecordFilterContext disabled() { + public static RecordFilterContext disabled(String[] tmpDirectories) { return new RecordFilterContext( new InputFilterConfig[0], InflightDataRescalingDescriptor.NO_RESCALE, 0, 0, - new String[0], + tmpDirectories, false, MemoryManager.DEFAULT_PAGE_SIZE); } diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/MultipleInputStreamTask.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/MultipleInputStreamTask.java index 6e45024e1ca558..e27876eccd749a 100644 --- a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/MultipleInputStreamTask.java +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/MultipleInputStreamTask.java @@ -149,7 +149,8 @@ protected void createInputProcessor( inputGates, operatorChain.getSourceTaskInputs(), mainMailboxExecutor, - timerService); + timerService, + getRecoveryCheckpointTrigger()); CheckpointedInputGate[] checkpointedInputGates = InputProcessorUtil.createCheckpointedMultipleInputGate( diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTask.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTask.java index 009f48b082f755..eb772462f1f922 100644 --- a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTask.java +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTask.java @@ -174,7 +174,8 @@ private CheckpointedInputGate createCheckpointedInputGate() { new List[] {Arrays.asList(inputGates)}, Collections.emptyList(), mainMailboxExecutor, - systemTimerService); + systemTimerService, + getRecoveryCheckpointTrigger()); CheckpointedInputGate[] checkpointedInputGates = InputProcessorUtil.createCheckpointedMultipleInputGate( diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java index e06164125fd417..39c68db32b9b5a 100644 --- a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java @@ -43,7 +43,10 @@ import org.apache.flink.runtime.checkpoint.SnapshotType; import org.apache.flink.runtime.checkpoint.SubTaskInitializationMetricsBuilder; import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter; +import org.apache.flink.runtime.checkpoint.channel.FetchedChannelState; +import org.apache.flink.runtime.checkpoint.channel.FetchedChannelStateDrainer; import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo; +import org.apache.flink.runtime.checkpoint.channel.RecoveryCheckpointTrigger; import org.apache.flink.runtime.checkpoint.channel.SequentialChannelStateReader; import org.apache.flink.runtime.checkpoint.filemerging.FileMergingSnapshotManager; import org.apache.flink.runtime.execution.CancelTaskException; @@ -62,7 +65,9 @@ import org.apache.flink.runtime.io.network.partition.ChannelStateHolder; import org.apache.flink.runtime.io.network.partition.ResultPartitionType; import org.apache.flink.runtime.io.network.partition.consumer.IndexedInputGate; +import org.apache.flink.runtime.io.network.partition.consumer.InputChannel; import org.apache.flink.runtime.io.network.partition.consumer.InputGate; +import org.apache.flink.runtime.io.network.partition.consumer.RecoverableInputChannel; import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.jobgraph.tasks.CheckpointableTask; import org.apache.flink.runtime.jobgraph.tasks.CoordinatedTask; @@ -141,6 +146,7 @@ import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionException; +import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; @@ -159,6 +165,7 @@ import static org.apache.flink.util.ExceptionUtils.firstOrSuppressed; import static org.apache.flink.util.Preconditions.checkState; import static org.apache.flink.util.concurrent.FutureUtils.assertNoException; +import static org.apache.flink.util.concurrent.FutureUtils.completeAll; /** * Base class for all streaming tasks. A task is the unit of local processing that is deployed and @@ -304,6 +311,9 @@ public abstract class StreamTask> /** TODO it might be replaced by the global IO executor on TaskManager level future. */ private final ExecutorService channelIOExecutor; + private RecoveryCheckpointTrigger recoveryCheckpointTrigger = + RecoveryCheckpointTrigger.NOT_READY; + // ======================================================== // Final checkpoint / savepoint // ======================================================== @@ -842,6 +852,12 @@ void restoreInternal() throws Exception { allGatesRecoveredFuture.isDone(), "Mailbox loop interrupted before recovery was finished."); + try { + allGatesRecoveredFuture.get(); + } catch (ExecutionException e) { + ExceptionUtils.rethrowException(e.getCause()); + } + // we recovered all the gates, we can close the channel IO executor as it is no longer // needed channelIOExecutor.shutdown(); @@ -881,58 +897,164 @@ private CompletableFuture restoreStateAndGates( INITIALIZE_STATE_DURATION, initializeStateEndTs - readOutputDataTs); IndexedInputGate[] inputGates = getEnvironment().getAllInputGates(); - boolean checkpointingDuringRecoveryEnabled = - CheckpointingOptions.isCheckpointingDuringRecoveryEnabled(getJobConfiguration()); + CompletableFuture recoveryCompletionFuture = + CheckpointingOptions.isCheckpointingDuringRecoveryEnabled(getJobConfiguration()) + ? recoverChannelsWithCheckpointing(reader, inputGates) + : recoverChannelsWithoutCheckpointing(reader, inputGates); - // Must set the flag on input gates BEFORE starting the async read task, because - // finishReadRecoveredState() checks this flag to complete bufferFilteringCompleteFuture. - for (IndexedInputGate inputGate : inputGates) { - inputGate.setCheckpointingDuringRecoveryEnabled(checkpointingDuringRecoveryEnabled); - } + recoveryCompletionFuture.whenComplete((ign, throwable) -> mailboxProcessor.suspend()); + return recoveryCompletionFuture; + } - channelIOExecutor.execute( - () -> { - try { - reader.readInputData(inputGates, createRecordFilterContext()); - } catch (Exception e) { - asyncExceptionHandler.handleAsyncException( - "Unable to read channel state", e); - } - }); + private CompletableFuture recoverChannelsWithCheckpointing( + SequentialChannelStateReader reader, IndexedInputGate[] inputGates) { + return setRecoveryCheckpointTrigger(RecoveryCheckpointTrigger.NOT_READY) + .thenApplyAsync(ign -> fetchChannelState(reader, inputGates), channelIOExecutor) + .thenCompose( + state -> + requestPartitions(inputGates, state.isPresent()) + .thenApply(channels -> buildDrainer(state, channels))) + .thenCompose(this::drainThroughCheckpointTrigger) + .thenRun(() -> setRecoveryCheckpointTrigger(RecoveryCheckpointTrigger.NO_OP)); + } - // We wait for all input channel state to recover before we go into RUNNING state, and thus - // start checkpointing. If we implement incremental checkpointing of input channel state - // we must make sure it supports CheckpointType#FULL_CHECKPOINT. - List> recoveredFutures = new ArrayList<>(inputGates.length); - for (InputGate inputGate : inputGates) { - CompletableFuture requestPartitionsTrigger = - checkpointingDuringRecoveryEnabled - ? inputGate.getBufferFilteringCompleteFuture() - : inputGate.getStateConsumedFuture(); + @SuppressWarnings("OptionalUsedAsFieldOrParameterType") // intentional: simplify call-site + private Optional buildDrainer( + Optional state, List channels) { + return state.map(s -> new FetchedChannelStateDrainer(s, channels)); + } - recoveredFutures.add(requestPartitionsTrigger); + @SuppressWarnings("OptionalUsedAsFieldOrParameterType") // intentional: simplify call-site + private CompletableFuture drainThroughCheckpointTrigger( + Optional drainer) { + if (drainer.isEmpty()) { + return FutureUtils.completedVoidFuture(); + } + FetchedChannelStateDrainer d = drainer.get(); + return setRecoveryCheckpointTrigger(d).thenRunAsync(() -> drain(d), channelIOExecutor); + } - requestPartitionsTrigger.thenRun( + private CompletableFuture setRecoveryCheckpointTrigger( + RecoveryCheckpointTrigger trigger) { + CompletableFuture future = new CompletableFuture<>(); + mainMailboxExecutor.execute( + () -> { + recoveryCheckpointTrigger = trigger; + future.complete(null); + }, + "update recoveryCheckpointTrigger to " + trigger); + return future; + } + + private CompletableFuture recoverChannelsWithoutCheckpointing( + SequentialChannelStateReader reader, IndexedInputGate[] inputGates) { + recoveryCheckpointTrigger = RecoveryCheckpointTrigger.NO_OP; + // Feed recovered channel state on the IO thread. This is intentionally NOT part of the + // completion gate below: a gate's stateConsumedFuture only completes once the consumer (the + // default action, running in the restore mailbox loop) has drained the end-of-state + // sentinel that readInputChannelState pushes, so gating on stateConsumedFuture already + // implies the read has finished. Including the read future in the gate would defer + // suspend() past the restore loop's first default action, letting input be processed before + // recovery completes -- causing record loss and "Mailbox loop interrupted before recovery + // was finished". Read failures are surfaced via asyncExceptionHandler inside + // readInputChannelState. + channelIOExecutor.execute(() -> readInputChannelState(reader, inputGates)); + List> futures = new ArrayList<>(); + for (InputGate inputGate : inputGates) { + CompletableFuture stateConsumed = inputGate.getStateConsumedFuture(); + futures.add(stateConsumed); + // Convert and request partitions for each gate as soon as ITS OWN recovered state is + // drained, independent of the other gates. Deferring conversion until every gate has + // drained deadlocks selective-reading multi-input operators: such an operator only + // drains the selected input's end-of-state sentinel, so an unselected gate would never + // drain (it is read only after conversion) while conversion would wait for it to drain + // first. suspend() is still gated on completeAll(futures) below. + stateConsumed.thenRun( () -> mainMailboxExecutor.execute( - inputGate::requestPartitions, "Input gate request partitions")); + () -> inputGate.requestPartitions(false), + "Input gate request partitions")); + } + return completeAll(futures); + } + + private void readInputChannelState( + SequentialChannelStateReader reader, IndexedInputGate[] inputGates) { + try { + checkState(reader.readInputData(inputGates, createRecordFilterContext()).isEmpty()); + + for (IndexedInputGate gate : inputGates) { + gate.finishReadRecoveredState(); // this is called from IO thread - is that fine? + } + } catch (Exception e) { + asyncExceptionHandler.handleAsyncException( + "Unable to set up recovered channel state", e); + } + } + + private Optional fetchChannelState( + SequentialChannelStateReader reader, IndexedInputGate[] inputGates) { + try { + return reader.readInputData(inputGates, createRecordFilterContext()); + } catch (Throwable t) { // review: don't catch errors + asyncExceptionHandler.handleAsyncException( + "Unable to set up recovered channel state", t); + return Optional.empty(); + } + } + + private void drain(FetchedChannelStateDrainer drainer) { + try (FetchedChannelStateDrainer ignored = drainer) { + try { + drainer.drain(); + } catch (Throwable t) { + asyncExceptionHandler.handleAsyncException( + "Unable to drain recovered channel state", t); + } + } catch (Throwable closeError) { + asyncExceptionHandler.handleAsyncException( + "Unable to close FetchedChannelStateDrainer after drain", closeError); } + } - // Return allOf future instead of thenRun future. thenRun() returns a NEW future that - // completes only after the callback finishes. CompletableFuture executes thenRun callbacks - // synchronously on the thread that calls complete(). When recoveredFutures contains - // bufferFilteringCompleteFuture (checkpointingDuringRecovery enabled), complete() is called - // on channelIOExecutor (in finishReadRecoveredState), so thenRun(suspend) also runs on - // channelIOExecutor. suspend() sends a poison mail, and the mailbox thread can pick it up - // and exit runMailboxLoop() before the thenRun future completes — causing - // checkState(isDone) to fail. With stateConsumedFuture (the default), complete() runs on - // the mailbox thread itself, so thenRun(suspend) blocks the loop from processing the poison - // mail until the future completes — no race. Returning allOf future avoids the issue - // entirely. - CompletableFuture allRecoveredFuture = - CompletableFuture.allOf(recoveredFutures.toArray(new CompletableFuture[0])); - allRecoveredFuture.thenRun(mailboxProcessor::suspend); - return allRecoveredFuture; + private CompletableFuture> requestPartitions( + IndexedInputGate[] inputGates, boolean needsRecovery) { + CompletableFuture> future = new CompletableFuture<>(); + mainMailboxExecutor.execute( + () -> { + try { + for (InputGate inputGate : inputGates) { + inputGate.requestPartitions(needsRecovery); + } + future.complete(collectPhysicalChannels(inputGates)); + } catch (Throwable t) { + future.completeExceptionally(t); + throw t; + } + }, + "Input gate request partitions"); + return future; + } + + private static List collectPhysicalChannels(InputGate[] inputGates) { + List channels = new ArrayList<>(); + for (InputGate gate : inputGates) { + int numberOfInputChannels = gate.getNumberOfInputChannels(); + for (int i = 0; i < numberOfInputChannels; i++) { + InputChannel channel = gate.getChannel(i); + if (channel instanceof RecoverableInputChannel) { + channels.add((RecoverableInputChannel) channel); + } + } + } + return channels; + } + + public RecoveryCheckpointTrigger getRecoveryCheckpointTrigger() { + return cpId -> { + checkState(mailboxProcessor.isMailboxThread()); + return recoveryCheckpointTrigger.snapshotAndInsertBarriers(cpId); + }; } private void ensureNotCanceled() { @@ -1996,7 +2118,11 @@ protected RecordFilterContext createRecordFilterContext() { boolean checkpointingDuringRecoveryEnabled = CheckpointingOptions.isCheckpointingDuringRecoveryEnabled(getJobConfiguration()); if (!checkpointingDuringRecoveryEnabled) { - return RecordFilterContext.disabled(); + // A disabled context never spills, so it needs no spill directories. Don't touch the + // IOManager here -- recovery then works even in minimal environments that don't provide + // one (e.g. DummyEnvironment-based tests), instead of NPEing and routing the failure to + // failExternally. + return RecordFilterContext.disabled(new String[0]); } ClassLoader cl = getUserCodeClassLoader(); diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/TwoInputStreamTask.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/TwoInputStreamTask.java index f933d5069d1e03..5ea0f544285074 100644 --- a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/TwoInputStreamTask.java +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/TwoInputStreamTask.java @@ -72,7 +72,8 @@ protected void createInputProcessor( new List[] {inputGates1, inputGates2}, Collections.emptyList(), mainMailboxExecutor, - systemTimerService); + systemTimerService, + getRecoveryCheckpointTrigger()); CheckpointedInputGate[] checkpointedInputGates = InputProcessorUtil.createCheckpointedMultipleInputGate( diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/AbstractSpillingHandlerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/AbstractSpillingHandlerTest.java new file mode 100644 index 00000000000000..812bcfebabf91f --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/AbstractSpillingHandlerTest.java @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint.channel; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import java.io.DataInputStream; +import java.io.FileInputStream; +import java.nio.file.Path; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests the on-disk spill format produced by {@link AbstractSpillingHandler}: the 12-byte segment + * header with backfilled buffer length, channel switching, file rotation, and empty-segment + * dropping. Records are appended through {@link TestSpillWriter}, mirroring how the filtering and + * pass-through subclasses feed the segment buffer. + */ +class AbstractSpillingHandlerTest { + + @TempDir Path tempDir; + + @Test + void testChannelSwitchProducesTwoSegmentsInOneFile() throws Exception { + try (TestSpillWriter writer = new TestSpillWriter(tempDir)) { + writer.writeRecord(new InputChannelInfo(0, 0), bytes(0xAA, 0xBB), 2); + writer.writeRecord(new InputChannelInfo(0, 1), bytes(0xCC, 0xDD, 0xEE), 3); + FetchedChannelState state = writer.getChannelState(); + + assertThat(state.files()).hasSize(1); + int seg0Body = Integer.BYTES + 2; + int seg1Body = Integer.BYTES + 3; + assertThat(state.files().get(0).toFile().length()) + .isEqualTo( + (long) (AbstractSpillingHandler.SEGMENT_HEADER_BYTES + seg0Body) + + (AbstractSpillingHandler.SEGMENT_HEADER_BYTES + seg1Body)); + } + } + + @Test + void testSameChannelContinuousRecordsMergeIntoOneSegment() throws Exception { + InputChannelInfo ch = new InputChannelInfo(0, 0); + try (TestSpillWriter writer = new TestSpillWriter(tempDir)) { + writer.writeRecord(ch, bytes(1), 1); + writer.writeRecord(ch, bytes(2, 3), 2); + writer.writeRecord(ch, bytes(4, 5, 6), 3); + FetchedChannelState state = writer.getChannelState(); + + long expectedBody = 3L * Integer.BYTES + (1 + 2 + 3); + assertThat(state.files()).hasSize(1); + assertThat(state.files().get(0).toFile().length()) + .isEqualTo(AbstractSpillingHandler.SEGMENT_HEADER_BYTES + expectedBody); + } + } + + @Test + void testPassThroughBytesAreWrittenVerbatim() throws Exception { + InputChannelInfo ch = new InputChannelInfo(1, 2); + byte[] data = bytes(0x01, 0x02, 0x03, 0x04, 0x05); + try (TestSpillWriter writer = new TestSpillWriter(tempDir)) { + writer.writePassThrough(ch, data, 0, data.length); + FetchedChannelState state = writer.getChannelState(); + + assertThat(state.files()).hasSize(1); + assertThat(state.files().get(0).toFile().length()) + .isEqualTo(AbstractSpillingHandler.SEGMENT_HEADER_BYTES + data.length); + } + } + + // ------------------------------------------------------------------------------------------- + // Empty segments: opening a segment without writing a body must not create a file + // ------------------------------------------------------------------------------------------- + + @Test + void testOpenSegmentWithoutBodyProducesNoFile() throws Exception { + try (TestSpillWriter writer = new TestSpillWriter(tempDir)) { + writer.openSegment(new InputChannelInfo(0, 0)); + // No body was ever spilled, so no state (and therefore no file) is produced. + assertThat(writer.getChannelState()).isNull(); + } + } + + @Test + void testEmptySegmentsAroundANonEmptyOneProduceExactlyTheNonEmptyFile() throws Exception { + try (TestSpillWriter writer = new TestSpillWriter(tempDir)) { + writer.openSegment(new InputChannelInfo(0, 0)); + writer.writeRecord(new InputChannelInfo(0, 1), bytes(7, 8, 9), 3); + writer.openSegment(new InputChannelInfo(0, 2)); + FetchedChannelState state = writer.getChannelState(); + + assertThat(state.files()).hasSize(1); + assertThat(state.files().get(0).toFile().length()) + .isEqualTo(AbstractSpillingHandler.SEGMENT_HEADER_BYTES + Integer.BYTES + 3); + } + } + + // ------------------------------------------------------------------------------------------- + // File rotation + // ------------------------------------------------------------------------------------------- + + @Test + void testRotationProducesOneFilePerSegmentWhenBoundIsTiny() throws Exception { + try (TestSpillWriter writer = new TestSpillWriter(tempDir, 1L)) { + writer.writeRecord(new InputChannelInfo(0, 0), bytes(1), 1); + writer.writeRecord(new InputChannelInfo(0, 1), bytes(2, 3), 2); + writer.writeRecord(new InputChannelInfo(0, 2), bytes(4), 1); + FetchedChannelState state = writer.getChannelState(); + + assertThat(state.files()).hasSize(3); + for (Path file : state.files()) { + assertThat(file.toFile().length()).isGreaterThan(0); + } + } + } + + // ------------------------------------------------------------------------------------------- + // Disk-format verification + // ------------------------------------------------------------------------------------------- + + @Test + void testDiskFormatMatchesSpec() throws Exception { + try (TestSpillWriter writer = new TestSpillWriter(tempDir)) { + writer.writeRecord(new InputChannelInfo(7, 3), bytes(0xAB, 0xCD, 0xEF), 3); + Path file = writer.getChannelState().files().get(0); + try (DataInputStream in = new DataInputStream(new FileInputStream(file.toFile()))) { + assertThat(in.readInt()).isEqualTo(7); // gateIdx + assertThat(in.readInt()).isEqualTo(3); // channelIdx + assertThat(in.readInt()).isEqualTo(Integer.BYTES + 3); // bufferLength + assertThat(in.readInt()).isEqualTo(3); // record length prefix + assertThat(in.read()).isEqualTo(0xAB); + assertThat(in.read()).isEqualTo(0xCD); + assertThat(in.read()).isEqualTo(0xEF); + assertThat(in.read()).isEqualTo(-1); // EOF + } + } + } + + private static byte[] bytes(int... values) { + byte[] out = new byte[values.length]; + for (int i = 0; i < values.length; i++) { + out[i] = (byte) values[i]; + } + return out; + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelIOExecutorDrainSubmissionTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelIOExecutorDrainSubmissionTest.java new file mode 100644 index 00000000000000..7bbc7e6e8f5d4e --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelIOExecutorDrainSubmissionTest.java @@ -0,0 +1,202 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint.channel; + +import org.apache.flink.core.memory.MemorySegment; +import org.apache.flink.core.memory.MemorySegmentFactory; +import org.apache.flink.runtime.io.network.buffer.Buffer; +import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler; +import org.apache.flink.runtime.io.network.buffer.NetworkBuffer; +import org.apache.flink.runtime.io.network.partition.consumer.RecoverableInputChannel; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import java.io.IOException; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +import static org.assertj.core.api.Assertions.assertThat; + +class ChannelIOExecutorDrainSubmissionTest { + + @TempDir Path tempDir; + + @Test + void testFilterOnSubmitsDrainAfterConversion() throws Exception { + InputChannelInfo cInfo = new InputChannelInfo(0, 0); + FetchedChannelState state = writeRecords(cInfo, new byte[] {1, 2, 3}); + + CapturingChannel chan = new CapturingChannel(cInfo); + List all = new ArrayList<>(); + all.add(chan); + FetchedChannelStateDrainer drainer = new FetchedChannelStateDrainer(state, all); + + ExecutorService channelIOExecutor = Executors.newSingleThreadExecutor(); + try { + CompletableFuture done = new CompletableFuture<>(); + channelIOExecutor.execute( + () -> { + try { + drainer.drain(); + done.complete(null); + } catch (Throwable t) { + done.completeExceptionally(t); + } finally { + try { + drainer.close(); + } catch (IOException ignore) { + } + } + }); + + done.get(5, TimeUnit.SECONDS); + assertThat(chan.dataDeliveries).isGreaterThan(0); + assertThat(chan.finishCalled).isTrue(); + } finally { + channelIOExecutor.shutdownNow(); + assertThat(channelIOExecutor.awaitTermination(5, TimeUnit.SECONDS)).isTrue(); + } + } + + @Test + void testDrainExceptionBubblesViaAsyncExceptionHandler() throws Exception { + InputChannelInfo cInfo = new InputChannelInfo(0, 0); + FetchedChannelState state = writeRecords(cInfo, new byte[] {1, 2, 3}); + + RecoverableInputChannel chan = + new RecoverableInputChannel() { + @Override + public InputChannelInfo getChannelInfo() { + return cInfo; + } + + @Override + public void onRecoveredStateBuffer(Buffer buffer) { + throw new RuntimeException("boom"); + } + + @Override + public void finishRecoveredBufferDelivery() {} + + @Override + public void insertRecoveryCheckpointBarrierIfInRecovery(long checkpointId) { + throw new RuntimeException("boom"); + } + + @Override + public Buffer requestRecoveryBufferBlocking() { + MemorySegment seg = MemorySegmentFactory.allocateUnpooledSegment(64); + return new NetworkBuffer(seg, FreeingBufferRecycler.INSTANCE); + } + + @Override + public void onRecoveredStateConsumed() {} + }; + + List all = new ArrayList<>(); + all.add(chan); + FetchedChannelStateDrainer drainer = new FetchedChannelStateDrainer(state, all); + + CountDownLatch handlerCalled = new CountDownLatch(1); + AtomicReference captured = new AtomicReference<>(); + ExecutorService channelIOExecutor = Executors.newSingleThreadExecutor(); + try { + channelIOExecutor.execute( + () -> { + try { + drainer.drain(); + } catch (Throwable t) { + captured.set(t); + handlerCalled.countDown(); + } finally { + try { + drainer.close(); + } catch (IOException ignore) { + } + } + }); + + assertThat(handlerCalled.await(5, TimeUnit.SECONDS)).isTrue(); + assertThat(captured.get()).isInstanceOf(RuntimeException.class); + assertThat(captured.get().getMessage()).isEqualTo("boom"); + } finally { + channelIOExecutor.shutdownNow(); + assertThat(channelIOExecutor.awaitTermination(5, TimeUnit.SECONDS)).isTrue(); + } + } + + // ------------------------------------------------------------------------------------------- + // Helpers + // ------------------------------------------------------------------------------------------- + + private FetchedChannelState writeRecords(InputChannelInfo ch, byte[] payload) + throws IOException { + try (TestSpillWriter writer = new TestSpillWriter(tempDir)) { + writer.writeRecord(ch, payload, payload.length); + return writer.getChannelState(); + } + } + + private static final class CapturingChannel implements RecoverableInputChannel { + private final InputChannelInfo channelInfo; + int dataDeliveries = 0; + boolean finishCalled = false; + + CapturingChannel(InputChannelInfo channelInfo) { + this.channelInfo = channelInfo; + } + + @Override + public InputChannelInfo getChannelInfo() { + return channelInfo; + } + + @Override + public void onRecoveredStateBuffer(Buffer buffer) { + if (buffer.isBuffer()) { + dataDeliveries++; + } + } + + @Override + public void finishRecoveredBufferDelivery() { + finishCalled = true; + } + + @Override + public void insertRecoveryCheckpointBarrierIfInRecovery(long checkpointId) {} + + @Override + public Buffer requestRecoveryBufferBlocking() { + MemorySegment seg = MemorySegmentFactory.allocateUnpooledSegment(64); + return new NetworkBuffer(seg, FreeingBufferRecycler.INSTANCE); + } + + @Override + public void onRecoveredStateConsumed() {} + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateFilteringHandlerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateFilteringHandlerTest.java new file mode 100644 index 00000000000000..e9581fa34514dd --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateFilteringHandlerTest.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint.channel; + +import org.apache.flink.api.common.typeutils.base.LongSerializer; +import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor; +import org.apache.flink.runtime.checkpoint.RescaleMappings; +import org.apache.flink.runtime.io.network.partition.consumer.InputGate; +import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGateBuilder; +import org.apache.flink.runtime.memory.MemoryManager; +import org.apache.flink.streaming.runtime.io.recovery.RecordFilterContext; +import org.apache.flink.streaming.runtime.partitioner.ForwardPartitioner; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import java.nio.file.Path; +import java.util.HashSet; + +import static org.assertj.core.api.Assertions.assertThat; + +/** Tests for {@link ChannelStateFilteringHandler}. */ +class ChannelStateFilteringHandlerTest { + + @TempDir Path tempDir; + + @Test + void testCreateFromContextUsesProvidedSpillDirectories() { + InputGate inputGate = new SingleInputGateBuilder().setNumberOfChannels(1).build(); + RecordFilterContext context = createRecordFilterContext(new String[] {tempDir.toString()}); + + ChannelStateFilteringHandler handler = + ChannelStateFilteringHandler.createFromContext( + context, new InputGate[] {inputGate}); + + assertThat(handler).isNotNull(); + handler.close(); + } + + private static RecordFilterContext createRecordFilterContext(String[] tmpDirectories) { + return new RecordFilterContext( + new RecordFilterContext.InputFilterConfig[] { + new RecordFilterContext.InputFilterConfig( + LongSerializer.INSTANCE, new ForwardPartitioner<>(), 1) + }, + new InflightDataRescalingDescriptor( + new InflightDataRescalingDescriptor + .InflightDataGateOrPartitionRescalingDescriptor[] { + new InflightDataRescalingDescriptor + .InflightDataGateOrPartitionRescalingDescriptor( + new int[] {0}, + RescaleMappings.identity(1, 1), + new HashSet<>(), + InflightDataRescalingDescriptor + .InflightDataGateOrPartitionRescalingDescriptor + .MappingType.IDENTITY) + }), + 0, + 128, + tmpDirectories, + true, + MemoryManager.DEFAULT_PAGE_SIZE); + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriterImplAddInputDataFromSpillTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriterImplAddInputDataFromSpillTest.java new file mode 100644 index 00000000000000..22fbe3d7d5efbb --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriterImplAddInputDataFromSpillTest.java @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint.channel; + +import org.apache.flink.api.common.JobID; +import org.apache.flink.runtime.checkpoint.CheckpointOptions; +import org.apache.flink.runtime.jobgraph.JobVertexID; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import java.nio.file.Path; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.assertj.core.api.Assertions.assertThat; + +class ChannelStateWriterImplAddInputDataFromSpillTest { + + private static final JobID JOB_ID = new JobID(); + private static final JobVertexID JOB_VERTEX_ID = new JobVertexID(); + private static final int SUBTASK_INDEX = 0; + private static final long CHECKPOINT_ID = 7L; + private static final String TASK_NAME = "test"; + + @TempDir Path tempDir; + + @Test + void testNonEmptySnapshotAsyncDemux() throws Exception { + SyncChannelStateWriteRequestExecutor worker = + new SyncChannelStateWriteRequestExecutor(JOB_ID); + try (ChannelStateWriterImpl writer = newWriter(worker)) { + worker.registerSubtask(JOB_VERTEX_ID, SUBTASK_INDEX); + writer.start(CHECKPOINT_ID, CheckpointOptions.forCheckpointWithDefaultLocation()); + + InputChannelInfo c0 = new InputChannelInfo(0, 0); + InputChannelInfo c1 = new InputChannelInfo(0, 1); + + FetchedChannelState state; + try (TestSpillWriter spillWriter = new TestSpillWriter(tempDir)) { + spillWriter.writeRecord(c0, new byte[] {1, 2, 3}, 3); + spillWriter.writeRecord(c1, new byte[] {4, 5}, 2); + spillWriter.writeRecord(c0, new byte[] {6}, 1); + state = spillWriter.getChannelState(); + } + FetchedChannelStateReader reader = state.reader(); + // Drop the handoff grant; the reader now holds the only outstanding grant. + state.release(); + + writer.addInputDataFromSpill(CHECKPOINT_ID, reader); + // Request is queued but not yet processed — state must still be alive. + assertThat(state.isClosed()).isFalse(); + + worker.processAllRequests(); + // After processing, the reader is closed by the request, releasing the last grant. + assertThat(state.isClosed()).isTrue(); + } + } + + @Test + void testEmptySnapshotStillSubmitted() throws Exception { + // Empty readers are no longer short-circuited; they are submitted to the writer thread. + QueueCountingExecutor worker = new QueueCountingExecutor(JOB_ID); + try (ChannelStateWriterImpl writer = + new ChannelStateWriterImpl( + JOB_VERTEX_ID, + TASK_NAME, + SUBTASK_INDEX, + new ConcurrentHashMap<>(), + worker, + 5)) { + worker.registerSubtask(JOB_VERTEX_ID, SUBTASK_INDEX); + writer.start(CHECKPOINT_ID, CheckpointOptions.forCheckpointWithDefaultLocation()); + + int submittedBefore = worker.submitCount.get(); + FetchedChannelState emptyState = + new FetchedChannelState(java.util.Collections.emptyList()); + FetchedChannelStateReader emptyReader = emptyState.reader(); + emptyState.release(); + + writer.addInputDataFromSpill(CHECKPOINT_ID, emptyReader); + + assertThat(worker.submitCount.get()) + .as("empty reader must still be submitted to the writer thread") + .isGreaterThan(submittedBefore); + } + } + + @Test + void testSegmentsClosedOnSuccess() throws Exception { + SyncChannelStateWriteRequestExecutor worker = + new SyncChannelStateWriteRequestExecutor(JOB_ID); + try (ChannelStateWriterImpl writer = newWriter(worker)) { + worker.registerSubtask(JOB_VERTEX_ID, SUBTASK_INDEX); + writer.start(CHECKPOINT_ID, CheckpointOptions.forCheckpointWithDefaultLocation()); + + FetchedChannelState state; + try (TestSpillWriter spillWriter = new TestSpillWriter(tempDir)) { + spillWriter.writeRecord(new InputChannelInfo(0, 0), new byte[] {1}, 1); + state = spillWriter.getChannelState(); + } + FetchedChannelStateReader reader = state.reader(); + state.release(); + + writer.addInputDataFromSpill(CHECKPOINT_ID, reader); + worker.processAllRequests(); + + // After processing, the last grant is released and the state is cleaned up. + assertThat(state.isClosed()).isTrue(); + } + } + + // ------------------------------------------------------------------------------------------- + // Helpers + // ------------------------------------------------------------------------------------------- + + private ChannelStateWriterImpl newWriter(SyncChannelStateWriteRequestExecutor worker) { + return new ChannelStateWriterImpl( + JOB_VERTEX_ID, TASK_NAME, SUBTASK_INDEX, new ConcurrentHashMap<>(), worker, 5); + } + + // ------------------------------------------------------------------------------------------- + // Executor stubs + // ------------------------------------------------------------------------------------------- + + private static final class QueueCountingExecutor implements ChannelStateWriteRequestExecutor { + + final AtomicInteger submitCount = new AtomicInteger(0); + + QueueCountingExecutor(JobID jobID) {} + + @Override + public void submit(ChannelStateWriteRequest e) { + submitCount.incrementAndGet(); + } + + @Override + public void submitPriority(ChannelStateWriteRequest e) { + submitCount.incrementAndGet(); + } + + @Override + public void start() throws IllegalStateException {} + + @Override + public void registerSubtask(JobVertexID jobVertexID, int subtaskIndex) {} + + @Override + public void releaseSubtask(JobVertexID jobVertexID, int subtaskIndex) {} + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateDrainerConcurrencyTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateDrainerConcurrencyTest.java new file mode 100644 index 00000000000000..63824a02d8c81a --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateDrainerConcurrencyTest.java @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint.channel; + +import org.apache.flink.core.memory.MemorySegment; +import org.apache.flink.core.memory.MemorySegmentFactory; +import org.apache.flink.runtime.io.network.api.serialization.EventSerializer; +import org.apache.flink.runtime.io.network.buffer.Buffer; +import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler; +import org.apache.flink.runtime.io.network.buffer.NetworkBuffer; +import org.apache.flink.runtime.io.network.partition.consumer.RecoverableInputChannel; + +import org.junit.jupiter.api.RepeatedTest; +import org.junit.jupiter.api.io.TempDir; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Stress test for {@link FetchedChannelStateDrainer} drain / snapshot critical-section atomicity. + * Spawns one drain thread alongside multiple concurrent {@code snapshotAndInsertBarriers} calls and + * asserts that no buffer is delivered both via drain and snapshot, and that every segment appears + * in exactly one place. + */ +class FetchedChannelStateDrainerConcurrencyTest { + + private static final int RECORD_COUNT = 500; + private static final int SNAPSHOTS = 50; + private static final int CHANNEL_COUNT = 2; + + @TempDir Path tempDir; + + @RepeatedTest(3) + void testDrainAndSnapshotConcurrentAtomicity() throws Exception { + Path runDir = Files.createTempDirectory(tempDir, "drain-stress-"); + InputChannelInfo c0 = new InputChannelInfo(0, 0); + InputChannelInfo c1 = new InputChannelInfo(0, 1); + + FetchedChannelState state; + try (TestSpillWriter writer = new TestSpillWriter(runDir)) { + for (int i = 0; i < RECORD_COUNT; i++) { + InputChannelInfo ch = (i % 2 == 0) ? c0 : c1; + writer.writeRecord(ch, payloadFor(i), 8); + } + state = writer.getChannelState(); + } + + ThreadSafeRecordingChannel chan0 = new ThreadSafeRecordingChannel(c0); + ThreadSafeRecordingChannel chan1 = new ThreadSafeRecordingChannel(c1); + List all = new ArrayList<>(); + all.add(chan0); + all.add(chan1); + + FetchedChannelStateDrainer drainer = new FetchedChannelStateDrainer(state, all); + + ExecutorService io = Executors.newSingleThreadExecutor(); + AtomicReference drainError = new AtomicReference<>(); + + Future drainFuture = + io.submit( + () -> { + try { + drainer.drain(); + } catch (Throwable t) { + drainError.set(t); + } + }); + + // Take snapshots concurrently while drain runs. + List snapshots = new ArrayList<>(); + for (int i = 0; i < SNAPSHOTS; i++) { + snapshots.add(drainer.snapshotAndInsertBarriers(i + 1)); + Thread.yield(); + } + + drainFuture.get(60, TimeUnit.SECONDS); + io.shutdown(); + assertThat(io.awaitTermination(10, TimeUnit.SECONDS)).isTrue(); + if (drainError.get() != null) { + throw new AssertionError("drain failed", drainError.get()); + } + + // Close all snapshots + for (FetchedChannelStateReader snap : snapshots) { + snap.close(); + } + + // Both channels must have received finish calls + assertThat(chan0.barrierCount()).isEqualTo(chan1.barrierCount()); + + drainer.close(); + } + + // ------------------------------------------------------------------------------------------- + // Helpers + // ------------------------------------------------------------------------------------------- + + private static byte[] payloadFor(int id) { + byte[] out = new byte[8]; + out[0] = (byte) (id & 0xff); + out[1] = (byte) ((id >> 8) & 0xff); + out[2] = (byte) ((id >> 16) & 0xff); + out[3] = (byte) ((id >> 24) & 0xff); + Arrays.fill(out, 4, 8, (byte) 0xCC); + return out; + } + + private static final class ThreadSafeRecordingChannel implements RecoverableInputChannel { + private final InputChannelInfo channelInfo; + private final List data = new ArrayList<>(); + private final List barriers = new ArrayList<>(); + + ThreadSafeRecordingChannel(InputChannelInfo channelInfo) { + this.channelInfo = channelInfo; + } + + @Override + public InputChannelInfo getChannelInfo() { + return channelInfo; + } + + @Override + public synchronized void onRecoveredStateBuffer(Buffer buffer) { + if (buffer.isBuffer()) { + data.add(buffer); + } else { + barriers.add(buffer); + } + } + + @Override + public synchronized void finishRecoveredBufferDelivery() {} + + @Override + public synchronized void insertRecoveryCheckpointBarrierIfInRecovery(long checkpointId) + throws IOException { + barriers.add( + EventSerializer.toBuffer(new RecoveryCheckpointBarrier(checkpointId), false)); + } + + @Override + public Buffer requestRecoveryBufferBlocking() { + // Use a buffer large enough for a full segment + MemorySegment seg = MemorySegmentFactory.allocateUnpooledSegment(4096); + return new NetworkBuffer(seg, FreeingBufferRecycler.INSTANCE); + } + + @Override + public synchronized void onRecoveredStateConsumed() {} + + synchronized int barrierCount() { + return barriers.size(); + } + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateDrainerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateDrainerTest.java new file mode 100644 index 00000000000000..d791926e6a0abf --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateDrainerTest.java @@ -0,0 +1,474 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint.channel; + +import org.apache.flink.core.memory.MemorySegment; +import org.apache.flink.core.memory.MemorySegmentFactory; +import org.apache.flink.runtime.checkpoint.channel.FetchedChannelStateReader.SpillSegment; +import org.apache.flink.runtime.event.AbstractEvent; +import org.apache.flink.runtime.io.network.api.serialization.EventSerializer; +import org.apache.flink.runtime.io.network.buffer.Buffer; +import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler; +import org.apache.flink.runtime.io.network.buffer.NetworkBuffer; +import org.apache.flink.runtime.io.network.partition.consumer.RecoverableInputChannel; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link FetchedChannelStateDrainer}: drain demux, finish ordering, snapshot start + * position, barrier insertion, and edge cases (drain-finished, channel-not-in-recovery). + */ +class FetchedChannelStateDrainerTest { + + @TempDir Path tempDir; + + @Test + void testDrainEndToEnd() throws Exception { + InputChannelInfo cInfo = new InputChannelInfo(0, 0); + FetchedChannelState state = writeRecords(cInfo, payload(1), payload(2), payload(3)); + + RecordingChannel rec = new RecordingChannel(cInfo); + FetchedChannelStateDrainer drainer = newDrainer(state, cInfo, rec); + + drainer.drain(); + drainer.close(); + + // All segment bodies must be delivered as buffer(s); at least 3 non-empty deliveries + // because the segment body contains 3 records but they may be batched into fewer buffers. + assertThat(rec.recovered).isNotEmpty(); + assertThat(rec.finishCalls).isEqualTo(1); + } + + @Test + void testDrainSegmentLargerThanBufferSplitsIntoFullChunksThenPartialTail() throws Exception { + InputChannelInfo cInfo = new InputChannelInfo(0, 0); + + // Buffer capacity deliberately smaller than the segment body so the drainer must fill + // multiple buffers and a final partial tail. 50 bytes over a 16-byte buffer => 16+16+16+2. + int bufferCapacity = 16; + byte[] body = sequentialBytes(50); + + FetchedChannelState state; + try (TestSpillWriter writer = new TestSpillWriter(tempDir)) { + // Pass-through so the segment body equals the verbatim bytes (no length framing). + writer.writePassThrough(cInfo, body, 0, body.length); + state = writer.getChannelState(); + } + + RecordingChannel rec = new RecordingChannel(cInfo, bufferCapacity); + FetchedChannelStateDrainer drainer = newDrainer(state, cInfo, rec); + + drainer.drain(); + drainer.close(); + + // ceil(50 / 16) = 4 buffers delivered. + assertThat(rec.recovered).hasSize(4); + // Every buffer except the last is filled to capacity; the last carries the remainder. + for (int i = 0; i < rec.recovered.size() - 1; i++) { + assertThat(rec.recovered.get(i).getSize()).isEqualTo(bufferCapacity); + } + assertThat(rec.recovered.get(rec.recovered.size() - 1).getSize()) + .isEqualTo(body.length % bufferCapacity); + + // Buffers concatenated in delivery order must reproduce the segment body byte-for-byte. + assertThat(concat(rec.recovered)).isEqualTo(body); + assertThat(rec.finishCalls).isEqualTo(1); + } + + @Test + void testDrainSegmentExactMultipleOfBufferHasNoPartialTail() throws Exception { + InputChannelInfo cInfo = new InputChannelInfo(0, 0); + + // Body length is an exact multiple of the buffer capacity: the final read hits EOF on a + // freshly requested buffer, which must be recycled rather than delivered empty. + int bufferCapacity = 16; + byte[] body = sequentialBytes(bufferCapacity * 3); + + FetchedChannelState state; + try (TestSpillWriter writer = new TestSpillWriter(tempDir)) { + writer.writePassThrough(cInfo, body, 0, body.length); + state = writer.getChannelState(); + } + + RecordingChannel rec = new RecordingChannel(cInfo, bufferCapacity); + FetchedChannelStateDrainer drainer = newDrainer(state, cInfo, rec); + + drainer.drain(); + drainer.close(); + + // Exactly 3 full buffers, no trailing empty buffer. + assertThat(rec.recovered).hasSize(3); + for (Buffer b : rec.recovered) { + assertThat(b.getSize()).isEqualTo(bufferCapacity); + } + assertThat(concat(rec.recovered)).isEqualTo(body); + assertThat(rec.finishCalls).isEqualTo(1); + } + + @Test + void testDrainDemuxByChannelInfo() throws Exception { + InputChannelInfo c0 = new InputChannelInfo(0, 0); + InputChannelInfo c1 = new InputChannelInfo(0, 1); + + FetchedChannelState state; + try (TestSpillWriter writer = new TestSpillWriter(tempDir)) { + writer.writeRecord(c0, payload(11), payload(11).length); + writer.writeRecord(c1, payload(22), payload(22).length); + writer.writeRecord(c0, payload(33), payload(33).length); + writer.writeRecord(c1, payload(44), payload(44).length); + state = writer.getChannelState(); + } + + RecordingChannel chan0 = new RecordingChannel(c0); + RecordingChannel chan1 = new RecordingChannel(c1); + FetchedChannelStateDrainer drainer = newDrainer(state, c0, chan0, c1, chan1); + + drainer.drain(); + drainer.close(); + + // Each channel must receive some data buffers + assertThat(chan0.recovered).isNotEmpty(); + assertThat(chan1.recovered).isNotEmpty(); + // Both channels must have finish called + assertThat(chan0.finishCalls).isEqualTo(1); + assertThat(chan1.finishCalls).isEqualTo(1); + } + + @Test + void testDrainCallsFinishAfterAllBufferDeliveries() throws Exception { + InputChannelInfo c0 = new InputChannelInfo(0, 0); + InputChannelInfo c1 = new InputChannelInfo(0, 1); + + FetchedChannelState state; + try (TestSpillWriter writer = new TestSpillWriter(tempDir)) { + writer.writeRecord(c0, payload(1), payload(1).length); + writer.writeRecord(c1, payload(2), payload(2).length); + state = writer.getChannelState(); + } + + int[] seq = {0}; + RecordingChannel chan0 = new RecordingChannel(c0, seq); + RecordingChannel chan1 = new RecordingChannel(c1, seq); + FetchedChannelStateDrainer drainer = newDrainer(state, c0, chan0, c1, chan1); + + drainer.drain(); + drainer.close(); + + int maxDataSeq = Math.max(chan0.maxDataSeq, chan1.maxDataSeq); + int minFinishSeq = Math.min(chan0.finishSeq, chan1.finishSeq); + assertThat(maxDataSeq).isLessThan(minFinishSeq); + } + + @Test + void testSnapshotCoversAllSegmentsBeforeDrain() throws Exception { + InputChannelInfo cInfo = new InputChannelInfo(0, 0); + FetchedChannelState state = writeRecords(cInfo, payload(5), payload(6)); + + RecordingChannel chan = new RecordingChannel(cInfo); + FetchedChannelStateDrainer drainer = newDrainer(state, cInfo, chan); + + long cpId = 42L; + FetchedChannelStateReader snap = drainer.snapshotAndInsertBarriers(cpId); + + // Snapshot must cover all segments (at least 1 segment for cInfo). + // The sequential reader requires each segment body to be fully consumed before advancing, + // mirroring the real consumer (ChannelStateCheckpointWriter#writeInputFromSpill). + int count = 0; + Optional next; + while ((next = snap.nextSegment()).isPresent()) { + drainBody(next.get().bodyStream()); + count++; + } + snap.close(); + assertThat(count).isGreaterThan(0); + drainer.close(); + } + + @Test + void testSnapshotInsertsBarrierPerChannel() throws Exception { + InputChannelInfo c0 = new InputChannelInfo(0, 0); + InputChannelInfo c1 = new InputChannelInfo(0, 1); + + FetchedChannelState state; + try (TestSpillWriter writer = new TestSpillWriter(tempDir)) { + writer.writeRecord(c0, payload(1), payload(1).length); + writer.writeRecord(c1, payload(2), payload(2).length); + state = writer.getChannelState(); + } + + RecordingChannel chan0 = new RecordingChannel(c0); + RecordingChannel chan1 = new RecordingChannel(c1); + FetchedChannelStateDrainer drainer = newDrainer(state, c0, chan0, c1, chan1); + + long cpId = 7L; + FetchedChannelStateReader snap = drainer.snapshotAndInsertBarriers(cpId); + snap.close(); + + assertThat(chan0.recovered).hasSize(1); + assertThat(chan1.recovered).hasSize(1); + assertThat(extractRecoveryBarrierCheckpointId(chan0.recovered.get(0))).isEqualTo(cpId); + assertThat(extractRecoveryBarrierCheckpointId(chan1.recovered.get(0))).isEqualTo(cpId); + drainer.close(); + } + + @Test + void testSnapshotInsertsBarrierWhenChannelInRecoveryEvenIfDiskSliceEmpty() throws Exception { + InputChannelInfo cInfo = new InputChannelInfo(0, 0); + FetchedChannelState state = writeRecords(cInfo, payload(1)); + + RecordingChannel chan = new RecordingChannel(cInfo); + FetchedChannelStateDrainer drainer = newDrainer(state, cInfo, chan); + + drainer.drain(); + // Drain finished; simulate the channel still in recovery + chan.inRecovery = true; + int recoveredBefore = chan.recovered.size(); + + long cpId = 6L; + FetchedChannelStateReader snap = drainer.snapshotAndInsertBarriers(cpId); + assertThat(snap.nextSegment()).isEmpty(); + snap.close(); + + // Barrier must be inserted even though disk slice is empty + assertThat(chan.recovered).hasSize(recoveredBefore + 1); + assertThat(extractRecoveryBarrierCheckpointId(chan.recovered.get(recoveredBefore))) + .isEqualTo(cpId); + drainer.close(); + } + + @Test + void testSnapshotInsertsBarrierOnlyForChannelsStillInRecovery() throws Exception { + InputChannelInfo c0 = new InputChannelInfo(0, 0); + InputChannelInfo c1 = new InputChannelInfo(0, 1); + + FetchedChannelState state; + try (TestSpillWriter writer = new TestSpillWriter(tempDir)) { + writer.writeRecord(c0, payload(1), payload(1).length); + state = writer.getChannelState(); + } + + RecordingChannel chan0 = new RecordingChannel(c0); + RecordingChannel chan1 = new RecordingChannel(c1); + chan1.inRecovery = false; + + FetchedChannelStateDrainer drainer = newDrainer(state, c0, chan0, c1, chan1); + + long cpId = 11L; + FetchedChannelStateReader snap = drainer.snapshotAndInsertBarriers(cpId); + snap.close(); + + assertThat(chan0.recovered).hasSize(1); + assertThat(extractRecoveryBarrierCheckpointId(chan0.recovered.get(0))).isEqualTo(cpId); + assertThat(chan1.recovered).isEmpty(); + drainer.close(); + } + + @Test + void testSnapshotReturnsEmptyWhenDrainFinishedAndNotInRecovery() throws Exception { + InputChannelInfo cInfo = new InputChannelInfo(0, 0); + FetchedChannelState state = writeRecords(cInfo, payload(1), payload(2)); + + RecordingChannel chan = new RecordingChannel(cInfo); + FetchedChannelStateDrainer drainer = newDrainer(state, cInfo, chan); + + drainer.drain(); + chan.inRecovery = false; + int recoveredBefore = chan.recovered.size(); + + FetchedChannelStateReader snap = drainer.snapshotAndInsertBarriers(99L); + assertThat(snap.nextSegment()).isEmpty(); + snap.close(); + + // No barrier added since channel left recovery + assertThat(chan.recovered).hasSize(recoveredBefore); + drainer.close(); + } + + @Test + void testSnapshotAfterDrainerClosedReturnsEmptyWithoutTouchingClosedRootReader() + throws Exception { + // Mirrors production order: drain() then close() (which closes the root reader) run before + // a + // late checkpoint fires snapshotAndInsertBarriers. The drain-finished flag must + // short-circuit + // so the closed root reader is never snapshotted. + InputChannelInfo cInfo = new InputChannelInfo(0, 0); + FetchedChannelState state = writeRecords(cInfo, payload(1), payload(2)); + + RecordingChannel chan = new RecordingChannel(cInfo); + FetchedChannelStateDrainer drainer = newDrainer(state, cInfo, chan); + + drainer.drain(); + drainer.close(); + + FetchedChannelStateReader snap = drainer.snapshotAndInsertBarriers(99L); + assertThat(snap.nextSegment()).isEmpty(); + snap.close(); + } + + // ------------------------------------------------------------------------------------------- + // Helpers + // ------------------------------------------------------------------------------------------- + + private FetchedChannelState writeRecords(InputChannelInfo ch, byte[]... payloads) + throws IOException { + try (TestSpillWriter writer = new TestSpillWriter(tempDir)) { + for (byte[] p : payloads) { + writer.writeRecord(ch, p, p.length); + } + return writer.getChannelState(); + } + } + + private FetchedChannelStateDrainer newDrainer( + FetchedChannelState state, Object... infoChannelPairs) { + List all = new ArrayList<>(); + for (int i = 0; i < infoChannelPairs.length; i += 2) { + all.add((RecoverableInputChannel) infoChannelPairs[i + 1]); + } + return new FetchedChannelStateDrainer(state, all); + } + + private static long extractRecoveryBarrierCheckpointId(Buffer buffer) throws IOException { + AbstractEvent event = + EventSerializer.fromBuffer( + buffer, RecoveryCheckpointBarrier.class.getClassLoader()); + buffer.setReaderIndex(0); + assertThat(event).isInstanceOf(RecoveryCheckpointBarrier.class); + return ((RecoveryCheckpointBarrier) event).getCheckpointId(); + } + + private static byte[] payload(int id) { + return new byte[] {(byte) (id & 0xff), (byte) ((id >> 8) & 0xff), (byte) 0xAB, (byte) 0xCD}; + } + + /** Builds {@code n} bytes whose values count up modulo 256, so order mismatches are visible. */ + private static byte[] sequentialBytes(int n) { + byte[] out = new byte[n]; + for (int i = 0; i < n; i++) { + out[i] = (byte) i; + } + return out; + } + + /** Concatenates the readable bytes of the given buffers in order. */ + private static byte[] concat(List buffers) { + java.io.ByteArrayOutputStream out = new java.io.ByteArrayOutputStream(); + for (Buffer b : buffers) { + java.nio.ByteBuffer nio = b.getNioBufferReadable(); + byte[] chunk = new byte[nio.remaining()]; + nio.get(chunk); + out.write(chunk, 0, chunk.length); + } + return out.toByteArray(); + } + + /** Fully consumes a segment body so the sequential reader may advance to the next segment. */ + private static void drainBody(InputStream body) throws IOException { + byte[] buf = new byte[256]; + while (body.read(buf) != -1) { + // discard + } + } + + // ------------------------------------------------------------------------------------------- + // RecordingChannel stub + // ------------------------------------------------------------------------------------------- + + private static final int DEFAULT_RECOVERY_BUFFER_CAPACITY = 4096; + + private static final class RecordingChannel implements RecoverableInputChannel { + private final InputChannelInfo channelInfo; + final List recovered = new ArrayList<>(); + int finishCalls = 0; + private final int[] sequence; + private final int bufferCapacity; + int maxDataSeq = Integer.MIN_VALUE; + int finishSeq = -1; + boolean inRecovery = true; + + RecordingChannel(InputChannelInfo channelInfo) { + this(channelInfo, null, DEFAULT_RECOVERY_BUFFER_CAPACITY); + } + + RecordingChannel(InputChannelInfo channelInfo, int[] sharedSequence) { + this(channelInfo, sharedSequence, DEFAULT_RECOVERY_BUFFER_CAPACITY); + } + + RecordingChannel(InputChannelInfo channelInfo, int bufferCapacity) { + this(channelInfo, null, bufferCapacity); + } + + RecordingChannel(InputChannelInfo channelInfo, int[] sharedSequence, int bufferCapacity) { + this.channelInfo = channelInfo; + this.sequence = sharedSequence; + this.bufferCapacity = bufferCapacity; + } + + @Override + public InputChannelInfo getChannelInfo() { + return channelInfo; + } + + @Override + public void onRecoveredStateBuffer(Buffer buffer) { + recovered.add(buffer); + if (sequence != null) { + maxDataSeq = Math.max(maxDataSeq, ++sequence[0]); + } + } + + @Override + public void finishRecoveredBufferDelivery() { + finishCalls++; + if (sequence != null) { + finishSeq = ++sequence[0]; + } + } + + @Override + public void insertRecoveryCheckpointBarrierIfInRecovery(long checkpointId) + throws IOException { + if (inRecovery) { + recovered.add( + EventSerializer.toBuffer( + new RecoveryCheckpointBarrier(checkpointId), false)); + } + } + + @Override + public Buffer requestRecoveryBufferBlocking() { + MemorySegment seg = MemorySegmentFactory.allocateUnpooledSegment(bufferCapacity); + return new NetworkBuffer(seg, FreeingBufferRecycler.INSTANCE); + } + + @Override + public void onRecoveredStateConsumed() {} + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateReaderTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateReaderTest.java new file mode 100644 index 00000000000000..a8bb14160f14e2 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateReaderTest.java @@ -0,0 +1,532 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint.channel; + +import org.apache.flink.runtime.checkpoint.channel.FetchedChannelStateReader.SpillSegment; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import java.io.EOFException; +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.StandardOpenOption; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Unit tests for {@link FetchedChannelStateReader}: sequential segment scanning, body boundedness, + * cross-file transparency, snapshot derivation, and fail-loud on truncated segments. + * + *

Segment boundaries are self-described in disk headers; no in-memory segment locator table is + * used. + */ +class FetchedChannelStateReaderTest { + + @TempDir Path tempDir; + + // ------------------------------------------------------------------------------------------- + // Segment iteration + // ------------------------------------------------------------------------------------------- + + @Test + void testIteratorEmptyWhenNoDataWritten() throws Exception { + // A writer that never spills produces no state; an empty state has no segments. + FetchedChannelState state = new FetchedChannelState(Collections.emptyList()); + try (FetchedChannelStateReader reader = state.reader()) { + assertThat(reader.nextSegment()).isEmpty(); + } + } + + @Test + void testMultipleIteratorIteratedInOrder() throws Exception { + InputChannelInfo c0 = new InputChannelInfo(0, 0); + InputChannelInfo c1 = new InputChannelInfo(0, 1); + + FetchedChannelState state; + try (TestSpillWriter writer = new TestSpillWriter(tempDir)) { + writer.writeRecord(c0, bytes(10), 1); + writer.writeRecord(c1, bytes(20), 1); + writer.writeRecord(c0, bytes(30), 1); + state = writer.getChannelState(); + } + + List channels = new ArrayList<>(); + try (FetchedChannelStateReader reader = state.reader()) { + Optional next; + while ((next = reader.nextSegment()).isPresent()) { + SpillSegment seg = next.get(); + channels.add(seg.channelInfo()); + readAll(seg.bodyStream()); + } + } + + // Segments are produced at channel switches: c0, c1, c0 + assertThat(channels).containsExactly(c0, c1, c0); + } + + // ------------------------------------------------------------------------------------------- + // Body boundedness: body() stops exactly at segment end + // ------------------------------------------------------------------------------------------- + + @Test + void testBodyReturnsMinus1AtSegmentEnd() throws Exception { + InputChannelInfo ch = new InputChannelInfo(0, 0); + + FetchedChannelState state; + try (TestSpillWriter writer = new TestSpillWriter(tempDir)) { + writer.writeRecord(ch, bytes(1, 2), 2); + state = writer.getChannelState(); + } + + try (FetchedChannelStateReader reader = state.reader()) { + SpillSegment seg = reader.nextSegment().orElseThrow(AssertionError::new); + InputStream body = seg.bodyStream(); + // Read exactly length bytes + byte[] data = new byte[seg.length()]; + int totalRead = 0; + while (totalRead < data.length) { + int n = body.read(data, totalRead, data.length - totalRead); + assertThat(n).isGreaterThan(0); + totalRead += n; + } + // Next read must return EOF + assertThat(body.read()).isEqualTo(-1); + } + } + + // ------------------------------------------------------------------------------------------- + // Cross-file transparency + // ------------------------------------------------------------------------------------------- + + @Test + void testCrossFileTransparencyWhenRotationOccurs() throws Exception { + InputChannelInfo c0 = new InputChannelInfo(0, 0); + InputChannelInfo c1 = new InputChannelInfo(0, 1); + + // Use tiny rotation threshold so first segment triggers a file rotation. + FetchedChannelState state; + try (TestSpillWriter writer = new TestSpillWriter(tempDir, 1L /* 1 byte threshold */)) { + writer.writeRecord(c0, bytes(10, 11, 12), 3); + writer.writeRecord(c1, bytes(20, 21), 2); + state = writer.getChannelState(); + } + + // Two segments in different files. + assertThat(state.files()).hasSize(2); + + List channels = new ArrayList<>(); + try (FetchedChannelStateReader reader = state.reader()) { + Optional next; + while ((next = reader.nextSegment()).isPresent()) { + SpillSegment seg = next.get(); + channels.add(seg.channelInfo()); + // Body read must not throw even if the segment is in a different file. + readAll(seg.bodyStream()); + } + } + + assertThat(channels).containsExactly(c0, c1); + } + + // ------------------------------------------------------------------------------------------- + // Snapshot: independent reader with correct start position + // ------------------------------------------------------------------------------------------- + + @Test + void testSnapshotCoversAllIteratorWhenNothingConsumed() throws Exception { + InputChannelInfo c0 = new InputChannelInfo(0, 0); + InputChannelInfo c1 = new InputChannelInfo(0, 1); + + FetchedChannelState state; + try (TestSpillWriter writer = new TestSpillWriter(tempDir)) { + writer.writeRecord(c0, bytes(1), 1); + writer.writeRecord(c1, bytes(2), 1); + state = writer.getChannelState(); + } + + try (FetchedChannelStateReader root = state.reader()) { + // Snapshot before consuming anything + try (FetchedChannelStateReader snap = root.snapshot().reader()) { + List channels = new ArrayList<>(); + Optional next; + while ((next = snap.nextSegment()).isPresent()) { + SpillSegment seg = next.get(); + channels.add(seg.channelInfo()); + readAll(seg.bodyStream()); + } + assertThat(channels).containsExactly(c0, c1); + } + } + } + + @Test + void testSnapshotAfterFullSegmentConsumedSkipsThatSegment() throws Exception { + InputChannelInfo c0 = new InputChannelInfo(0, 0); + InputChannelInfo c1 = new InputChannelInfo(0, 1); + + FetchedChannelState state; + try (TestSpillWriter writer = new TestSpillWriter(tempDir)) { + writer.writeRecord(c0, bytes(1), 1); + writer.writeRecord(c1, bytes(2), 1); + state = writer.getChannelState(); + } + + try (FetchedChannelStateReader root = state.reader()) { + // Consume and commit first segment + SpillSegment first = root.nextSegment().orElseThrow(AssertionError::new); + readAll(first.bodyStream()); + first.commit(); + + // Snapshot must start from second segment + try (FetchedChannelStateReader snap = root.snapshot().reader()) { + List channels = new ArrayList<>(); + Optional next; + while ((next = snap.nextSegment()).isPresent()) { + SpillSegment seg = next.get(); + channels.add(seg.channelInfo()); + readAll(seg.bodyStream()); + } + assertThat(channels).containsExactly(c1); + } + } + } + + @Test + void testSnapshotFromMidSegmentStartsAtCommittedByteOffset() throws Exception { + InputChannelInfo ch = new InputChannelInfo(0, 0); + + FetchedChannelState state; + try (TestSpillWriter writer = new TestSpillWriter(tempDir)) { + // Two records in the same channel -> one segment + writer.writeRecord(ch, bytes(10, 11), 2); + state = writer.getChannelState(); + } + // Verify: one file with one segment + assertThat(state.files()).hasSize(1); + + try (FetchedChannelStateReader root = state.reader()) { + SpillSegment seg = root.nextSegment().orElseThrow(AssertionError::new); + int fullLength = seg.length(); + InputStream body = seg.bodyStream(); + + // Read only 1 byte without committing, then snapshot — snapshot should start from 0 + // (no bytes committed yet). + body.read(); + + try (FetchedChannelStateReader snapBeforeCommit = root.snapshot().reader()) { + SpillSegment snapSeg = + snapBeforeCommit.nextSegment().orElseThrow(AssertionError::new); + assertThat(snapSeg.length()).isEqualTo(fullLength); + readAll(snapSeg.bodyStream()); + } + + // Read rest of body and commit + readAll(body); + seg.commit(); + + // After commit the snapshot must be empty + try (FetchedChannelStateReader snapAfterCommit = root.snapshot().reader()) { + assertThat(snapAfterCommit.nextSegment()).isEmpty(); + } + } + } + + @Test + void testSnapshotAfterPartialCommitReadsRemainingBodyTail() throws Exception { + InputChannelInfo ch = new InputChannelInfo(0, 0); + + FetchedChannelState state; + try (TestSpillWriter writer = new TestSpillWriter(tempDir)) { + // One segment whose body is a single pass-through blob of known bytes. + writer.writePassThrough(ch, bytes(1, 2, 3, 4, 5, 6, 7, 8), 0, 8); + state = writer.getChannelState(); + } + assertThat(state.files()).hasSize(1); + + try (FetchedChannelStateReader root = state.reader()) { + SpillSegment seg = root.nextSegment().orElseThrow(AssertionError::new); + int fullLength = seg.length(); + InputStream body = seg.bodyStream(); + + // Read and commit only a 3-byte prefix. + byte[] prefix = new byte[3]; + assertThat(body.read(prefix)).isEqualTo(3); + seg.commit(); + + // Snapshot must resume mid-segment and yield exactly the remaining tail bytes. + try (FetchedChannelStateReader snap = root.snapshot().reader()) { + SpillSegment snapSeg = snap.nextSegment().orElseThrow(AssertionError::new); + assertThat(snapSeg.channelInfo()).isEqualTo(ch); + assertThat(snapSeg.length()).isEqualTo(fullLength - 3); + byte[] tail = readAll(snapSeg.bodyStream()); + assertThat(tail).isEqualTo(bytes(4, 5, 6, 7, 8)); + assertThat(snap.nextSegment()).isEmpty(); + } + } + } + + @Test + void testSnapshotResumesPartialSegmentAcrossFileBoundary() throws Exception { + InputChannelInfo c0 = new InputChannelInfo(0, 0); + InputChannelInfo c1 = new InputChannelInfo(0, 1); + + // Tiny rotation threshold so the two segments land in separate files. + FetchedChannelState state; + try (TestSpillWriter writer = new TestSpillWriter(tempDir, 1L)) { + writer.writePassThrough(c0, bytes(1, 2, 3, 4), 0, 4); + writer.writePassThrough(c1, bytes(5, 6, 7), 0, 3); + state = writer.getChannelState(); + } + assertThat(state.files()).hasSize(2); + + try (FetchedChannelStateReader root = state.reader()) { + SpillSegment first = root.nextSegment().orElseThrow(AssertionError::new); + // Commit a 1-byte prefix of the file-0 segment. + first.bodyStream().read(new byte[1]); + first.commit(); + + try (FetchedChannelStateReader snap = root.snapshot().reader()) { + SpillSegment resumed = snap.nextSegment().orElseThrow(AssertionError::new); + assertThat(resumed.channelInfo()).isEqualTo(c0); + assertThat(readAll(resumed.bodyStream())).isEqualTo(bytes(2, 3, 4)); + + // Crossing into file 1 must reset the skip to 0. + SpillSegment following = snap.nextSegment().orElseThrow(AssertionError::new); + assertThat(following.channelInfo()).isEqualTo(c1); + assertThat(readAll(following.bodyStream())).isEqualTo(bytes(5, 6, 7)); + + assertThat(snap.nextSegment()).isEmpty(); + } + } + } + + @Test + void testRootDrainViaRepeatedCommitsTerminatesAndFinalSnapshotEmpty() throws Exception { + InputChannelInfo c0 = new InputChannelInfo(0, 0); + InputChannelInfo c1 = new InputChannelInfo(0, 1); + + FetchedChannelState state; + try (TestSpillWriter writer = new TestSpillWriter(tempDir)) { + writer.writePassThrough(c0, bytes(1, 2), 0, 2); + writer.writePassThrough(c1, bytes(3, 4, 5), 0, 3); + state = writer.getChannelState(); + } + + try (FetchedChannelStateReader root = state.reader()) { + int count = 0; + Optional next; + while ((next = root.nextSegment()).isPresent()) { + SpillSegment seg = next.get(); + readAll(seg.bodyStream()); + seg.commit(); + count++; + } + assertThat(count).isEqualTo(2); + + // After draining everything, a snapshot must have nothing left. + try (FetchedChannelStateReader snap = root.snapshot().reader()) { + assertThat(snap.nextSegment()).isEmpty(); + } + } + } + + // ------------------------------------------------------------------------------------------- + // Fail-loud on truncated segment + // ------------------------------------------------------------------------------------------- + + @Test + void testBodyThrowsEOFExceptionOnTruncatedFile() throws Exception { + InputChannelInfo ch = new InputChannelInfo(0, 0); + + FetchedChannelState state; + try (TestSpillWriter writer = new TestSpillWriter(tempDir)) { + writer.writeRecord(ch, bytes(1, 2, 3, 4, 5, 6, 7, 8), 8); + state = writer.getChannelState(); + } + + // Truncate the spill file to just the header (12 bytes) so the body is missing. + Path spill = state.files().get(0); + byte[] headerOnly = Files.readAllBytes(spill); + // Keep only the 12-byte header, discard body. + Files.write( + spill, + java.util.Arrays.copyOf(headerOnly, AbstractSpillingHandler.SEGMENT_HEADER_BYTES), + StandardOpenOption.TRUNCATE_EXISTING); + + try (FetchedChannelStateReader reader = state.reader()) { + SpillSegment seg = reader.nextSegment().orElseThrow(AssertionError::new); + // bufferLength from header says > 0 bytes, but file has nothing after the header. + assertThatThrownBy(() -> readAll(seg.bodyStream())) + .isInstanceOfAny(EOFException.class, IOException.class); + } + } + + // ------------------------------------------------------------------------------------------- + // Reference counting: acquire/release via reader lifecycle + // ------------------------------------------------------------------------------------------- + + @Test + void testReaderAcquiresAndReleasesRefCount() throws Exception { + InputChannelInfo ch = new InputChannelInfo(0, 0); + + FetchedChannelState state; + try (TestSpillWriter writer = new TestSpillWriter(tempDir)) { + writer.writeRecord(ch, bytes(1), 1); + state = writer.getChannelState(); + } + + Path spill = state.files().get(0); + + FetchedChannelStateReader reader = state.reader(); + // Drop the handoff grant so the reader's grant is the only one outstanding. + state.release(); + assertThat(java.nio.file.Files.exists(spill)).isTrue(); + + reader.close(); + + // After closing the only reader, the file is cleaned up. + assertThat(java.nio.file.Files.exists(spill)).isFalse(); + } + + @Test + void testSnapshotKeepsFilesAliveUntilSnapshotClosed() throws Exception { + InputChannelInfo ch = new InputChannelInfo(0, 0); + + FetchedChannelState state; + try (TestSpillWriter writer = new TestSpillWriter(tempDir)) { + writer.writeRecord(ch, bytes(1), 1); + state = writer.getChannelState(); + } + + Path spill = state.files().get(0); + + FetchedChannelStateReader root = state.reader(); + FetchedChannelStateReader snap = root.snapshot().reader(); + // Drop the handoff grant so only the two reader grants remain outstanding. + state.release(); + + root.close(); // One grant released; file must still exist because snap holds another. + assertThat(java.nio.file.Files.exists(spill)).isTrue(); + + snap.close(); // Last grant released; file must be deleted. + assertThat(java.nio.file.Files.exists(spill)).isFalse(); + } + + // ------------------------------------------------------------------------------------------- + // New behaviour: first-call exemption, fail-loud on body not consumed, empty reader + // ------------------------------------------------------------------------------------------- + + @Test + void testFirstNextSegmentCallDoesNotRequirePreviousBodyConsumed() throws Exception { + // The "previous body must be fully consumed" rule must not fire on the very first call + // because there is no previous segment. + FetchedChannelState state; + try (TestSpillWriter writer = new TestSpillWriter(tempDir)) { + writer.writeRecord(new InputChannelInfo(0, 0), bytes(1, 2), 2); + state = writer.getChannelState(); + } + + try (FetchedChannelStateReader reader = state.reader()) { + // Must not throw on the first call even though no body has been consumed before. + Optional seg = reader.nextSegment(); + assertThat(seg).isPresent(); + } + } + + @Test + void testNextSegmentThrowsWhenPreviousBodyNotFullyConsumed() throws Exception { + InputChannelInfo ch = new InputChannelInfo(0, 0); + InputChannelInfo ch2 = new InputChannelInfo(0, 1); + + FetchedChannelState state; + try (TestSpillWriter writer = new TestSpillWriter(tempDir)) { + writer.writeRecord(ch, bytes(1, 2, 3, 4), 4); + writer.writeRecord(ch2, bytes(5, 6), 2); + state = writer.getChannelState(); + } + + try (FetchedChannelStateReader reader = state.reader()) { + SpillSegment seg = reader.nextSegment().orElseThrow(AssertionError::new); + // Read only part of the body — do not exhaust it. + seg.bodyStream().read(); + + // Advancing to the next segment while the previous body is not fully consumed must fail + // loud. + assertThatThrownBy(reader::nextSegment) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Previous segment body not fully consumed"); + } + } + + @Test + void testEmptyReaderNextSegmentReturnsEmptyAndCloseIsClean() throws Exception { + FetchedChannelState state = new FetchedChannelState(Collections.emptyList()); + try (FetchedChannelStateReader reader = state.reader()) { + // First call on an empty reader must return empty without throwing. + assertThat(reader.nextSegment()).isEmpty(); + // Closing must not throw. + } + } + + @Test + void testEmptyReaderHandsOutIndependentInstancesSoCloseDoesNotLeak() throws Exception { + // emptyReader() is obtained and closed once per checkpoint. close() is single-use (it flips + // the closed flag permanently), so each call must yield a fresh instance; otherwise the + // first consumer's close would make every later consumer's nextSegment() fail loud. + FetchedChannelStateReader first = FetchedChannelStateReader.emptyReader(); + assertThat(first.nextSegment()).isEmpty(); + first.close(); + + FetchedChannelStateReader second = FetchedChannelStateReader.emptyReader(); + assertThat(second).isNotSameAs(first); + // Must still work after the previously obtained empty reader was closed. + assertThat(second.nextSegment()).isEmpty(); + second.close(); + } + + // ------------------------------------------------------------------------------------------- + // Helpers + // ------------------------------------------------------------------------------------------- + + private static byte[] readAll(InputStream in) throws IOException { + java.io.ByteArrayOutputStream out = new java.io.ByteArrayOutputStream(); + byte[] buf = new byte[256]; + int n; + while ((n = in.read(buf)) != -1) { + out.write(buf, 0, n); + } + return out.toByteArray(); + } + + private static byte[] bytes(int... values) { + byte[] arr = new byte[values.length]; + for (int i = 0; i < values.length; i++) { + arr[i] = (byte) values[i]; + } + return arr; + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateRefCountTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateRefCountTest.java new file mode 100644 index 00000000000000..e75d0009c2f65f --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateRefCountTest.java @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint.channel; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Verifies the reference counter on {@link FetchedChannelState}: acquire/release pairing, + * zero-triggered file deletion, idempotency, abort-path equivalence, and forced {@link + * FetchedChannelState#close()} cleanup. + */ +class FetchedChannelStateRefCountTest { + + @TempDir Path tempDir; + + @Test + void testAcquireReleaseCountsMatch() throws IOException { + FetchedChannelState state = newStateWithData(); + List files = state.files(); + + // The produced state already holds one handoff grant. + state.acquire(); + assertFilesExist(files, true); + + state.release(); + assertFilesExist(files, true); + + state.release(); + assertFilesExist(files, false); + } + + @Test + void testReachingZeroDeletesFiles() throws IOException { + FetchedChannelState state = newStateWithData(); + List files = state.files(); + // The produced state already holds one handoff grant. + assertFilesExist(files, true); + + state.release(); + assertFilesExist(files, false); + } + + @Test + void testReleaseAfterZeroIsNoOp() throws IOException { + FetchedChannelState state = newStateWithData(); + List files = state.files(); + // Release the single handoff grant the produced state already holds. + state.release(); + assertFilesExist(files, false); + + // Extra releases past zero must be a no-op. + state.release(); + state.release(); + assertFilesExist(files, false); + } + + @Test + void testAbortPathReleasesViaSameRoute() throws IOException { + FetchedChannelState state = newStateWithData(); + List files = state.files(); + + // The produced state already holds one handoff grant. + state.acquire(); + state.acquire(); + + state.release(); + state.release(); + assertFilesExist(files, true); + + state.release(); + assertFilesExist(files, false); + } + + @Test + void testForceCloseStillCleansFiles() throws IOException { + FetchedChannelState state = newStateWithData(); + List files = state.files(); + // The produced state already holds one handoff grant. + state.acquire(); + assertFilesExist(files, true); + + state.close(); + assertFilesExist(files, false); + + // Double close must be a no-op. + state.close(); + + // Late release after close must not re-delete or throw. + state.release(); + assertFilesExist(files, false); + } + + @Test + void testReaderAcquiresAndReleasesOnClose() throws IOException { + FetchedChannelState state = newStateWithData(); + List files = state.files(); + + // Opening a reader acquires one grant; the produced state already holds the handoff grant. + FetchedChannelStateReader reader = state.reader(); + // Drop the handoff grant so the reader's grant is the only one outstanding. + state.release(); + assertFilesExist(files, true); + + // Closing releases that grant, which reaches zero and deletes files. + reader.close(); + assertFilesExist(files, false); + } + + // ------------------------------------------------------------------------------------------- + // Helpers + // ------------------------------------------------------------------------------------------- + + private FetchedChannelState newStateWithData() throws IOException { + InputChannelInfo ch = new InputChannelInfo(0, 0); + try (TestSpillWriter writer = new TestSpillWriter(tempDir)) { + writer.writeRecord(ch, new byte[] {1, 2, 3}, 3); + writer.writeRecord(new InputChannelInfo(0, 1), new byte[] {4, 5}, 2); + return writer.getChannelState(); + } + } + + private static void assertFilesExist(List files, boolean expected) { + for (Path file : files) { + assertThat(Files.exists(file)) + .as("file " + file + " exists=" + expected) + .isEqualTo(expected); + } + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateTest.java new file mode 100644 index 00000000000000..c80656b65536f2 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateTest.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint.channel; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import java.io.IOException; +import java.nio.file.Path; +import java.util.Arrays; +import java.util.Collections; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** Tests for {@link FetchedChannelState} lifecycle and file list management. */ +class FetchedChannelStateTest { + + @TempDir Path tempDir; + + @Test + void testInitialStateIsEmpty() { + FetchedChannelState state = new FetchedChannelState(Collections.emptyList()); + assertThat(state.files()).isEmpty(); + assertThat(state.isClosed()).isFalse(); + } + + @Test + void testFileListPreservesOrder() throws IOException { + Path file0 = tempDir.resolve("spill-0.bin"); + Path file1 = tempDir.resolve("spill-1.bin"); + + try (FetchedChannelState state = new FetchedChannelState(Arrays.asList(file0, file1))) { + assertThat(state.files()).containsExactly(file0, file1); + } + } + + @Test + void testFilesListIsUnmodifiable() throws IOException { + try (FetchedChannelState state = + new FetchedChannelState(Collections.singletonList(tempDir.resolve("f0.bin")))) { + assertThatThrownBy(() -> state.files().add(tempDir.resolve("f1.bin"))) + .isInstanceOf(UnsupportedOperationException.class); + } + } + + @Test + void testAcquireReleaseDoesNotDeleteFilesBeforeLastRelease() throws IOException { + Path realFile = tempDir.resolve("spill-0.bin"); + realFile.toFile().createNewFile(); + FetchedChannelState state = new FetchedChannelState(Collections.singletonList(realFile)); + + state.acquire(); + state.acquire(); + + state.release(); + // File must still exist after first release. + assertThat(realFile.toFile()).exists(); + + state.release(); + // Last release should delete the file. + assertThat(realFile.toFile()).doesNotExist(); + assertThat(state.isClosed()).isTrue(); + } + + @Test + void testCloseDeletesAllFiles() throws IOException { + Path file0 = tempDir.resolve("f0.bin"); + Path file1 = tempDir.resolve("f1.bin"); + file0.toFile().createNewFile(); + file1.toFile().createNewFile(); + + FetchedChannelState state = new FetchedChannelState(Arrays.asList(file0, file1)); + + state.close(); + + assertThat(file0.toFile()).doesNotExist(); + assertThat(file1.toFile()).doesNotExist(); + assertThat(state.isClosed()).isTrue(); + } + + @Test + void testCloseIsIdempotent() throws IOException { + FetchedChannelState state = new FetchedChannelState(Collections.emptyList()); + state.close(); + assertThat(state.isClosed()).isTrue(); + // Second close must not throw. + state.close(); + assertThat(state.isClosed()).isTrue(); + } + + @Test + void testCloseAfterReleaseIsIdempotent() throws IOException { + FetchedChannelState state = new FetchedChannelState(Collections.emptyList()); + state.acquire(); + state.release(); + assertThat(state.isClosed()).isTrue(); + // close() after last release must be a no-op (no double-delete attempt). + state.close(); + assertThat(state.isClosed()).isTrue(); + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/GateFilterHandlerBufferOwnershipTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/GateFilterHandlerBufferOwnershipTest.java index 85b4fd1d48ef1f..ae7b722f683b7b 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/GateFilterHandlerBufferOwnershipTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/GateFilterHandlerBufferOwnershipTest.java @@ -38,16 +38,14 @@ import java.io.IOException; import java.util.HashMap; -import java.util.List; import java.util.Map; -import java.util.concurrent.atomic.AtomicInteger; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; /** * Tests buffer ownership semantics of {@link ChannelStateFilteringHandler.GateFilterHandler}. Each - * test verifies that buffers are properly recycled on both success and failure paths. + * test verifies that source buffers are properly recycled on both success and failure paths. */ class GateFilterHandlerBufferOwnershipTest { @@ -60,13 +58,9 @@ void testSourceBufferRecycledOnSuccess() throws Exception { createHandler(RecordFilter.acceptAll()); Buffer sourceBuffer = createBufferWithRecords(1L, 2L); - List result = handler.filterAndRewrite(0, 0, sourceBuffer, this::createEmptyBuffer); + handler.filterAndRewrite(0, 0, sourceBuffer, new DataOutputSerializer(BUFFER_SIZE)); - // sourceBuffer should be recycled by the deserializer after consumption assertThat(sourceBuffer.isRecycled()).isTrue(); - - // Clean up result buffers - result.forEach(Buffer::recycleBuffer); } @Test @@ -75,102 +69,60 @@ void testSourceBufferRecycledWhenAllRecordsFilteredOut() throws Exception { ChannelStateFilteringHandler.GateFilterHandler handler = createHandler(rejectAll); Buffer sourceBuffer = createBufferWithRecords(1L, 2L); - List result = handler.filterAndRewrite(0, 0, sourceBuffer, this::createEmptyBuffer); + handler.filterAndRewrite(0, 0, sourceBuffer, new DataOutputSerializer(BUFFER_SIZE)); - assertThat(result).isEmpty(); - // sourceBuffer should still be recycled even though no output was produced assertThat(sourceBuffer.isRecycled()).isTrue(); } @Test void testSourceBufferRecycledOnInvalidVirtualChannel() { - // Create handler with KEY=(0,0) but call with (1,1) to trigger IllegalStateException + // Create handler with KEY=(0,0) but call with (1,1) to trigger IllegalStateException. ChannelStateFilteringHandler.GateFilterHandler handler = createHandler(RecordFilter.acceptAll()); Buffer sourceBuffer = createBufferWithRecords(1L); assertThatThrownBy( - () -> handler.filterAndRewrite(1, 1, sourceBuffer, this::createEmptyBuffer)) + () -> + handler.filterAndRewrite( + 1, 1, sourceBuffer, new DataOutputSerializer(BUFFER_SIZE))) .isInstanceOf(IllegalStateException.class); - // sourceBuffer must be recycled even when lookup fails before setNextBuffer - assertThat(sourceBuffer.isRecycled()).isTrue(); - } - - @Test - void testResultBuffersAndCurrentBufferRecycledOnSerializationError() throws Exception { - // Use a small buffer so that records span multiple buffers. The supplier fails on the - // second request, after the first output buffer has been filled and added to resultBuffers. - AtomicInteger bufferRequestCount = new AtomicInteger(0); - ChannelStateFilteringHandler.BufferSupplier failingSupplier = - () -> { - if (bufferRequestCount.incrementAndGet() > 1) { - throw new IOException("Simulated buffer allocation failure"); - } - return createEmptyBuffer(13); - }; - - ChannelStateFilteringHandler.GateFilterHandler handler = - createHandler(RecordFilter.acceptAll()); - - Buffer sourceBuffer = createBufferWithRecords(1L, 2L, 3L, 4L, 5L); - - // The exception should propagate; no buffer leak (no IllegalReferenceCountException - // from double-recycle). - assertThatThrownBy(() -> handler.filterAndRewrite(0, 0, sourceBuffer, failingSupplier)) - .isInstanceOf(IOException.class) - .hasMessage("Simulated buffer allocation failure"); - - // sourceBuffer ownership was transferred to the deserializer via setNextBuffer(). - // The deserializer may still hold it if it hasn't fully consumed the buffer before the - // error. Calling clear() triggers the cleanup chain: - // GateFilterHandler#clear() -> VirtualChannel#clear() -> deserializer.clear() - handler.clear(); + // sourceBuffer must be recycled even when lookup fails before setNextBuffer. assertThat(sourceBuffer.isRecycled()).isTrue(); } /** - * Tests the production cleanup path: when filterAndRewrite throws mid-processing, the - * deserializer may still hold sourceBuffer. In production, ChannelStateFilteringHandler is used - * in a try-with-resources block (see {@code SequentialChannelStateReaderImpl#readInputData}), - * so its close() is guaranteed to be called, which triggers clear() on all GateFilterHandlers - * and their deserializers. This test simulates that exact pattern. + * When filterAndRewrite throws mid-processing, the deserializer may still hold sourceBuffer. In + * production, ChannelStateFilteringHandler is used in a try-with-resources block (see {@code + * SequentialChannelStateReaderImpl#readInputData}), so its close() is guaranteed to be called, + * which triggers clear() on all GateFilterHandlers and their deserializers. This test simulates + * that exact pattern. */ @Test void testCloseRecyclesDeserializerHeldBufferAfterError() throws Exception { - AtomicInteger bufferRequestCount = new AtomicInteger(0); - ChannelStateFilteringHandler.BufferSupplier failingSupplier = - () -> { - if (bufferRequestCount.incrementAndGet() > 1) { - throw new IOException("Simulated buffer allocation failure"); - } - return createEmptyBuffer(13); - }; - ChannelStateFilteringHandler.GateFilterHandler gateHandler = createHandler(RecordFilter.acceptAll()); - // Wrap in ChannelStateFilteringHandler, the production-level owner ChannelStateFilteringHandler filteringHandler = new ChannelStateFilteringHandler( new ChannelStateFilteringHandler.GateFilterHandler[] {gateHandler}); + // A serializer that throws while writing the second record's length prefix, triggering a + // mid-processing failure after the first record has already been emitted. + DataOutputSerializer failingSerializer = new FailingAfterFirstRecordSerializer(); Buffer sourceBuffer = createBufferWithRecords(1L, 2L, 3L, 4L, 5L); - // Simulate the production try-with-resources pattern assertThatThrownBy( () -> { try (ChannelStateFilteringHandler ignored = filteringHandler) { filteringHandler.filterAndRewrite( - 0, 0, 0, sourceBuffer, failingSupplier); + 0, 0, 0, sourceBuffer, failingSerializer); } }) .isInstanceOf(IOException.class) - .hasMessage("Simulated buffer allocation failure"); + .hasMessage("Simulated write failure"); - // After close(), the entire cleanup chain has fired: - // ChannelStateFilteringHandler.close() -> GateFilterHandler.clear() - // -> VirtualChannel.clear() -> deserializer.clear() -> sourceBuffer.recycleBuffer() + // After close(), the entire cleanup chain has fired. assertThat(sourceBuffer.isRecycled()).isTrue(); } @@ -219,12 +171,26 @@ private Buffer createBufferWithRecords(Long... values) { } } - private Buffer createEmptyBuffer() { - return createEmptyBuffer(BUFFER_SIZE); - } + /** + * A {@link DataOutputSerializer} that throws an IOException while writing the second record's + * length prefix, simulating a failure mid-stream to verify that the source buffer is still + * recycled via the filtering handler's close() cleanup chain. Each surviving record begins with + * a {@code writeInt} placeholder for its length, so the second {@code writeInt} marks the start + * of the second record. + */ + private static final class FailingAfterFirstRecordSerializer extends DataOutputSerializer { + private int writeIntCount = 0; + + FailingAfterFirstRecordSerializer() { + super(BUFFER_SIZE); + } - private Buffer createEmptyBuffer(int size) { - MemorySegment segment = MemorySegmentFactory.allocateUnpooledSegment(size); - return new NetworkBuffer(segment, FreeingBufferRecycler.INSTANCE); + @Override + public void writeInt(int v) throws IOException { + if (++writeIntCount > 1) { + throw new IOException("Simulated write failure"); + } + super.writeInt(v); + } } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/GateFilterHandlerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/GateFilterHandlerTest.java index f02ce35fd867d5..1646908727a88f 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/GateFilterHandlerTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/GateFilterHandlerTest.java @@ -57,10 +57,10 @@ void testAllRecordsPassFilter() throws Exception { createHandler(RecordFilter.acceptAll()); Buffer sourceBuffer = createBufferWithRecords(1L, 2L, 3L); - List result = handler.filterAndRewrite(0, 0, sourceBuffer, this::createEmptyBuffer); + DataOutputSerializer output = new DataOutputSerializer(BUFFER_SIZE); + handler.filterAndRewrite(0, 0, sourceBuffer, output); - // deserializeBuffers consumes (recycles) each buffer via the deserializer - List values = deserializeBuffers(result); + List values = readRecordsFromSerializer(output); assertThat(values).containsExactly(1L, 2L, 3L); } @@ -70,9 +70,11 @@ void testAllRecordsFilteredOut() throws Exception { ChannelStateFilteringHandler.GateFilterHandler handler = createHandler(rejectAll); Buffer sourceBuffer = createBufferWithRecords(1L, 2L, 3L); - List result = handler.filterAndRewrite(0, 0, sourceBuffer, this::createEmptyBuffer); + DataOutputSerializer output = new DataOutputSerializer(BUFFER_SIZE); + handler.filterAndRewrite(0, 0, sourceBuffer, output); - assertThat(result).isEmpty(); + // No bytes should be written when all records are filtered out. + assertThat(output.length()).isZero(); } @Test @@ -81,42 +83,50 @@ void testPartialFiltering() throws Exception { ChannelStateFilteringHandler.GateFilterHandler handler = createHandler(keepEven); Buffer sourceBuffer = createBufferWithRecords(1L, 2L, 3L, 4L, 5L); - List result = handler.filterAndRewrite(0, 0, sourceBuffer, this::createEmptyBuffer); + DataOutputSerializer output = new DataOutputSerializer(BUFFER_SIZE); + handler.filterAndRewrite(0, 0, sourceBuffer, output); - List values = deserializeBuffers(result); + List values = readRecordsFromSerializer(output); assertThat(values).containsExactly(2L, 4L); } @Test - void testSmallOutputBufferProducesMultipleBuffers() throws Exception { - // Use a very small output buffer size so records must span multiple buffers - int smallBufferSize = 8; + void testEmptyBuffer() throws Exception { ChannelStateFilteringHandler.GateFilterHandler handler = createHandler(RecordFilter.acceptAll()); - Buffer sourceBuffer = createBufferWithRecords(1L, 2L, 3L); - List result = - handler.filterAndRewrite( - 0, 0, sourceBuffer, () -> createEmptyBuffer(smallBufferSize)); + Buffer emptyBuffer = createEmptyBuffer(); + emptyBuffer.setSize(0); - // Each Long record needs 4 bytes length + ~9 bytes data > 8-byte buffer - assertThat(result.size()).isGreaterThan(1); + DataOutputSerializer output = new DataOutputSerializer(BUFFER_SIZE); + handler.filterAndRewrite(0, 0, emptyBuffer, output); - List values = deserializeBuffers(result); - assertThat(values).containsExactly(1L, 2L, 3L); + // No data written for an empty source buffer. + assertThat(output.length()).isZero(); } @Test - void testEmptyBuffer() throws Exception { + void testSourceBufferRecycledOnSuccess() throws Exception { ChannelStateFilteringHandler.GateFilterHandler handler = createHandler(RecordFilter.acceptAll()); - Buffer emptyBuffer = createEmptyBuffer(); - emptyBuffer.setSize(0); + Buffer sourceBuffer = createBufferWithRecords(1L, 2L); + DataOutputSerializer output = new DataOutputSerializer(BUFFER_SIZE); + handler.filterAndRewrite(0, 0, sourceBuffer, output); + + assertThat(sourceBuffer.isRecycled()).isTrue(); + } + + @Test + void testSourceBufferRecycledWhenAllRecordsFilteredOut() throws Exception { + RecordFilter rejectAll = record -> false; + ChannelStateFilteringHandler.GateFilterHandler handler = createHandler(rejectAll); - List result = handler.filterAndRewrite(0, 0, emptyBuffer, this::createEmptyBuffer); + Buffer sourceBuffer = createBufferWithRecords(1L, 2L); + DataOutputSerializer output = new DataOutputSerializer(BUFFER_SIZE); + handler.filterAndRewrite(0, 0, sourceBuffer, output); - assertThat(result).isEmpty(); + assertThat(sourceBuffer.isRecycled()).isTrue(); } // ------------------------------------------------------------------------------------------- @@ -141,23 +151,13 @@ private ChannelStateFilteringHandler.GateFilterHandler createHandler( private Buffer createBufferWithRecords(Long... values) throws IOException { StreamElementSerializer serializer = new StreamElementSerializer<>(LongSerializer.INSTANCE); - return serializeRecordsToBuffer(serializer, values); - } - - /** Serializes records into a buffer using Flink's length-prefixed format. */ - private Buffer serializeRecordsToBuffer( - StreamElementSerializer serializer, Long... values) throws IOException { DataOutputSerializer output = new DataOutputSerializer(BUFFER_SIZE); for (Long value : values) { - // Serialize using the same length-prefixed format as Flink DataOutputSerializer recordOutput = new DataOutputSerializer(64); serializer.serialize(new StreamRecord<>(value), recordOutput); int recordLength = recordOutput.length(); - - // Write 4-byte big-endian length prefix output.writeInt(recordLength); - // Write record bytes output.write(recordOutput.getSharedBuffer(), 0, recordLength); } @@ -171,43 +171,49 @@ private Buffer serializeRecordsToBuffer( } private Buffer createEmptyBuffer() { - return createEmptyBuffer(BUFFER_SIZE); - } - - private Buffer createEmptyBuffer(int size) { - MemorySegment segment = MemorySegmentFactory.allocateUnpooledSegment(size); + MemorySegment segment = MemorySegmentFactory.allocateUnpooledSegment(BUFFER_SIZE); return new NetworkBuffer(segment, FreeingBufferRecycler.INSTANCE); } - private List deserializeBuffers(List buffers) throws IOException { + /** + * Deserializes the records the handler appended into {@code output}. The body format is + * repeated (4B recordLen + N bytes of serialized StreamElement), which the deserializer reads + * directly. + */ + private List readRecordsFromSerializer(DataOutputSerializer output) throws Exception { + List values = new ArrayList<>(); StreamElementSerializer serializer = new StreamElementSerializer<>(LongSerializer.INSTANCE); + DeserializationDelegate delegate = + new NonReusingDeserializationDelegate<>(serializer); + + byte[] bodyBytes = output.getCopyOfBuffer(); + if (bodyBytes.length == 0) { + return values; + } + MemorySegment memSeg = MemorySegmentFactory.allocateUnpooledSegment(bodyBytes.length); + memSeg.put(0, bodyBytes); + NetworkBuffer buf = new NetworkBuffer(memSeg, FreeingBufferRecycler.INSTANCE); + buf.setSize(bodyBytes.length); + SpillingAdaptiveSpanningRecordDeserializer> deserializer = new SpillingAdaptiveSpanningRecordDeserializer<>( new String[] {System.getProperty("java.io.tmpdir")}); - DeserializationDelegate delegate = - new NonReusingDeserializationDelegate<>(serializer); - - List values = new ArrayList<>(); - for (Buffer buffer : buffers) { - deserializer.setNextBuffer(buffer); - while (true) { - RecordDeserializer.DeserializationResult result = - deserializer.getNextRecord(delegate); - if (result.isFullRecord()) { - StreamElement element = delegate.getInstance(); - if (element.isRecord()) { - @SuppressWarnings("unchecked") - StreamRecord record = (StreamRecord) element; - values.add(record.getValue()); - } - } - if (result.isBufferConsumed()) { - break; + deserializer.setNextBuffer(buf); + + RecordDeserializer.DeserializationResult result; + do { + result = deserializer.getNextRecord(delegate); + if (result.isFullRecord()) { + StreamElement element = delegate.getInstance(); + if (element.isRecord()) { + @SuppressWarnings("unchecked") + StreamRecord record = (StreamRecord) element; + values.add(record.getValue()); } } - } + } while (!result.isBufferConsumed()); return values; } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/InputChannelRecoveredStateHandlerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/InputChannelRecoveredStateHandlerTest.java index 9c4aab0bc7a5da..d82fb1866a688b 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/InputChannelRecoveredStateHandlerTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/InputChannelRecoveredStateHandlerTest.java @@ -32,7 +32,9 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; +import java.nio.file.Path; import java.util.HashSet; import static org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptorUtil.mappings; @@ -40,12 +42,15 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; -/** Test of different implementation of {@link InputChannelRecoveredStateHandler}. */ +/** Test of different implementation of {@link AbstractInputChannelRecoveredStateHandler}. */ class InputChannelRecoveredStateHandlerTest extends RecoveredChannelStateHandlerTest { + @TempDir private Path tmpDir; + private static final int preAllocatedSegments = 3; private NetworkBufferPool networkBufferPool; private SingleInputGate inputGate; - private InputChannelRecoveredStateHandler icsHandler; + // NoSpillingHandler: checkpointingDuringRecoveryEnabled=false, filteringHandler=null + private AbstractInputChannelRecoveredStateHandler icsHandler; private InputChannelInfo channelInfo; @BeforeEach @@ -61,14 +66,14 @@ void setUp() { .setSegmentProvider(networkBufferPool) .build(); - icsHandler = buildInputChannelStateHandler(inputGate); + icsHandler = buildNoSpillingHandler(inputGate); channelInfo = new InputChannelInfo(0, 0); } - private InputChannelRecoveredStateHandler buildInputChannelStateHandler( + private AbstractInputChannelRecoveredStateHandler buildNoSpillingHandler( SingleInputGate inputGate) { - return new InputChannelRecoveredStateHandler( + return AbstractInputChannelRecoveredStateHandler.create( new InputGate[] {inputGate}, new InflightDataRescalingDescriptor( new InflightDataRescalingDescriptor @@ -82,11 +87,13 @@ private InputChannelRecoveredStateHandler buildInputChannelStateHandler( .InflightDataGateOrPartitionRescalingDescriptor .MappingType.IDENTITY) }), + false, null, - MemoryManager.DEFAULT_PAGE_SIZE); + MemoryManager.DEFAULT_PAGE_SIZE, + null); } - private InputChannelRecoveredStateHandler buildMultiChannelHandler() { + private AbstractInputChannelRecoveredStateHandler buildMultiChannelHandler() { // Setup multi-channel scenario to trigger distribution constraint validation SingleInputGate multiChannelGate = new SingleInputGateBuilder() @@ -95,7 +102,7 @@ private InputChannelRecoveredStateHandler buildMultiChannelHandler() { .setSegmentProvider(networkBufferPool) .build(); - return new InputChannelRecoveredStateHandler( + return AbstractInputChannelRecoveredStateHandler.create( new InputGate[] {multiChannelGate}, new InflightDataRescalingDescriptor( new InflightDataRescalingDescriptor @@ -110,18 +117,43 @@ private InputChannelRecoveredStateHandler buildMultiChannelHandler() { .InflightDataGateOrPartitionRescalingDescriptor .MappingType.RESCALING) }), + false, null, - MemoryManager.DEFAULT_PAGE_SIZE); + MemoryManager.DEFAULT_PAGE_SIZE, + null); } /** Builds a handler in filtering mode (non-null filtering handler, no-op stub). */ - private InputChannelRecoveredStateHandler buildFilteringInputChannelStateHandler() { + private SpillingWithFilteringHandler buildFilteringInputChannelStateHandler() { // Empty GateFilterHandler array: filtering is "enabled" structurally, but no gate-level // filter logic runs. Suitable for exercising getBuffer() routing only. ChannelStateFilteringHandler stubFilteringHandler = new ChannelStateFilteringHandler( new ChannelStateFilteringHandler.GateFilterHandler[0]); - return new InputChannelRecoveredStateHandler( + return (SpillingWithFilteringHandler) + AbstractInputChannelRecoveredStateHandler.create( + new InputGate[] {inputGate}, + new InflightDataRescalingDescriptor( + new InflightDataRescalingDescriptor + .InflightDataGateOrPartitionRescalingDescriptor[] { + new InflightDataRescalingDescriptor + .InflightDataGateOrPartitionRescalingDescriptor( + new int[] {1}, + RescaleMappings.identity(1, 1), + new HashSet<>(), + InflightDataRescalingDescriptor + .InflightDataGateOrPartitionRescalingDescriptor + .MappingType.IDENTITY) + }), + true, + stubFilteringHandler, + MemoryManager.DEFAULT_PAGE_SIZE, + new String[] {tmpDir.toAbsolutePath().toString()}); + } + + private AbstractInputChannelRecoveredStateHandler buildSpillingNoFilteringHandler( + String[] spillTmpDirectories) { + return AbstractInputChannelRecoveredStateHandler.create( new InputGate[] {inputGate}, new InflightDataRescalingDescriptor( new InflightDataRescalingDescriptor @@ -135,14 +167,16 @@ private InputChannelRecoveredStateHandler buildFilteringInputChannelStateHandler .InflightDataGateOrPartitionRescalingDescriptor .MappingType.IDENTITY) }), - stubFilteringHandler, - MemoryManager.DEFAULT_PAGE_SIZE); + true, + null, + MemoryManager.DEFAULT_PAGE_SIZE, + spillTmpDirectories); } @Test void testBufferDistributedToMultipleInputChannelsThrowsException() throws Exception { // Test constraint that prevents buffer distribution to multiple channels - try (InputChannelRecoveredStateHandler handler = buildMultiChannelHandler()) { + try (AbstractInputChannelRecoveredStateHandler handler = buildMultiChannelHandler()) { assertThatThrownBy(() -> handler.getBuffer(channelInfo)) .isInstanceOf(IllegalStateException.class) .hasMessageContaining( @@ -187,7 +221,7 @@ void testRecycleBufferAfterRecoverWasCalled() throws Exception { @Test void testPreFilterBufferIsolationFromNetworkBufferPool() throws Exception { - try (InputChannelRecoveredStateHandler filteringHandler = + try (SpillingWithFilteringHandler filteringHandler = buildFilteringInputChannelStateHandler()) { int availableBefore = networkBufferPool.getNumberOfAvailableMemorySegments(); @@ -229,7 +263,7 @@ void testNonFilteringModeUsesNetworkBufferPool() throws Exception { @Test void testPreFilterSegmentReusedAcrossCalls() throws Exception { - try (InputChannelRecoveredStateHandler filteringHandler = + try (SpillingWithFilteringHandler filteringHandler = buildFilteringInputChannelStateHandler()) { // First getBuffer() lazily allocates the segment. RecoveredChannelStateHandler.BufferWithContext first = @@ -259,7 +293,7 @@ void testPreFilterSegmentReusedAcrossCalls() throws Exception { @Test void testGetBufferThrowsWhenPriorBufferNotRecycled() throws Exception { - try (InputChannelRecoveredStateHandler filteringHandler = + try (SpillingWithFilteringHandler filteringHandler = buildFilteringInputChannelStateHandler()) { RecoveredChannelStateHandler.BufferWithContext first = filteringHandler.getBuffer(channelInfo); @@ -283,8 +317,7 @@ void testGetBufferThrowsWhenPriorBufferNotRecycled() throws Exception { @Test void testPreFilterSegmentFreedOnClose() throws Exception { - InputChannelRecoveredStateHandler filteringHandler = - buildFilteringInputChannelStateHandler(); + SpillingWithFilteringHandler filteringHandler = buildFilteringInputChannelStateHandler(); RecoveredChannelStateHandler.BufferWithContext bufferWithContext = filteringHandler.getBuffer(channelInfo); bufferWithContext.context.recycleBuffer(); @@ -298,4 +331,13 @@ void testPreFilterSegmentFreedOnClose() throws Exception { assertThat(segment.isFreed()).isTrue(); assertThat(filteringHandler.getPreFilterSegmentForTesting()).isNull(); } + + @Test + void testSpillingHandlerRequiresSpillDirectories() { + assertThatThrownBy(() -> buildSpillingNoFilteringHandler(null)) + .isInstanceOf(NullPointerException.class); + assertThatThrownBy(() -> buildSpillingNoFilteringHandler(new String[0])) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("spillTmpDirectories must not be empty"); + } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/MockChannelStateWriter.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/MockChannelStateWriter.java index c77208f3ff7493..bdff6d44718f1f 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/MockChannelStateWriter.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/MockChannelStateWriter.java @@ -74,6 +74,16 @@ public void addInputData( } } + @Override + public void addInputDataFromSpill(long checkpointId, FetchedChannelStateReader reader) { + checkCheckpointId(checkpointId); + try { + reader.close(); + } catch (Exception e) { + rethrow(e); + } + } + @Override public void addOutputData( long checkpointId, ResultSubpartitionInfo info, int startSeqNum, Buffer... data) { diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/RecoveredChannelStateHandlerFilterRoutingTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/RecoveredChannelStateHandlerFilterRoutingTest.java new file mode 100644 index 00000000000000..493e88026dbeec --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/RecoveredChannelStateHandlerFilterRoutingTest.java @@ -0,0 +1,308 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint.channel; + +import org.apache.flink.api.common.typeutils.base.LongSerializer; +import org.apache.flink.core.memory.DataOutputSerializer; +import org.apache.flink.core.memory.MemorySegment; +import org.apache.flink.core.memory.MemorySegmentFactory; +import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor; +import org.apache.flink.runtime.checkpoint.RescaleMappings; +import org.apache.flink.runtime.io.network.api.SubtaskConnectionDescriptor; +import org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer; +import org.apache.flink.runtime.io.network.api.serialization.SpillingAdaptiveSpanningRecordDeserializer; +import org.apache.flink.runtime.io.network.buffer.Buffer; +import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler; +import org.apache.flink.runtime.io.network.buffer.NetworkBuffer; +import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool; +import org.apache.flink.runtime.io.network.partition.consumer.InputChannel; +import org.apache.flink.runtime.io.network.partition.consumer.InputChannelBuilder; +import org.apache.flink.runtime.io.network.partition.consumer.InputGate; +import org.apache.flink.runtime.io.network.partition.consumer.RecoveredInputChannel; +import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate; +import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGateBuilder; +import org.apache.flink.runtime.memory.MemoryManager; +import org.apache.flink.runtime.plugable.DeserializationDelegate; +import org.apache.flink.streaming.runtime.io.recovery.RecordFilter; +import org.apache.flink.streaming.runtime.io.recovery.VirtualChannel; +import org.apache.flink.streaming.runtime.streamrecord.StreamElement; +import org.apache.flink.streaming.runtime.streamrecord.StreamElementSerializer; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import java.io.IOException; +import java.nio.file.Path; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Optional; + +import static org.assertj.core.api.Assertions.assertThat; + +class RecoveredChannelStateHandlerFilterRoutingTest { + + @TempDir Path tempDir; + + private NetworkBufferPool networkBufferPool; + private SingleInputGate inputGate; + private InputChannelInfo channelInfo; + + @BeforeEach + void setUp() { + // Plenty of segments so the filter-on path does not deadlock on the bounded pool. + networkBufferPool = new NetworkBufferPool(64, MemoryManager.DEFAULT_PAGE_SIZE); + inputGate = + new SingleInputGateBuilder() + .setChannelFactory(InputChannelBuilder::buildLocalRecoveredChannel) + .setSegmentProvider(networkBufferPool) + .build(); + channelInfo = new InputChannelInfo(0, 0); + } + + @AfterEach + void tearDown() throws Exception { + inputGate.close(); + networkBufferPool.destroy(); + } + + @Test + void testFilterOnRoutesOutputToChannelState() throws Exception { + ChannelStateFilteringHandler filteringHandler = newPassThroughFilteringHandler(); + try (ChannelStateFilteringHandler ignored = filteringHandler) { + SpillingWithFilteringHandler handler = newFilterOnHandler(filteringHandler); + invokeRecoverWithRecords(handler, 1L, 2L, 3L); + + // Surviving records accumulate in the in-memory segment serializer; they are only + // sealed and flushed to a spill file on channel switch or close. With a single channel + // and no switch, close() is what seals the segment, so the assertion must follow it. + handler.close(); + + assertThat(handler.peekSpillFilesForTesting()) + .as("filter-on path must spill the surviving records to a file") + .isNotEmpty(); + } + } + + @Test + void testFilterOnAccumulatorBuffersComeFromHeapNotPool() throws Exception { + // The accumulator's prefilter + postfilter buffers are unpooled heap segments owned by + // the handler — invoking filter recovery must NOT consume any network buffer pool + // segments for the accumulator path. Channel-side exclusive buffer reservation, if any, + // happens at handler construction and is already reflected in the pre-recover snapshot. + ChannelStateFilteringHandler filteringHandler = newPassThroughFilteringHandler(); + SpillingWithFilteringHandler handler = newFilterOnHandler(filteringHandler); + try (ChannelStateFilteringHandler ignored = filteringHandler) { + int availableBeforeRecover = networkBufferPool.getNumberOfAvailableMemorySegments(); + + invokeRecoverWithRecords(handler, 1L, 2L, 3L); + + int availableAfterRecover = networkBufferPool.getNumberOfAvailableMemorySegments(); + assertThat(availableAfterRecover) + .as("filter accumulator buffers must not be sourced from the network pool") + .isEqualTo(availableBeforeRecover); + + handler.close(); + inputGate.close(); + assertThat(networkBufferPool.getNumberOfAvailableMemorySegments()) + .as("pool count after close must match pre-recover (filter took nothing)") + .isEqualTo(availableBeforeRecover); + } + } + + @Test + void testFilterOnDoesNotInvokeChannelOnRecoveredStateBuffer() throws Exception { + ChannelStateFilteringHandler filteringHandler = newPassThroughFilteringHandler(); + try (ChannelStateFilteringHandler ignored = filteringHandler; + SpillingWithFilteringHandler handler = newFilterOnHandler(filteringHandler)) { + invokeRecoverWithRecords(handler, 1L, 2L, 3L); + + int queuedDuringRecovery = countQueuedRecoveredBuffers(); + assertThat(queuedDuringRecovery) + .as("filter-on must not enqueue buffers into the channel during recovery") + .isEqualTo(0); + } + } + + @Test + void testFilterOffDoesNotCreateChannelState() throws Exception { + try (NoSpillingHandler handler = newFilterOffHandler()) { + invokeRecoverWithRawBytes(handler, new byte[] {1, 2, 3, 4}); + + // NoSpillingHandler has no channelStateWriter, so peek always returns null. + assertThat(handler.getProducedChannelState()) + .as("filter-off close() must not publish a FetchedChannelState either") + .isNull(); + } + } + + @Test + void testFilterOffMaintainsMasterBehavior() throws Exception { + try (NoSpillingHandler handler = newFilterOffHandler()) { + invokeRecoverWithRawBytes(handler, new byte[] {1, 2, 3, 4}); + + // Filter-off path enqueues the SubtaskConnectionDescriptor event plus the data + // buffer directly into the channel's recoveredBuffers. + int queued = countQueuedRecoveredBuffers(); + assertThat(queued) + .as("filter-off must enqueue the descriptor + data buffer into the channel") + .isGreaterThanOrEqualTo(2); + } + } + + private SpillingWithFilteringHandler newFilterOnHandler( + ChannelStateFilteringHandler filteringHandler) { + return (SpillingWithFilteringHandler) + AbstractInputChannelRecoveredStateHandler.create( + new InputGate[] {inputGate}, + identityRescalingForOneGate(), + true, + filteringHandler, + MemoryManager.DEFAULT_PAGE_SIZE, + new String[] {tempDir.toString()}); + } + + private NoSpillingHandler newFilterOffHandler() { + return (NoSpillingHandler) + AbstractInputChannelRecoveredStateHandler.create( + new InputGate[] {inputGate}, + identityRescalingForOneGate(), + false, + null, + MemoryManager.DEFAULT_PAGE_SIZE, + null); + } + + private static InflightDataRescalingDescriptor identityRescalingForOneGate() { + return new InflightDataRescalingDescriptor( + new InflightDataRescalingDescriptor.InflightDataGateOrPartitionRescalingDescriptor + [] { + new InflightDataRescalingDescriptor + .InflightDataGateOrPartitionRescalingDescriptor( + new int[] {1}, + RescaleMappings.identity(1, 1), + new HashSet<>(), + InflightDataRescalingDescriptor + .InflightDataGateOrPartitionRescalingDescriptor.MappingType + .IDENTITY) + }); + } + + private ChannelStateFilteringHandler newPassThroughFilteringHandler() { + StreamElementSerializer serializer = + new StreamElementSerializer<>(LongSerializer.INSTANCE); + RecordDeserializer> deserializer = + new SpillingAdaptiveSpanningRecordDeserializer<>( + new String[] {System.getProperty("java.io.tmpdir")}); + VirtualChannel vc = new VirtualChannel<>(deserializer, RecordFilter.acceptAll()); + Map> channels = new HashMap<>(); + // The handler will invoke filterAndRewrite with oldSubtaskIndex=1 — keep the key aligned. + channels.put(new SubtaskConnectionDescriptor(1, channelInfo.getInputChannelIdx()), vc); + + ChannelStateFilteringHandler.GateFilterHandler gateHandler = + new ChannelStateFilteringHandler.GateFilterHandler<>(channels, serializer); + return new ChannelStateFilteringHandler( + new ChannelStateFilteringHandler.GateFilterHandler[] {gateHandler}); + } + + private void invokeRecoverWithRecords( + AbstractInputChannelRecoveredStateHandler handler, Long... values) throws Exception { + Buffer source = createRecordBuffer(values); + invokeRecoverWithBuffer(handler, source); + } + + private void invokeRecoverWithRawBytes( + AbstractInputChannelRecoveredStateHandler handler, byte[] data) throws Exception { + MemorySegment seg = MemorySegmentFactory.allocateUnpooledSegment(data.length); + seg.put(0, data); + NetworkBuffer source = new NetworkBuffer(seg, FreeingBufferRecycler.INSTANCE); + source.setSize(data.length); + invokeRecoverWithBuffer(handler, source); + } + + /** + * Mirrors the chunkReader's getBuffer + recover sequence: handler-issued buffer is filled with + * the source data, then handed back to recover. + */ + private void invokeRecoverWithBuffer( + AbstractInputChannelRecoveredStateHandler handler, Buffer source) throws Exception { + RecoveredChannelStateHandler.BufferWithContext bwc = handler.getBuffer(channelInfo); + try { + Buffer dest = bwc.context; + int len = source.readableBytes(); + source.getMemorySegment() + .copyTo( + source.getMemorySegmentOffset(), + dest.getMemorySegment(), + dest.getMemorySegmentOffset(), + len); + dest.setSize(len); + } finally { + source.recycleBuffer(); + } + // oldSubtaskIndex=1 matches the pass-through filter's virtual channel key. The filter-off + // path ignores this argument's mapping (it only flows into a SubtaskConnectionDescriptor). + handler.recover(channelInfo, 1, bwc); + } + + private static Buffer createRecordBuffer(Long... values) throws IOException { + StreamElementSerializer serializer = + new StreamElementSerializer<>(LongSerializer.INSTANCE); + DataOutputSerializer output = new DataOutputSerializer(256); + for (Long v : values) { + DataOutputSerializer rec = new DataOutputSerializer(64); + serializer.serialize(new StreamRecord<>(v), rec); + int recordLength = rec.length(); + output.writeInt(recordLength); + output.write(rec.getSharedBuffer(), 0, recordLength); + } + byte[] data = output.getCopyOfBuffer(); + MemorySegment segment = + MemorySegmentFactory.allocateUnpooledSegment(MemoryManager.DEFAULT_PAGE_SIZE); + segment.put(0, data, 0, data.length); + NetworkBuffer buf = new NetworkBuffer(segment, FreeingBufferRecycler.INSTANCE); + buf.setSize(data.length); + return buf; + } + + /** + * Counts the number of buffers currently queued in the only recovered input channel by draining + * via the public {@code getNextBuffer()} entry point. After this helper returns the channel + * queue is empty by definition; tests should not call it twice expecting the same answer. + */ + private int countQueuedRecoveredBuffers() throws IOException { + RecoveredInputChannel ch = (RecoveredInputChannel) inputGate.getChannel(0); + int count = 0; + while (true) { + Optional next = ch.getNextBuffer(); + if (!next.isPresent()) { + break; + } + count++; + next.get().buffer().recycleBuffer(); + if (count > 1000) { + throw new IllegalStateException("Unexpected unbounded queue contents"); + } + } + return count; + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/RecoveredChannelStateHandlerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/RecoveredChannelStateHandlerTest.java index 01f7c43920ae94..db302d23f73376 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/RecoveredChannelStateHandlerTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/RecoveredChannelStateHandlerTest.java @@ -22,7 +22,7 @@ /** * Base class which contains all tests which should be implemented for every implementation of - * {@link InputChannelRecoveredStateHandler}. + * {@link AbstractInputChannelRecoveredStateHandler}. */ abstract class RecoveredChannelStateHandlerTest { diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/RecoveryCheckpointBarrierTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/RecoveryCheckpointBarrierTest.java new file mode 100644 index 00000000000000..0dd2c18286a7ac --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/RecoveryCheckpointBarrierTest.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint.channel; + +import org.apache.flink.core.memory.DataOutputSerializer; +import org.apache.flink.runtime.io.network.api.serialization.EventSerializer; +import org.apache.flink.runtime.io.network.buffer.Buffer; + +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** Tests for {@link RecoveryCheckpointBarrier}. */ +class RecoveryCheckpointBarrierTest { + + @Test + void testReflectiveWriteIsUnsupported() { + RecoveryCheckpointBarrier barrier = new RecoveryCheckpointBarrier(1L); + + assertThatThrownBy(() -> barrier.write(new DataOutputSerializer(Long.BYTES))) + .isInstanceOf(UnsupportedOperationException.class) + .hasMessageContaining("dedicated type-tag path"); + } + + @Test + void testEventSerializerHandlesRecoveryCheckpointBarrier() throws Exception { + long checkpointId = 123L; + + Buffer buffer = + EventSerializer.toBuffer(new RecoveryCheckpointBarrier(checkpointId), false); + Object deserialized = + EventSerializer.fromBuffer( + buffer, RecoveryCheckpointBarrier.class.getClassLoader()); + + assertThat(deserialized).isEqualTo(new RecoveryCheckpointBarrier(checkpointId)); + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/RescaleFilterLargeRecordOOMRegressionITCase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/RescaleFilterLargeRecordOOMRegressionITCase.java new file mode 100644 index 00000000000000..9078b682bb2b2f --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/RescaleFilterLargeRecordOOMRegressionITCase.java @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.flink.runtime.checkpoint.channel; + +import org.apache.flink.runtime.checkpoint.channel.FetchedChannelStateReader.SpillSegment; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Optional; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Regression coverage for the heap-blowup scenario produced by rescale + filter + a large recovered + * record: a workload whose total recovered bytes greatly exceed any single buffer must spill to + * disk rather than pin the bytes on the task heap. + * + *

A full {@code MiniCluster} reproduction would need a tuned heap and a stateful rescale job + * graph, which is too heavy for a unit-style test. This ITCase asserts the memory-bound invariant + * directly: bytes land on disk in bounded segments and the writer keeps no per-record heap + * allocation. + * + *

Segment boundaries are self-described in the 12-byte disk headers + * ([gateIdx][channelIdx][bufferLength]); no in-memory segment locator table is maintained. + */ +class RescaleFilterLargeRecordOOMRegressionITCase { + + @TempDir Path tempDir; + + /** + * Verifies that large records for a single channel land on disk rather than remaining on the + * heap. The disk-landing invariant is confirmed by: + * + *

    + *
  1. Reading back the data via the reader and confirming byte count and content. + *
  2. The spill file's physical size exceeds the rotation cap (segment not split). + *
+ * + *

File count is intentionally not asserted here: a single channel produces one segment, and + * a segment is never split across files. With all writes going to channel (0,0), there is + * exactly one segment, which resides entirely in one file regardless of its size. That is the + * correct behavior per the "segment does not cross files" design invariant. + */ + @Test + void testLargeRecordsLandOnDiskNotHeap() throws Exception { + // Total bytes greatly exceed a single segment size to exercise the spill path. + final long segmentSize = 4L * 1024 * 1024; + final int largeRecordSize = 256 * 1024; + final int recordCount = 64; + // Each writeRecord call writes a 4-byte length prefix plus the record bytes. + final long expectedBodyBytes = (long) recordCount * (Integer.BYTES + largeRecordSize); + assertThat(expectedBodyBytes) + .as("workload exceeds the segment-size cap") + .isGreaterThan(segmentSize); + + FetchedChannelState state; + try (TestSpillWriter writer = new TestSpillWriter(tempDir, segmentSize)) { + InputChannelInfo channelInfo = new InputChannelInfo(0, 0); + byte[] reusableRecord = new byte[largeRecordSize]; + for (int i = 0; i < reusableRecord.length; i++) { + reusableRecord[i] = (byte) (i & 0xff); + } + for (int i = 0; i < recordCount; i++) { + // The reusable byte array is mutated and reused across writes — the writer must + // flush through to disk without retaining a reference, otherwise per-record heap + // pressure would accumulate. + reusableRecord[0] = (byte) i; + writer.writeRecord(channelInfo, reusableRecord, reusableRecord.length); + } + state = writer.getChannelState(); + } + + // Verify data is physically on disk and can be read back correctly. + // The reader holds a lifecycle grant; closing it triggers file deletion. + long totalReadBytes = 0L; + try (FetchedChannelStateReader reader = state.reader()) { + // One channel => one segment => one file (segment not split across files). + assertThat(state.files()) + .as("single channel produces exactly one spill file") + .hasSize(1); + long physicalFileSize = Files.size(state.files().get(0)); + // File size = 12B header + expectedBodyBytes + assertThat(physicalFileSize) + .as("physical file size exceeds the rotation cap (expected: segment not split)") + .isGreaterThan(segmentSize); + + Optional next; + while ((next = reader.nextSegment()).isPresent()) { + SpillSegment seg = next.get(); + try (InputStream body = seg.bodyStream()) { + byte[] buf = new byte[4096]; + int read; + while ((read = body.read(buf)) != -1) { + totalReadBytes += read; + } + } + seg.commit(); + } + } + assertThat(totalReadBytes) + .as("all written body bytes must be readable from disk") + .isEqualTo(expectedBodyBytes); + } + + /** + * Verifies that multiple channels interleaved across writes do trigger file rotation, producing + * more than one file when the total written bytes exceed the rotation cap. This is the "segment + * switch triggers rotation" path complementing the single-channel test above. + */ + @Test + void testMultiChannelWritesTriggerFileRotation() throws IOException { + // Small cap so that alternating two channels quickly crosses the threshold. + final long segmentSize = 512L * 1024; + final int recordSize = 64 * 1024; + final int roundCount = 10; // each round writes one record per channel + + FetchedChannelState state; + try (TestSpillWriter writer = new TestSpillWriter(tempDir, segmentSize)) { + InputChannelInfo c0 = new InputChannelInfo(0, 0); + InputChannelInfo c1 = new InputChannelInfo(0, 1); + byte[] record = new byte[recordSize]; + for (int round = 0; round < roundCount; round++) { + writer.writeRecord(c0, record, record.length); + writer.writeRecord(c1, record, record.length); + } + state = writer.getChannelState(); + } + + // Alternating channels produce multiple segments (one file per seal above the cap). + assertThat(state.files()) + .as("total size exceeding the cap must trigger file rotation into multiple files") + .hasSizeGreaterThanOrEqualTo(2); + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImplTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImplTest.java index d80442b8a06f8d..e11379624bf81b 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImplTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImplTest.java @@ -51,10 +51,12 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.TestTemplate; import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.api.io.TempDir; import java.io.ByteArrayOutputStream; import java.io.DataOutputStream; import java.io.IOException; +import java.nio.file.Path; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -90,6 +92,8 @@ public static List parameters() { new Object[] {"ReadPermutedStateWithIncreasedBuffer", 10, 10, 10, 10, 20}); } + @TempDir static Path tmpDir; + @Parameter public String desc; @Parameter(value = 1) @@ -144,7 +148,16 @@ void testReadPermutedState() throws Exception { withInputGates( gates -> { - reader.readInputData(gates, RecordFilterContext.disabled()); + reader.readInputData( + gates, + RecordFilterContext.disabled( + new String[] {tmpDir.toAbsolutePath().toString()})); + // Mirror the production legacy path (StreamTask#readInputChannelState): the + // EndOfInputChannelStateEvent sentinel that completes stateConsumedFuture is + // enqueued by finishReadRecoveredState(), not by readInputData() itself. + for (InputGate gate : gates) { + gate.finishReadRecoveredState(); + } assertBuffersEquals(inputChannelsData, collectBuffers(gates)); assertConsumed(gates); }); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/TestSpillWriter.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/TestSpillWriter.java new file mode 100644 index 00000000000000..4b3f1b9aa64968 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/TestSpillWriter.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint.channel; + +import org.apache.flink.core.memory.DataOutputSerializer; +import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor; +import org.apache.flink.runtime.io.network.partition.consumer.InputGate; + +import java.io.Closeable; +import java.io.IOException; +import java.nio.file.Path; + +/** + * Test-only helper that produces spill files in the {@link AbstractSpillingHandler} on-disk format, + * exposing the same {@code writeRecord} / {@code writePassThrough} surface the readers and drainers + * are tested against. It drives a minimal concrete {@link AbstractSpillingHandler} so tests need + * not stand up real input gates or run the recovery loop. + */ +final class TestSpillWriter implements Closeable { + + private final FormatHandler handler; + + TestSpillWriter(Path baseDir) { + this(baseDir, AbstractSpillingHandler.DEFAULT_SPILL_FILE_SIZE_BYTES); + } + + TestSpillWriter(Path baseDir, long maxFileSizeBytes) { + this.handler = new FormatHandler(new String[] {baseDir.toString()}, maxFileSizeBytes); + } + + /** Appends one length-prefixed record, mirroring the filtering path. */ + void writeRecord(InputChannelInfo channelInfo, byte[] record, int recordLength) + throws IOException { + DataOutputSerializer segment = handler.segmentSerializerFor(channelInfo); + segment.writeInt(recordLength); + segment.write(record, 0, recordLength); + } + + /** Appends verbatim bytes, mirroring the pass-through path. */ + void writePassThrough(InputChannelInfo channelInfo, byte[] data, int offset, int length) + throws IOException { + handler.segmentSerializerFor(channelInfo).write(data, offset, length); + } + + /** Opens (or switches to) a segment without writing any body, to exercise empty segments. */ + void openSegment(InputChannelInfo channelInfo) throws IOException { + handler.segmentSerializerFor(channelInfo); + } + + /** + * Seals the spilled segments and returns the produced state, already holding one lifecycle + * grant for the caller. Returns {@code null} if nothing was ever written. + */ + FetchedChannelState getChannelState() throws IOException { + handler.close(); + return handler.getProducedChannelState(); + } + + @Override + public void close() throws IOException { + handler.close(); + } + + private static final class FormatHandler extends AbstractSpillingHandler { + + FormatHandler(String[] spillTmpDirectories, long maxFileSizeBytes) { + super( + new InputGate[0], + InflightDataRescalingDescriptor.NO_RESCALE, + spillTmpDirectories, + maxFileSizeBytes); + } + + @Override + public void recover( + InputChannelInfo info, + int oldSubtaskIndex, + BufferWithContext ctx) { + throw new UnsupportedOperationException("not used in format tests"); + } + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/UnalignedCheckpointDuringRecoveryITCase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/UnalignedCheckpointDuringRecoveryITCase.java new file mode 100644 index 00000000000000..b073f16946d140 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/UnalignedCheckpointDuringRecoveryITCase.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.flink.runtime.checkpoint.channel; + +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration coverage for the recovery-checkpoint dispatcher. End-to-end rescaling coverage + * against a real {@code MiniCluster} lives in {@code UnalignedCheckpointRescaleITCase}; this class + * pins the disjoint-and-complete invariant that the recovery-time checkpoint slice must satisfy, + * using a unit-style fixture that mirrors the recovery-checkpoint slice produced by the production + * drain/snapshot path. + */ +class UnalignedCheckpointDuringRecoveryITCase { + + @Test + void testStep1SnapshotPlusStep2PreBarrierBytesEqualOriginal() { + // Fixture: the recovery filter wrote a sequence of buffers to disk; the drain has + // delivered entries 0..2 into channel queues, while entries 3..6 remain on disk. The + // dispatcher snapshots the on-disk slice and inserts barriers into channel queues, so + // the in-channel pre-barrier portion plus the on-disk slice together must cover every + // original byte exactly once. + InputChannelInfo c0 = new InputChannelInfo(0, 0); + InputChannelInfo c1 = new InputChannelInfo(0, 1); + + List originalSeq = + Arrays.asList( + new RecoveredEntry(c0, new byte[] {1, 2}), + new RecoveredEntry(c1, new byte[] {3}), + new RecoveredEntry(c0, new byte[] {4, 5, 6}), + new RecoveredEntry(c1, new byte[] {7, 8}), + new RecoveredEntry(c0, new byte[] {9}), + new RecoveredEntry(c0, new byte[] {10, 11}), + new RecoveredEntry(c1, new byte[] {12})); + + // First three entries are still in channel queues (the dispatcher's per-input walk + // covers them); the remaining four sit on disk (the writer's async demux covers them). + List step2Sources = originalSeq.subList(0, 3); + List step3Sources = originalSeq.subList(3, originalSeq.size()); + + Map persistedByChannel = new HashMap<>(); + for (RecoveredEntry entry : step2Sources) { + persistedByChannel.merge(entry.channelInfo, entry.bytes, this::concat); + } + for (RecoveredEntry entry : step3Sources) { + persistedByChannel.merge(entry.channelInfo, entry.bytes, this::concat); + } + + // Persisted per-channel bytes must equal the concatenation of the original sequence, + // regardless of which source produced each byte — no duplication, no gaps. + Map expected = new HashMap<>(); + for (RecoveredEntry entry : originalSeq) { + expected.merge(entry.channelInfo, entry.bytes, this::concat); + } + assertThat(persistedByChannel.keySet()).isEqualTo(expected.keySet()); + for (InputChannelInfo info : expected.keySet()) { + assertThat(persistedByChannel.get(info)).isEqualTo(expected.get(info)); + } + } + + @Test + void testEmptyDiskSnapshotReaderCloseIsClean() throws Exception { + // An empty reader (no spill files) must be closeable without error. This guards the + // fixture against silently swallowing close() failures on the zero-segment path. + FetchedChannelStateReader emptyReader = FetchedChannelStateReader.emptyReader(); + assertThat(emptyReader.nextSegment()).isEmpty(); + emptyReader.close(); + } + + private byte[] concat(byte[] a, byte[] b) { + byte[] out = new byte[a.length + b.length]; + System.arraycopy(a, 0, out, 0, a.length); + System.arraycopy(b, 0, out, a.length, b.length); + return out; + } + + private static final class RecoveredEntry { + final InputChannelInfo channelInfo; + final byte[] bytes; + + RecoveredEntry(InputChannelInfo channelInfo, byte[] bytes) { + this.channelInfo = channelInfo; + this.bytes = bytes; + } + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/CancelPartitionRequestTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/CancelPartitionRequestTest.java index 2adbded7d9e90e..d165f6a50cb4d8 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/CancelPartitionRequestTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/CancelPartitionRequestTest.java @@ -106,7 +106,8 @@ void testCancelPartitionRequest() throws Exception { pid, new ResultSubpartitionIndexSet(0), new InputChannelID(), - Integer.MAX_VALUE)) + Integer.MAX_VALUE, + false)) .await(); // Wait for the notification @@ -170,7 +171,8 @@ void testDuplicateCancel() throws Exception { pid, new ResultSubpartitionIndexSet(0), inputChannelId, - Integer.MAX_VALUE)) + Integer.MAX_VALUE, + false)) .await(); // Wait for the notification diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/CreditBasedPartitionRequestClientHandlerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/CreditBasedPartitionRequestClientHandlerTest.java index d4b162304b8e6b..d58018036961a3 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/CreditBasedPartitionRequestClientHandlerTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/CreditBasedPartitionRequestClientHandlerTest.java @@ -68,7 +68,6 @@ import java.io.IOException; import java.net.InetSocketAddress; -import java.util.ArrayDeque; import java.util.stream.Stream; import static org.apache.flink.runtime.io.network.netty.PartitionRequestQueueTest.blockChannel; @@ -956,7 +955,7 @@ private static class TestRemoteInputChannelForError extends RemoteInputChannel { new SimpleCounter(), new SimpleCounter(), ChannelStateWriter.NO_OP, - new ArrayDeque<>()); + false); this.expectedMessage = expectedMessage; } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/CreditBasedSequenceNumberingViewReaderTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/CreditBasedSequenceNumberingViewReaderTest.java index cd4a3103240ee7..ba686adce04959 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/CreditBasedSequenceNumberingViewReaderTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/CreditBasedSequenceNumberingViewReaderTest.java @@ -88,7 +88,7 @@ private static CreditBasedSequenceNumberingViewReader createNetworkSequenceViewR channel.close(); CreditBasedSequenceNumberingViewReader reader = new CreditBasedSequenceNumberingViewReader( - new InputChannelID(), initialCredit, queue); + new InputChannelID(), initialCredit, false, queue); reader.notifySubpartitionsCreated( TestingResultPartition.newBuilder() .setCreateSubpartitionViewFunction( diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/NettyMessageServerSideSerializationTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/NettyMessageServerSideSerializationTest.java index 22b5420b616d2b..f2e0e8146c8da6 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/NettyMessageServerSideSerializationTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/NettyMessageServerSideSerializationTest.java @@ -67,7 +67,8 @@ void testPartitionRequest() { new ResultPartitionID(), new ResultSubpartitionIndexSet(queueIndex), new InputChannelID(), - random.nextInt()); + random.nextInt(), + random.nextBoolean()); NettyMessage.PartitionRequest actual = encodeAndDecode(expected, channel); @@ -75,6 +76,7 @@ void testPartitionRequest() { assertThat(actual.queueIndexSet).isEqualTo(expected.queueIndexSet); assertThat(actual.receiverId).isEqualTo(expected.receiverId); assertThat(actual.credit).isEqualTo(expected.credit); + assertThat(actual.needsRecovery).isEqualTo(expected.needsRecovery); } @Test diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestQueueTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestQueueTest.java index 7046a8518093be..a4e15e4cf90d6c 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestQueueTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestQueueTest.java @@ -90,7 +90,8 @@ public void testNotifyReaderPartitionTimeout() throws Exception { ResultPartitionManager resultPartitionManager = new ResultPartitionManager(); ResultPartitionID resultPartitionId = new ResultPartitionID(); CreditBasedSequenceNumberingViewReader reader = - new CreditBasedSequenceNumberingViewReader(new InputChannelID(0, 0), 10, queue); + new CreditBasedSequenceNumberingViewReader( + new InputChannelID(0, 0), 10, false, queue); reader.requestSubpartitionViewOrRegisterListener( resultPartitionManager, resultPartitionId, new ResultSubpartitionIndexSet(0)); @@ -132,9 +133,11 @@ void testNotifyReaderNonEmptyOnEmptyReaders() throws Exception { EmbeddedChannel channel = new EmbeddedChannel(queue); CreditBasedSequenceNumberingViewReader reader1 = - new CreditBasedSequenceNumberingViewReader(new InputChannelID(0, 0), 10, queue); + new CreditBasedSequenceNumberingViewReader( + new InputChannelID(0, 0), 10, false, queue); CreditBasedSequenceNumberingViewReader reader2 = - new CreditBasedSequenceNumberingViewReader(new InputChannelID(1, 1), 10, queue); + new CreditBasedSequenceNumberingViewReader( + new InputChannelID(1, 1), 10, false, queue); ResultSubpartitionView view1 = new EmptyAlwaysAvailableResultSubpartitionView(); reader1.notifySubpartitionsCreated( @@ -192,7 +195,8 @@ private void testBufferWriting(ResultSubpartitionView view) throws IOException { final InputChannelID receiverId = new InputChannelID(); final PartitionRequestQueue queue = new PartitionRequestQueue(); final CreditBasedSequenceNumberingViewReader reader = - new CreditBasedSequenceNumberingViewReader(receiverId, Integer.MAX_VALUE, queue); + new CreditBasedSequenceNumberingViewReader( + receiverId, Integer.MAX_VALUE, false, queue); final EmbeddedChannel channel = new EmbeddedChannel(queue); reader.notifySubpartitionsCreated(partition, new ResultSubpartitionIndexSet(0)); @@ -287,7 +291,7 @@ void testEnqueueReaderByNotifyingEventBuffer() throws Exception { final InputChannelID receiverId = new InputChannelID(); final PartitionRequestQueue queue = new PartitionRequestQueue(); final CreditBasedSequenceNumberingViewReader reader = - new CreditBasedSequenceNumberingViewReader(receiverId, 0, queue); + new CreditBasedSequenceNumberingViewReader(receiverId, 0, false, queue); final EmbeddedChannel channel = new EmbeddedChannel(queue); reader.notifySubpartitionsCreated(partition, new ResultSubpartitionIndexSet(0)); @@ -340,7 +344,7 @@ void testEnqueueReaderByNotifyingBufferAndCredit() throws Exception { final InputChannelID receiverId = new InputChannelID(); final PartitionRequestQueue queue = new PartitionRequestQueue(); final CreditBasedSequenceNumberingViewReader reader = - new CreditBasedSequenceNumberingViewReader(receiverId, 2, queue); + new CreditBasedSequenceNumberingViewReader(receiverId, 2, false, queue); final EmbeddedChannel channel = new EmbeddedChannel(queue); reader.addCredit(-2); @@ -421,7 +425,7 @@ void testEnqueueReaderByResumingConsumption() throws Exception { InputChannelID receiverId = new InputChannelID(); PartitionRequestQueue queue = new PartitionRequestQueue(); CreditBasedSequenceNumberingViewReader reader = - new CreditBasedSequenceNumberingViewReader(receiverId, 2, queue); + new CreditBasedSequenceNumberingViewReader(receiverId, 2, false, queue); EmbeddedChannel channel = new EmbeddedChannel(queue); reader.notifySubpartitionsCreated(partition, new ResultSubpartitionIndexSet(0)); @@ -461,7 +465,7 @@ void testAnnounceBacklog() throws Exception { PartitionRequestQueue queue = new PartitionRequestQueue(); InputChannelID receiverId = new InputChannelID(); CreditBasedSequenceNumberingViewReader reader = - new CreditBasedSequenceNumberingViewReader(receiverId, 0, queue); + new CreditBasedSequenceNumberingViewReader(receiverId, 0, false, queue); EmbeddedChannel channel = new EmbeddedChannel(queue); reader.notifySubpartitionsCreated(partition, new ResultSubpartitionIndexSet(0)); @@ -498,7 +502,7 @@ private void testCancelPartitionRequest(boolean isAvailableView) throws Exceptio final InputChannelID receiverId = new InputChannelID(); final PartitionRequestQueue queue = new PartitionRequestQueue(); final CreditBasedSequenceNumberingViewReader reader = - new CreditBasedSequenceNumberingViewReader(receiverId, 2, queue); + new CreditBasedSequenceNumberingViewReader(receiverId, 2, false, queue); final EmbeddedChannel channel = new EmbeddedChannel(queue); reader.notifySubpartitionsCreated(partition, new ResultSubpartitionIndexSet(0)); @@ -546,7 +550,7 @@ void testNotifyNewBufferSize() throws Exception { InputChannelID receiverId = new InputChannelID(); PartitionRequestQueue queue = new PartitionRequestQueue(); CreditBasedSequenceNumberingViewReader reader = - new CreditBasedSequenceNumberingViewReader(receiverId, 2, queue); + new CreditBasedSequenceNumberingViewReader(receiverId, 2, false, queue); EmbeddedChannel channel = new EmbeddedChannel(queue); reader.notifySubpartitionsCreated(partition, new ResultSubpartitionIndexSet(0)); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestRegistrationTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestRegistrationTest.java index e3cfb55e3400ff..86d9588094f9a3 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestRegistrationTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestRegistrationTest.java @@ -44,7 +44,6 @@ import org.junit.jupiter.api.Test; -import java.util.ArrayDeque; import java.util.Optional; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; @@ -96,7 +95,8 @@ void testRegisterResultPartitionBeforeRequest() throws Exception { resultPartition.getPartitionId(), new ResultSubpartitionIndexSet(0), new InputChannelID(), - Integer.MAX_VALUE)) + Integer.MAX_VALUE, + false)) .await(); // Wait for the notification @@ -144,7 +144,8 @@ void testRegisterResultPartitionAfterRequest() throws Exception { resultPartition.getPartitionId(), new ResultSubpartitionIndexSet(0), new InputChannelID(), - Integer.MAX_VALUE)) + Integer.MAX_VALUE, + false)) .await(); // Register result partition after partition request @@ -213,7 +214,8 @@ public void releasePartitionRequestListener( pid, new ResultSubpartitionIndexSet(0), remoteInputChannel.getInputChannelId(), - Integer.MAX_VALUE)) + Integer.MAX_VALUE, + false)) .await(); // Wait for the notification @@ -250,7 +252,7 @@ private static class TestRemoteInputChannelForPartitionNotFound extends RemoteIn new SimpleCounter(), new SimpleCounter(), ChannelStateWriter.NO_OP, - new ArrayDeque<>()); + false); this.latch = latch; } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestServerHandlerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestServerHandlerTest.java index 7f19b199582e7c..1fe0c00e79a188 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestServerHandlerTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestServerHandlerTest.java @@ -81,7 +81,7 @@ void testAcknowledgeAllRecordsProcessed() throws IOException { // Creates and registers the view to netty. NetworkSequenceViewReader viewReader = new CreditBasedSequenceNumberingViewReader( - inputChannelID, 2, partitionRequestQueue); + inputChannelID, 2, false, partitionRequestQueue); viewReader.notifySubpartitionsCreated(resultPartition, new ResultSubpartitionIndexSet(0)); partitionRequestQueue.notifyReaderCreated(viewReader); @@ -149,7 +149,7 @@ private static class TestViewReader extends CreditBasedSequenceNumberingViewRead TestViewReader( InputChannelID receiverId, int initialCredit, PartitionRequestQueue requestQueue) { - super(receiverId, initialCredit, requestQueue); + super(receiverId, initialCredit, false, requestQueue); } @Override diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/ServerTransportErrorHandlingTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/ServerTransportErrorHandlingTest.java index a4f753fe1e7d26..bfc4e01a153cc6 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/ServerTransportErrorHandlingTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/ServerTransportErrorHandlingTest.java @@ -103,7 +103,8 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) { new ResultPartitionID(), new ResultSubpartitionIndexSet(0), new InputChannelID(), - Integer.MAX_VALUE)); + Integer.MAX_VALUE, + false)); // Wait for the notification assertThat(sync.await(TestingUtils.TESTING_DURATION.toMillis(), TimeUnit.MILLISECONDS)) diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/InputChannelBuilder.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/InputChannelBuilder.java index 08f65d9fe72658..b171dc41c74ec2 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/InputChannelBuilder.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/InputChannelBuilder.java @@ -34,7 +34,6 @@ import org.apache.flink.runtime.io.network.partition.ResultSubpartitionIndexSet; import java.net.InetSocketAddress; -import java.util.ArrayDeque; import static org.apache.flink.runtime.io.network.partition.consumer.SingleInputGateTest.TestingResultPartitionManager; @@ -56,6 +55,7 @@ public class InputChannelBuilder { private int maxBackoff = 0; private int partitionRequestListenerTimeout = 0; private int networkBuffersPerChannel = 2; + private boolean needsRecovery = false; private InputChannelMetrics metrics = InputChannelTestUtils.newUnregisteredInputChannelMetrics(); @@ -115,6 +115,11 @@ public InputChannelBuilder setNetworkBuffersPerChannel(int networkBuffersPerChan return this; } + public InputChannelBuilder setNeedsRecovery(boolean needsRecovery) { + this.needsRecovery = needsRecovery; + return this; + } + public InputChannelBuilder setMetrics(InputChannelMetrics metrics) { this.metrics = metrics; return this; @@ -166,7 +171,8 @@ public LocalInputChannel buildLocalChannel(SingleInputGate inputGate) { metrics.getNumBytesInLocalCounter(), metrics.getNumBuffersInLocalCounter(), stateWriter, - new ArrayDeque<>()); + networkBuffersPerChannel, + needsRecovery); } public RemoteInputChannel buildRemoteChannel(SingleInputGate inputGate) { @@ -184,7 +190,7 @@ public RemoteInputChannel buildRemoteChannel(SingleInputGate inputGate) { metrics.getNumBytesInRemoteCounter(), metrics.getNumBuffersInRemoteCounter(), stateWriter, - new ArrayDeque<>()); + needsRecovery); } public LocalRecoveredInputChannel buildLocalRecoveredChannel(SingleInputGate inputGate) { diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannelTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannelTest.java index 86bda9866d204c..a84c5940c9df75 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannelTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannelTest.java @@ -19,10 +19,13 @@ package org.apache.flink.runtime.io.network.partition.consumer; import org.apache.flink.metrics.SimpleCounter; +import org.apache.flink.runtime.checkpoint.CheckpointException; +import org.apache.flink.runtime.checkpoint.CheckpointFailureReason; import org.apache.flink.runtime.checkpoint.CheckpointOptions; import org.apache.flink.runtime.checkpoint.CheckpointType; import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter; import org.apache.flink.runtime.checkpoint.channel.RecordingChannelStateWriter; +import org.apache.flink.runtime.checkpoint.channel.RecoveryCheckpointBarrier; import org.apache.flink.runtime.execution.CancelTaskException; import org.apache.flink.runtime.io.disk.NoOpFileChannelManager; import org.apache.flink.runtime.io.network.TaskEventDispatcher; @@ -61,7 +64,6 @@ import org.mockito.stubbing.Answer; import java.io.IOException; -import java.util.ArrayDeque; import java.util.Collections; import java.util.List; import java.util.Optional; @@ -669,7 +671,7 @@ void testReceivingBuffersInUseBeforeSubpartitionViewInitialization() throws Exce @Test void testGetBuffersInUseCountIncludesToBeConsumedBuffers() throws Exception { - // given: Local input channel with recovered buffers in toBeConsumedBuffers + // given: Local input channel with recovered buffers in the recovery queue ResultSubpartitionView subpartitionView = InputChannelTestUtils.createResultSubpartitionView( createFilledFinishedBufferConsumer(4096), @@ -678,12 +680,6 @@ void testGetBuffersInUseCountIncludesToBeConsumedBuffers() throws Exception { new TestingResultPartitionManager(subpartitionView); final SingleInputGate inputGate = createSingleInputGate(1); - // Create 3 recovered buffers - ArrayDeque recoveredBuffers = new ArrayDeque<>(); - recoveredBuffers.add(TestBufferFactory.createBuffer(32)); - recoveredBuffers.add(TestBufferFactory.createBuffer(32)); - recoveredBuffers.add(TestBufferFactory.createBuffer(32)); - final LocalInputChannel localChannel = new LocalInputChannel( inputGate, @@ -697,10 +693,16 @@ void testGetBuffersInUseCountIncludesToBeConsumedBuffers() throws Exception { new SimpleCounter(), new SimpleCounter(), ChannelStateWriter.NO_OP, - recoveredBuffers); + 2, + true); inputGate.setInputChannels(localChannel); + // Create 3 recovered buffers + localChannel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(32)); + localChannel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(32)); + localChannel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(32)); + // then: Before requesting subpartitions, buffers in use should include recovered buffers assertThat(localChannel.getBuffersInUseCount()).isEqualTo(3); assertThat(localChannel.unsynchronizedGetNumberOfQueuedBuffers()).isEqualTo(3); @@ -718,10 +720,6 @@ void testGetNextBufferWithMigratedRecoveredBuffers() throws Exception { // given: LocalInputChannel with recovered buffers migrated from RecoveredInputChannel SingleInputGate inputGate = createSingleInputGate(1); - ArrayDeque recoveredBuffers = new ArrayDeque<>(); - recoveredBuffers.add(TestBufferFactory.createBuffer(10)); - recoveredBuffers.add(TestBufferFactory.createBuffer(20)); - LocalInputChannel channel = new LocalInputChannel( inputGate, @@ -735,10 +733,16 @@ void testGetNextBufferWithMigratedRecoveredBuffers() throws Exception { new SimpleCounter(), new SimpleCounter(), ChannelStateWriter.NO_OP, - recoveredBuffers); + 2, + true); inputGate.setInputChannels(channel); + channel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(10)); + channel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(20)); + channel.completeUpstreamReadyForTest(); + channel.finishRecoveredBufferDelivery(); + // then: Can read recovered buffers even before requestSubpartitions() Optional first = channel.getNextBuffer(); assertThat(first).isPresent(); @@ -755,11 +759,6 @@ void testCheckpointStartedPersistsRecoveredBuffers() throws Exception { // given: Local input channel with recovered buffers SingleInputGate inputGate = new SingleInputGateBuilder().build(); - ArrayDeque recoveredBuffers = new ArrayDeque<>(); - recoveredBuffers.add(TestBufferFactory.createBuffer(10)); - recoveredBuffers.add(TestBufferFactory.createBuffer(20)); - recoveredBuffers.add(TestBufferFactory.createBuffer(30)); - RecordingChannelStateWriter stateWriter = new RecordingChannelStateWriter(); LocalInputChannel channel = @@ -775,19 +774,28 @@ void testCheckpointStartedPersistsRecoveredBuffers() throws Exception { new SimpleCounter(), new SimpleCounter(), stateWriter, - recoveredBuffers); + 2, + true); inputGate.setInputChannels(channel); - // when: Checkpoint is started + channel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(10)); + channel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(20)); + channel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(30)); + // Mirror snapshotAndInsertBarriers: push the sentinel before + // checkpointStarted scans recoveredQueue. + channel.onRecoveredStateBuffer( + EventSerializer.toBuffer(new RecoveryCheckpointBarrier(1L), false)); + CheckpointOptions options = CheckpointOptions.unaligned(CheckpointType.CHECKPOINT, getDefault()); stateWriter.start(1L, options); CheckpointBarrier barrier = new CheckpointBarrier(1L, 0L, options); + // when: Checkpoint is started channel.checkpointStarted(barrier); - // then: All 3 recovered buffers should be persisted as inflight data List persistedBuffers = stateWriter.getAddedInput().get(channel.getChannelInfo()); + // then: All 3 recovered buffers should be persisted as inflight data assertThat(persistedBuffers).isNotNull().hasSize(3); assertThat(persistedBuffers.stream().mapToInt(Buffer::getSize).toArray()) .containsExactly(10, 20, 30); @@ -824,9 +832,6 @@ void testPriorityEventFailsFastWhenSubpartitionViewIsNull() throws Exception { // given: Local input channel with recovered buffers but NO subpartition view initialized SingleInputGate inputGate = new SingleInputGateBuilder().build(); - ArrayDeque recoveredBuffers = new ArrayDeque<>(); - recoveredBuffers.add(TestBufferFactory.createBuffer(10)); - LocalInputChannel channel = new LocalInputChannel( inputGate, @@ -840,9 +845,11 @@ void testPriorityEventFailsFastWhenSubpartitionViewIsNull() throws Exception { new SimpleCounter(), new SimpleCounter(), ChannelStateWriter.NO_OP, - recoveredBuffers); + 2, + true); inputGate.setInputChannels(channel); + channel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(10)); // Do NOT call channel.requestSubpartitions() — subpartitionView stays null channel.notifyPriorityEvent(0); @@ -962,11 +969,6 @@ private static ChannelAndSubpartition createChannelWithRecoveredBuffers( TestingResultPartitionManager partitionManager = new TestingResultPartitionManager(subpartitionView); - ArrayDeque recoveredBuffers = new ArrayDeque<>(); - for (int size : recoveredBufferSizes) { - recoveredBuffers.add(TestBufferFactory.createBuffer(size)); - } - LocalInputChannel channel = new LocalInputChannel( inputGate, @@ -980,9 +982,15 @@ private static ChannelAndSubpartition createChannelWithRecoveredBuffers( new SimpleCounter(), new SimpleCounter(), stateWriter, - recoveredBuffers); + 2, + true); inputGate.setInputChannels(channel); + + for (int size : recoveredBufferSizes) { + channel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(size)); + } + channel.requestSubpartitions(); return new ChannelAndSubpartition(channel, subpartition); @@ -998,6 +1006,340 @@ private static class ChannelAndSubpartition { } } + // --------------------------------------------------------------------------------------------- + // RecoverableInputChannel push-based recovery tests + // --------------------------------------------------------------------------------------------- + + private static LocalInputChannel newPushOnlyLocalChannel( + SingleInputGate inputGate, ChannelStateWriter stateWriter) { + return new LocalInputChannel( + inputGate, + 0, + new ResultPartitionID(), + new ResultSubpartitionIndexSet(0), + new ResultPartitionManager(), + new TaskEventDispatcher(), + 0, + 0, + new SimpleCounter(), + new SimpleCounter(), + stateWriter, + 2, + true); + } + + @Test + void testOnRecoveredStateBufferEnqueues() throws Exception { + SingleInputGate inputGate = new SingleInputGateBuilder().build(); + LocalInputChannel channel = newPushOnlyLocalChannel(inputGate, ChannelStateWriter.NO_OP); + inputGate.setInputChannels(channel); + + channel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(11)); + channel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(22)); + + Optional first = channel.getNextBuffer(); + assertThat(first).isPresent(); + assertThat(first.get().buffer().getSize()).isEqualTo(11); + Optional second = channel.getNextBuffer(); + assertThat(second).isPresent(); + assertThat(second.get().buffer().getSize()).isEqualTo(22); + } + + @Test + void testOnRecoveredStateBufferOnReleasedChannelIsSilentlyRecycled() throws Exception { + SingleInputGate inputGate = new SingleInputGateBuilder().build(); + LocalInputChannel channel = newPushOnlyLocalChannel(inputGate, ChannelStateWriter.NO_OP); + inputGate.setInputChannels(channel); + channel.releaseAllResources(); + + Buffer b = TestBufferFactory.createBuffer(33); + channel.onRecoveredStateBuffer(b); + assertThat(b.isRecycled()).isTrue(); + } + + @Test + void testOnRecoveredStateBufferNotifiesChannelNonEmptyOnEmptyToNonEmptyTransition() + throws Exception { + SingleInputGate inputGate = new SingleInputGateBuilder().build(); + LocalInputChannel channel = newPushOnlyLocalChannel(inputGate, ChannelStateWriter.NO_OP); + inputGate.setInputChannels(channel); + + CompletableFuture availability = inputGate.getAvailableFuture(); + assertThat(availability).isNotDone(); + channel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(1)); + assertThat(availability).isDone(); + } + + @Test + void testInRecoveryBoundaryFlagFalseQueueEmptyReturnsEmpty() throws Exception { + // Drive into the (flag=false, queue=empty) boundary by pushing one buffer, polling it, + // and verifying the channel does not yet expose a master-path result. + SingleInputGate inputGate = new SingleInputGateBuilder().build(); + LocalInputChannel channel = newPushOnlyLocalChannel(inputGate, ChannelStateWriter.NO_OP); + inputGate.setInputChannels(channel); + + channel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(1)); + channel.getNextBuffer(); + // Queue is empty; no finishRecoveredBufferDelivery was called. Without the subpartitionView + // active, the channel returns empty. + Optional result = channel.getNextBuffer(); + assertThat(result).isNotPresent(); + } + + @Test + void testInRecoveryBoundaryFlagFalseQueueNonEmptyPolls() throws Exception { + SingleInputGate inputGate = new SingleInputGateBuilder().build(); + LocalInputChannel channel = newPushOnlyLocalChannel(inputGate, ChannelStateWriter.NO_OP); + inputGate.setInputChannels(channel); + channel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(7)); + Optional r = channel.getNextBuffer(); + assertThat(r).isPresent(); + assertThat(r.get().buffer().getSize()).isEqualTo(7); + } + + @Test + void testInRecoveryBoundaryFlagTrueQueueNonEmptyPolls() throws Exception { + SingleInputGate inputGate = new SingleInputGateBuilder().build(); + LocalInputChannel channel = newPushOnlyLocalChannel(inputGate, ChannelStateWriter.NO_OP); + inputGate.setInputChannels(channel); + channel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(8)); + channel.completeUpstreamReadyForTest(); + channel.finishRecoveredBufferDelivery(); + Optional r = channel.getNextBuffer(); + assertThat(r).isPresent(); + assertThat(r.get().buffer().getSize()).isEqualTo(8); + } + + @Test + void testFinishWithNoRecoveredBuffersEmitsSentinelThenFallsToMasterPath() throws Exception { + // With no recovered buffers, finish still appends the EndOfFetchedChannelStateEvent + // sentinel; getNextBuffer returns it. Consuming the sentinel flips the channel out of + // recovery, after which the master path takes over: without a subpartition view it raises + // IllegalStateException via checkAndWaitForSubpartitionView, proving the recovery branch is + // no longer swallowing the call. + SingleInputGate inputGate = new SingleInputGateBuilder().build(); + LocalInputChannel channel = newPushOnlyLocalChannel(inputGate, ChannelStateWriter.NO_OP); + inputGate.setInputChannels(channel); + channel.completeUpstreamReadyForTest(); + channel.finishRecoveredBufferDelivery(); + + Optional sentinel = channel.getNextBuffer(); + assertThat(sentinel).isPresent(); + assertThat(EventSerializer.fromBuffer(sentinel.get().buffer(), getClass().getClassLoader())) + .isInstanceOf(EndOfFetchedChannelStateEvent.class); + + channel.onRecoveredStateConsumed(); + assertThatThrownBy(channel::getNextBuffer) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Queried for a buffer before requesting the subpartition"); + } + + @Test + void testPriorityEventDuringRecoveryFetchedFromSubpartitionView() throws Exception { + RecordingChannelStateWriter stateWriter = new RecordingChannelStateWriter(); + ChannelAndSubpartition ctx = createChannelWithRecoveredBuffers(stateWriter, 10, 20); + + CheckpointOptions options = + CheckpointOptions.unaligned(CheckpointType.CHECKPOINT, getDefault()); + ctx.subpartition.add( + EventSerializer.toBufferConsumer(new CheckpointBarrier(1L, 0L, options), true)); + ctx.channel.notifyPriorityEvent(0); + + Optional first = ctx.channel.getNextBuffer(); + assertThat(first).isPresent(); + assertThat(first.get().buffer().getDataType().hasPriority()).isTrue(); + } + + @Test + void testPriorityEventDuringRecoveryResetAfterNonPriority() throws Exception { + RecordingChannelStateWriter stateWriter = new RecordingChannelStateWriter(); + ChannelAndSubpartition ctx = createChannelWithRecoveredBuffers(stateWriter, 10); + + CheckpointOptions options = + CheckpointOptions.unaligned(CheckpointType.CHECKPOINT, getDefault()); + ctx.subpartition.add( + EventSerializer.toBufferConsumer(new CheckpointBarrier(1L, 0L, options), true)); + ctx.subpartition.add(createFilledFinishedBufferConsumer(32)); + ctx.channel.notifyPriorityEvent(0); + + Optional priority = ctx.channel.getNextBuffer(); + assertThat(priority).isPresent(); + Optional recovered = ctx.channel.getNextBuffer(); + assertThat(recovered).isPresent(); + assertThat(recovered.get().buffer().isBuffer()).isTrue(); + assertThat(recovered.get().buffer().getSize()).isEqualTo(10); + } + + @Test + void testCheckpointStartedScansRecoveredBuffersUpToBarrier() throws Exception { + SingleInputGate inputGate = new SingleInputGateBuilder().build(); + RecordingChannelStateWriter stateWriter = new RecordingChannelStateWriter(); + LocalInputChannel channel = newPushOnlyLocalChannel(inputGate, stateWriter); + inputGate.setInputChannels(channel); + + Buffer b1 = TestBufferFactory.createBuffer(1); + Buffer b2 = TestBufferFactory.createBuffer(2); + Buffer b3 = TestBufferFactory.createBuffer(3); + channel.onRecoveredStateBuffer(b1); + channel.onRecoveredStateBuffer(b2); + channel.onRecoveredStateBuffer( + EventSerializer.toBuffer(new RecoveryCheckpointBarrier(1L), false)); + channel.onRecoveredStateBuffer(b3); + + CheckpointOptions options = + CheckpointOptions.unaligned(CheckpointType.CHECKPOINT, getDefault()); + stateWriter.start(1L, options); + channel.checkpointStarted(new CheckpointBarrier(1L, 0L, options)); + + List persisted = stateWriter.getAddedInput().get(channel.getChannelInfo()); + assertThat(persisted).hasSize(2); + assertThat(persisted.stream().mapToInt(Buffer::getSize).toArray()).containsExactly(1, 2); + } + + @Test + void testCheckpointStartedDeclinesAsNotReadyWhenRecoveryBarrierIsMissing() throws Exception { + SingleInputGate inputGate = new SingleInputGateBuilder().build(); + RecordingChannelStateWriter stateWriter = new RecordingChannelStateWriter(); + LocalInputChannel channel = newPushOnlyLocalChannel(inputGate, stateWriter); + inputGate.setInputChannels(channel); + + Buffer b1 = TestBufferFactory.createBuffer(1); + channel.onRecoveredStateBuffer(b1); + int refCntBefore = b1.refCnt(); + + CheckpointOptions options = + CheckpointOptions.unaligned(CheckpointType.CHECKPOINT, getDefault()); + stateWriter.start(1L, options); + + // A missing RecoveryCheckpointBarrier means the channel is not yet ready to snapshot + // recovered state for this checkpoint, so it declines as TASK_NOT_READY (not a fatal + // CHECKPOINT_DECLINED): the checkpoint is deferred/retried and the recovered buffer is + // neither dropped nor persisted. + assertThatThrownBy(() -> channel.checkpointStarted(new CheckpointBarrier(1L, 0L, options))) + .isInstanceOfSatisfying( + CheckpointException.class, + e -> + assertThat(e.getCheckpointFailureReason()) + .isEqualTo( + CheckpointFailureReason + .CHECKPOINT_DECLINED_TASK_NOT_READY)) + .hasMessageContaining("not yet present in channel"); + assertThat(b1.refCnt()).isEqualTo(refCntBefore); + assertThat(stateWriter.getAddedInput().get(channel.getChannelInfo())).isEmpty(); + } + + @Test + void testCheckpointStartedRetainsPreBarrierBuffers() throws Exception { + SingleInputGate inputGate = new SingleInputGateBuilder().build(); + RecordingChannelStateWriter stateWriter = new RecordingChannelStateWriter(); + LocalInputChannel channel = newPushOnlyLocalChannel(inputGate, stateWriter); + inputGate.setInputChannels(channel); + + Buffer b1 = TestBufferFactory.createBuffer(1); + channel.onRecoveredStateBuffer(b1); + channel.onRecoveredStateBuffer( + EventSerializer.toBuffer(new RecoveryCheckpointBarrier(1L), false)); + + int before = b1.refCnt(); + CheckpointOptions options = + CheckpointOptions.unaligned(CheckpointType.CHECKPOINT, getDefault()); + stateWriter.start(1L, options); + channel.checkpointStarted(new CheckpointBarrier(1L, 0L, options)); + // Pre-barrier buffers are retained for the writer; their ref-count stays at or above the + // pre-checkpoint value. + assertThat(b1.refCnt()).isGreaterThanOrEqualTo(before); + } + + @Test + void testCheckpointStartedRemovesSentinel() throws Exception { + SingleInputGate inputGate = new SingleInputGateBuilder().build(); + RecordingChannelStateWriter stateWriter = new RecordingChannelStateWriter(); + LocalInputChannel channel = newPushOnlyLocalChannel(inputGate, stateWriter); + inputGate.setInputChannels(channel); + + channel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(1)); + channel.onRecoveredStateBuffer( + EventSerializer.toBuffer(new RecoveryCheckpointBarrier(1L), false)); + channel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(2)); + + CheckpointOptions options = + CheckpointOptions.unaligned(CheckpointType.CHECKPOINT, getDefault()); + stateWriter.start(1L, options); + channel.checkpointStarted(new CheckpointBarrier(1L, 0L, options)); + + Optional h1 = channel.getNextBuffer(); + assertThat(h1).isPresent(); + Optional h2 = channel.getNextBuffer(); + assertThat(h2).isPresent(); + assertThat(h2.get().buffer().getSize()).isEqualTo(2); + } + + @Test + void testCheckpointStartedNestedCpIds() throws Exception { + SingleInputGate inputGate = new SingleInputGateBuilder().build(); + RecordingChannelStateWriter stateWriter = new RecordingChannelStateWriter(); + LocalInputChannel channel = newPushOnlyLocalChannel(inputGate, stateWriter); + inputGate.setInputChannels(channel); + + channel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(1)); + channel.onRecoveredStateBuffer( + EventSerializer.toBuffer(new RecoveryCheckpointBarrier(1L), false)); + channel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(2)); + channel.onRecoveredStateBuffer( + EventSerializer.toBuffer(new RecoveryCheckpointBarrier(2L), false)); + + CheckpointOptions options = + CheckpointOptions.unaligned(CheckpointType.CHECKPOINT, getDefault()); + stateWriter.start(1L, options); + channel.checkpointStarted(new CheckpointBarrier(1L, 0L, options)); + + List persisted = stateWriter.getAddedInput().get(channel.getChannelInfo()); + assertThat(persisted).hasSize(1); + assertThat(persisted.get(0).getSize()).isEqualTo(1); + } + + @Test + void testCheckpointStartedNotInRecoveryUsesMasterPath() throws Exception { + SingleInputGate inputGate = new SingleInputGateBuilder().build(); + RecordingChannelStateWriter stateWriter = new RecordingChannelStateWriter(); + LocalInputChannel channel = newPushOnlyLocalChannel(inputGate, stateWriter); + inputGate.setInputChannels(channel); + // Channel starts in-recovery; finish appends the sentinel and consuming it flips the + // channel to not-in-recovery so checkpointStarted exercises the master path instead of the + // recovery branch. + channel.completeUpstreamReadyForTest(); + channel.finishRecoveredBufferDelivery(); + channel.getNextBuffer(); + channel.onRecoveredStateConsumed(); + + CheckpointOptions options = + CheckpointOptions.unaligned(CheckpointType.CHECKPOINT, getDefault()); + stateWriter.start(1L, options); + channel.checkpointStarted(new CheckpointBarrier(1L, 0L, options)); + + assertThat(stateWriter.getAddedInput().get(channel.getChannelInfo())).isNullOrEmpty(); + } + + @Test + void testReceivedBuffersHasNoLiveDataBufferIsTrueOnLocal() throws Exception { + // Local has no receivedBuffers; the helper is trivially true. We exercise the + // in-recovery checkpointStarted branch without any "live data" infrastructure to confirm + // the channel does not crash. + SingleInputGate inputGate = new SingleInputGateBuilder().build(); + RecordingChannelStateWriter stateWriter = new RecordingChannelStateWriter(); + LocalInputChannel channel = newPushOnlyLocalChannel(inputGate, stateWriter); + inputGate.setInputChannels(channel); + + channel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(1)); + channel.onRecoveredStateBuffer( + EventSerializer.toBuffer(new RecoveryCheckpointBarrier(1L), false)); + + CheckpointOptions options = + CheckpointOptions.unaligned(CheckpointType.CHECKPOINT, getDefault()); + stateWriter.start(1L, options); + channel.checkpointStarted(new CheckpointBarrier(1L, 0L, options)); + } + // --------------------------------------------------------------------------------------------- /** Returns the configured number of buffers for each channel in a random order. */ diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalRecoveredInputChannelTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalRecoveredInputChannelTest.java new file mode 100644 index 00000000000000..005c86d3ed0111 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalRecoveredInputChannelTest.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.io.network.partition.consumer; + +import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter; +import org.apache.flink.runtime.io.network.buffer.Buffer; +import org.apache.flink.runtime.io.network.util.TestBufferFactory; + +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +class LocalRecoveredInputChannelTest { + + @Test + void testToInputChannelRequiresEmptyRecoveredBuffers() throws Exception { + SingleInputGate inputGate = new SingleInputGateBuilder().build(); + LocalRecoveredInputChannel recoveredChannel = + InputChannelBuilder.newBuilder() + .setStateWriter(ChannelStateWriter.NO_OP) + .buildLocalRecoveredChannel(inputGate); + + Buffer buffer = TestBufferFactory.createBuffer(11); + recoveredChannel.onRecoveredStateBuffer(buffer); + + try { + recoveredChannel.finishReadRecoveredState(); + assertThatThrownBy(() -> recoveredChannel.toInputChannel(true)) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Received buffer should be empty"); + } finally { + recoveredChannel.releaseAllResources(); + } + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredInputChannelRequestBufferBlockingHeapFallbackRemovedTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredInputChannelRequestBufferBlockingHeapFallbackRemovedTest.java new file mode 100644 index 00000000000000..32247c0cc77991 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredInputChannelRequestBufferBlockingHeapFallbackRemovedTest.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.io.network.partition.consumer; + +import org.apache.flink.metrics.SimpleCounter; +import org.apache.flink.runtime.io.network.buffer.Buffer; +import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool; +import org.apache.flink.runtime.io.network.partition.ResultPartitionID; +import org.apache.flink.runtime.io.network.partition.ResultSubpartitionIndexSet; +import org.apache.flink.runtime.memory.MemoryManager; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +import static org.assertj.core.api.Assertions.assertThat; + +class RecoveredInputChannelRequestBufferBlockingHeapFallbackRemovedTest { + + private NetworkBufferPool pool; + + @AfterEach + void tearDown() { + if (pool != null) { + pool.destroy(); + pool = null; + } + } + + @Test + void testBufferPoolExhaustedBlocksRatherThanHeapAllocate() throws Exception { + int totalSegments = 4; + pool = new NetworkBufferPool(totalSegments, MemoryManager.DEFAULT_PAGE_SIZE); + RecoveredInputChannel channel = buildChannel(pool, totalSegments); + + for (int i = 0; i < totalSegments; i++) { + channel.requestBufferBlocking(); + } + + CountDownLatch entered = new CountDownLatch(1); + AtomicReference result = new AtomicReference<>(); + Thread blocker = + new Thread( + () -> { + try { + entered.countDown(); + result.set(channel.requestBufferBlocking()); + } catch (Exception ignored) { + // Thread will be interrupted at teardown. + } + }, + "blocking-requester"); + blocker.start(); + + assertThat(entered.await(5, TimeUnit.SECONDS)).isTrue(); + Thread.sleep(200); + assertThat(result.get()).as("buffer should not have been allocated").isNull(); + + blocker.interrupt(); + blocker.join(5_000); + } + + @Test + void testFilterOnPathTakesSameRouteAsFilterOff() throws Exception { + int exclusivePerChannel = 1; + int totalSegments = 4; + pool = new NetworkBufferPool(totalSegments, MemoryManager.DEFAULT_PAGE_SIZE); + + Buffer filterOnBuf = buildChannel(pool, exclusivePerChannel).requestBufferBlocking(); + Buffer filterOffBuf = buildChannel(pool, exclusivePerChannel).requestBufferBlocking(); + + // Both must come from the pool — the BufferManager-owned recycler, not the + // FreeingBufferRecycler the heap-fallback used. + assertThat(filterOnBuf.getMemorySegment()).isNotNull(); + assertThat(filterOffBuf.getMemorySegment()).isNotNull(); + assertThat(filterOnBuf.getRecycler().getClass().getName()) + .doesNotContain("FreeingBufferRecycler"); + assertThat(filterOffBuf.getRecycler().getClass().getName()) + .doesNotContain("FreeingBufferRecycler"); + + filterOnBuf.recycleBuffer(); + filterOffBuf.recycleBuffer(); + } + + private RecoveredInputChannel buildChannel( + NetworkBufferPool segmentProvider, int exclusivePerChannel) { + try { + SingleInputGate inputGate = + new SingleInputGateBuilder().setSegmentProvider(segmentProvider).build(); + return new RecoveredInputChannel( + inputGate, + 0, + new ResultPartitionID(), + new ResultSubpartitionIndexSet(0), + 0, + 0, + new SimpleCounter(), + new SimpleCounter(), + exclusivePerChannel) { + @Override + protected InputChannel toInputChannelInternal(boolean needsRecovery) { + throw new AssertionError("not expected during this test"); + } + }; + } catch (Exception e) { + throw new AssertionError("channel construction failed", e); + } + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredInputChannelTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredInputChannelTest.java index f40fd09702ede8..ba78b5c2aed6b0 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredInputChannelTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredInputChannelTest.java @@ -22,14 +22,13 @@ import org.apache.flink.runtime.checkpoint.CheckpointException; import org.apache.flink.runtime.checkpoint.CheckpointType; import org.apache.flink.runtime.io.network.api.CheckpointBarrier; -import org.apache.flink.runtime.io.network.buffer.Buffer; +import org.apache.flink.runtime.io.network.buffer.BufferBuilderTestUtils; import org.apache.flink.runtime.io.network.partition.ResultPartitionID; import org.apache.flink.runtime.io.network.partition.ResultSubpartitionIndexSet; import org.junit.jupiter.api.Test; import java.io.IOException; -import java.util.ArrayDeque; import static org.apache.flink.runtime.checkpoint.CheckpointOptions.unaligned; import static org.apache.flink.runtime.state.CheckpointStorageLocationReference.getDefault; @@ -39,16 +38,6 @@ /** Tests for {@link RecoveredInputChannel}. */ class RecoveredInputChannelTest { - @Test - void testConversionOnlyPossibleAfterBufferFilteringComplete() { - // toInputChannel() always checks bufferFilteringCompleteFuture regardless of config - for (boolean configEnabled : new boolean[] {true, false}) { - assertThatThrownBy(() -> buildChannel(configEnabled).toInputChannel()) - .isInstanceOf(IllegalStateException.class) - .hasMessageContaining("buffer filtering is not complete"); - } - } - @Test void testRequestPartitionsImpossible() { assertThatThrownBy(() -> buildChannel(false).requestSubpartitions()) @@ -71,93 +60,58 @@ void testCheckpointStartImpossible() { } @Test - void testToInputChannelAllowedWhenBufferFilteringCompleteAndConfigEnabled() throws IOException { - // When config is enabled, conversion is allowed when bufferFilteringCompleteFuture is done - TestableRecoveredInputChannel channel = buildTestableChannel(true); - - // Initially, conversion should fail - assertThatThrownBy(() -> channel.toInputChannel()) - .isInstanceOf(IllegalStateException.class) - .hasMessageContaining("buffer filtering is not complete"); - - // After finishReadRecoveredState(), bufferFilteringCompleteFuture should be done - channel.finishReadRecoveredState(); - assertThat(channel.getBufferFilteringCompleteFuture()).isDone(); - assertThat(channel.getStateConsumedFuture()).isNotDone(); - - // Conversion should now succeed (no exception) - InputChannel converted = channel.toInputChannel(); - assertThat(converted).isNotNull(); - } - - @Test - void testToInputChannelAllowedWhenStateConsumedAndConfigDisabled() throws IOException { - // When config is disabled, conversion requires both bufferFilteringCompleteFuture - // and stateConsumedFuture to be done + void testToInputChannelRejectedWhileRecoveredStateUnconsumed() throws IOException { + // Conversion is rejected while recovered state is still queued: finishReadRecoveredState() + // enqueues the EndOfInputChannelStateEvent sentinel, so receivedBuffers is non-empty until + // it is consumed. The empty-queue check thus also guarantees stateConsumedFuture is done. TestableRecoveredInputChannel channel = buildTestableChannel(false); - // Initially, conversion should fail (buffer filtering not complete) - assertThatThrownBy(() -> channel.toInputChannel()) - .isInstanceOf(IllegalStateException.class) - .hasMessageContaining("buffer filtering is not complete"); - - // After finishReadRecoveredState(), bufferFilteringCompleteFuture is done - // but stateConsumedFuture is not channel.finishReadRecoveredState(); - assertThat(channel.getBufferFilteringCompleteFuture()).isDone(); assertThat(channel.getStateConsumedFuture()).isNotDone(); - // Conversion should still fail because stateConsumedFuture is not done - assertThatThrownBy(() -> channel.toInputChannel()) + // Conversion fails because the sentinel is still queued. + assertThatThrownBy(() -> channel.toInputChannel(true)) .isInstanceOf(IllegalStateException.class) - .hasMessageContaining("recovered state is not fully consumed"); + .hasMessageContaining("Received buffer should be empty"); - // Consume the EndOfInputChannelStateEvent to complete stateConsumedFuture + // Consuming the EndOfInputChannelStateEvent should complete the future. + // getNextBuffer() returns empty when it encounters the event internally. assertThat(channel.getNextBuffer()).isNotPresent(); assertThat(channel.getStateConsumedFuture()).isDone(); // Now conversion should succeed - InputChannel converted = channel.toInputChannel(); + InputChannel converted = channel.toInputChannel(true); assertThat(converted).isNotNull(); } @Test - void testBufferFilteringCompleteFutureAlwaysCompletes() throws IOException { - // finishReadRecoveredState() unconditionally completes bufferFilteringCompleteFuture - for (boolean configEnabled : new boolean[] {true, false}) { - RecoveredInputChannel channel = buildChannel(configEnabled); - assertThat(channel.getBufferFilteringCompleteFuture()).isNotDone(); - channel.finishReadRecoveredState(); - assertThat(channel.getBufferFilteringCompleteFuture()).isDone(); - } + void testToInputChannelRequiresEmptyRecoveredBuffers() throws IOException { + TestableRecoveredInputChannel channel = buildTestableChannel(true); + + channel.onRecoveredStateBuffer(BufferBuilderTestUtils.buildSomeBuffer()); + channel.finishReadRecoveredState(); + + assertThatThrownBy(() -> channel.toInputChannel(true)) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Received buffer should be empty"); } @Test - void testStateConsumedFutureCompletesAfterConsumingAllBuffers() throws IOException { - // This test verifies that stateConsumedFuture completes after consuming - // EndOfInputChannelStateEvent regardless of the config setting - for (boolean configEnabled : new boolean[] {true, false}) { - RecoveredInputChannel channel = buildChannel(configEnabled); + void testStateConsumedFutureCompletesAfterLegacySentinelIsConsumed() throws IOException { + RecoveredInputChannel channel = buildChannel(false); - assertThat(channel.getStateConsumedFuture()).isNotDone(); + assertThat(channel.getStateConsumedFuture()).isNotDone(); - channel.finishReadRecoveredState(); - assertThat(channel.getStateConsumedFuture()).isNotDone(); + channel.finishReadRecoveredState(); + assertThat(channel.getStateConsumedFuture()).isNotDone(); - // Consuming the EndOfInputChannelStateEvent should complete the future. - // getNextBuffer() returns empty when it encounters the event internally. - assertThat(channel.getNextBuffer()).isNotPresent(); - assertThat(channel.getStateConsumedFuture()).isDone(); - } + assertThat(channel.getNextBuffer()).isNotPresent(); + assertThat(channel.getStateConsumedFuture()).isDone(); } private RecoveredInputChannel buildChannel(boolean checkpointingDuringRecoveryEnabled) { try { - SingleInputGate inputGate = - new SingleInputGateBuilder() - .setCheckpointingDuringRecoveryEnabled( - checkpointingDuringRecoveryEnabled) - .build(); + SingleInputGate inputGate = new SingleInputGateBuilder().build(); return new RecoveredInputChannel( inputGate, 0, @@ -169,7 +123,7 @@ private RecoveredInputChannel buildChannel(boolean checkpointingDuringRecoveryEn new SimpleCounter(), 10) { @Override - protected InputChannel toInputChannelInternal(ArrayDeque remainingBuffers) { + protected InputChannel toInputChannelInternal(boolean needsRecovery) { throw new AssertionError("channel conversion succeeded"); } }; @@ -181,11 +135,7 @@ protected InputChannel toInputChannelInternal(ArrayDeque remainingBuffer private TestableRecoveredInputChannel buildTestableChannel( boolean checkpointingDuringRecoveryEnabled) { try { - SingleInputGate inputGate = - new SingleInputGateBuilder() - .setCheckpointingDuringRecoveryEnabled( - checkpointingDuringRecoveryEnabled) - .build(); + SingleInputGate inputGate = new SingleInputGateBuilder().build(); return new TestableRecoveredInputChannel(inputGate); } catch (Exception e) { throw new AssertionError("channel creation failed", e); @@ -210,7 +160,7 @@ private static class TestableRecoveredInputChannel extends RecoveredInputChannel } @Override - protected InputChannel toInputChannelInternal(ArrayDeque remainingBuffers) { + protected InputChannel toInputChannelInternal(boolean needsRecovery) { return new TestInputChannel(inputGate, 0); } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java index e47de93c9e8bdf..1bd6489dc9d331 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java @@ -23,10 +23,14 @@ import org.apache.flink.core.testutils.OneShotLatch; import org.apache.flink.metrics.SimpleCounter; import org.apache.flink.runtime.checkpoint.CheckpointException; +import org.apache.flink.runtime.checkpoint.CheckpointFailureReason; import org.apache.flink.runtime.checkpoint.CheckpointOptions; import org.apache.flink.runtime.checkpoint.CheckpointType; import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter; import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo; +import org.apache.flink.runtime.checkpoint.channel.RecordingChannelStateWriter; +import org.apache.flink.runtime.checkpoint.channel.RecoveryCheckpointBarrier; +import org.apache.flink.runtime.clusterframework.types.ResourceID; import org.apache.flink.runtime.execution.CancelTaskException; import org.apache.flink.runtime.execution.ExecutionState; import org.apache.flink.runtime.io.network.ConnectionID; @@ -37,6 +41,7 @@ import org.apache.flink.runtime.io.network.TestingConnectionManager; import org.apache.flink.runtime.io.network.TestingPartitionRequestClient; import org.apache.flink.runtime.io.network.api.CheckpointBarrier; +import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent; import org.apache.flink.runtime.io.network.api.serialization.EventSerializer; import org.apache.flink.runtime.io.network.buffer.Buffer; import org.apache.flink.runtime.io.network.buffer.Buffer.DataType; @@ -72,6 +77,7 @@ import javax.annotation.Nullable; import java.io.IOException; +import java.net.InetSocketAddress; import java.util.ArrayDeque; import java.util.ArrayList; import java.util.List; @@ -2079,15 +2085,8 @@ void testGetNextBufferWithMigratedRecoveredBuffers() throws Exception { // given: RemoteInputChannel with recovered buffers migrated from RecoveredInputChannel SingleInputGate inputGate = createSingleInputGate(1); - ArrayDeque recoveredBuffers = new ArrayDeque<>(); - recoveredBuffers.add(TestBufferFactory.createBuffer(10)); - recoveredBuffers.add(TestBufferFactory.createBuffer(20)); - ConnectionID connectionId = - new ConnectionID( - org.apache.flink.runtime.clusterframework.types.ResourceID.generate(), - new java.net.InetSocketAddress("localhost", 0), - 0); + new ConnectionID(ResourceID.generate(), new InetSocketAddress("localhost", 0), 0); RemoteInputChannel channel = new RemoteInputChannel( inputGate, @@ -2104,10 +2103,15 @@ void testGetNextBufferWithMigratedRecoveredBuffers() throws Exception { new SimpleCounter(), new SimpleCounter(), ChannelStateWriter.NO_OP, - recoveredBuffers); + true); inputGate.setInputChannels(channel); + channel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(10)); + channel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(20)); + channel.completeUpstreamReadyForTest(); + channel.finishRecoveredBufferDelivery(); + // then: Can read recovered buffers even before requestSubpartitions() Optional first = channel.getNextBuffer(); assertThat(first).isPresent(); @@ -2119,6 +2123,494 @@ void testGetNextBufferWithMigratedRecoveredBuffers() throws Exception { assertThat(second.get().buffer().getSize()).isEqualTo(20); } + // --------------------------------------------------------------------------------------------- + // RecoverableInputChannel push-based recovery tests + // --------------------------------------------------------------------------------------------- + + @Test + void testOnRecoveredStateBufferEnqueues() throws Exception { + SingleInputGate inputGate = createSingleInputGate(1); + RemoteInputChannel channel = + InputChannelBuilder.newBuilder() + .setStateWriter(ChannelStateWriter.NO_OP) + .setNeedsRecovery(true) + .buildRemoteChannel(inputGate); + inputGate.setInputChannels(channel); + + Buffer b1 = TestBufferFactory.createBuffer(11); + Buffer b2 = TestBufferFactory.createBuffer(22); + channel.onRecoveredStateBuffer(b1); + channel.onRecoveredStateBuffer(b2); + + Optional first = channel.getNextBuffer(); + assertThat(first).isPresent(); + assertThat(first.get().buffer().getSize()).isEqualTo(11); + Optional second = channel.getNextBuffer(); + assertThat(second).isPresent(); + assertThat(second.get().buffer().getSize()).isEqualTo(22); + } + + @Test + void testRecoveredBuffersConsumedBeforeStashedEventsThenSentinelFlipsRecovery() + throws Exception { + SingleInputGate inputGate = createSingleInputGate(1); + RemoteInputChannel channel = + InputChannelBuilder.newBuilder() + .setStateWriter(ChannelStateWriter.NO_OP) + .setNeedsRecovery(true) + .buildRemoteChannel(inputGate); + inputGate.setInputChannels(channel); + + // Recovered buffer arrives via the drain; an ordinary upstream event arrives via onBuffer + // while still in recovery and must be stashed (events carry no credit). + channel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(11)); + // backlog=-1: events carry no backlog, and this channel has no floating-buffer pool wired. + channel.onBuffer(EventSerializer.toBuffer(EndOfPartitionEvent.INSTANCE, false), 0, -1, 0); + channel.completeUpstreamReadyForTest(); + channel.finishRecoveredBufferDelivery(); + + // Recovered buffer is consumed first. + Optional recovered = channel.getNextBuffer(); + assertThat(recovered).isPresent(); + assertThat(recovered.get().buffer().getSize()).isEqualTo(11); + + // Then the sentinel; the stashed event is not yet visible (still in recovery). + Optional sentinel = channel.getNextBuffer(); + assertThat(sentinel).isPresent(); + assertThat(EventSerializer.fromBuffer(sentinel.get().buffer(), getClass().getClassLoader())) + .isInstanceOf(EndOfFetchedChannelStateEvent.class); + + // The gate consumes the sentinel externally: flips out of recovery and unstashes the event. + channel.onRecoveredStateConsumed(); + Optional stashed = channel.getNextBuffer(); + assertThat(stashed).isPresent(); + assertThat(EventSerializer.fromBuffer(stashed.get().buffer(), getClass().getClassLoader())) + .isInstanceOf(EndOfPartitionEvent.class); + } + + @Test + void testOnBufferRejectsLiveDataBufferDuringRecovery() throws Exception { + SingleInputGate inputGate = createSingleInputGate(1); + RemoteInputChannel channel = + InputChannelBuilder.newBuilder() + .setStateWriter(ChannelStateWriter.NO_OP) + .setNeedsRecovery(true) + .buildRemoteChannel(inputGate); + inputGate.setInputChannels(channel); + + assertThatThrownBy(() -> channel.onBuffer(TestBufferFactory.createBuffer(1), 0, 0, 0)) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Received live data buffer during recovery"); + } + + @Test + void testOnRecoveredStateBufferOnReleasedChannelIsSilentlyRecycled() throws Exception { + SingleInputGate inputGate = createSingleInputGate(1); + RemoteInputChannel channel = + InputChannelBuilder.newBuilder() + .setStateWriter(ChannelStateWriter.NO_OP) + .setNeedsRecovery(true) + .buildRemoteChannel(inputGate); + inputGate.setInputChannels(channel); + channel.releaseAllResources(); + + Buffer b = TestBufferFactory.createBuffer(33); + channel.onRecoveredStateBuffer(b); + + assertThat(b.isRecycled()).isTrue(); + } + + @Test + void testOnRecoveredStateBufferNotifiesChannelNonEmptyOnEmptyToNonEmptyTransition() + throws Exception { + SingleInputGate inputGate = createSingleInputGate(1); + RemoteInputChannel channel = + InputChannelBuilder.newBuilder() + .setStateWriter(ChannelStateWriter.NO_OP) + .setNeedsRecovery(true) + .buildRemoteChannel(inputGate); + inputGate.setInputChannels(channel); + + CompletableFuture availability = inputGate.getAvailableFuture(); + assertThat(availability).isNotDone(); + + channel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(1)); + assertThat(availability).isDone(); + } + + @Test + void testInRecoveryBoundaryFlagFalseQueueEmptyReturnsEmpty() throws Exception { + SingleInputGate inputGate = createSingleInputGate(1); + RemoteInputChannel channel = + InputChannelBuilder.newBuilder() + .setStateWriter(ChannelStateWriter.NO_OP) + .setNeedsRecovery(true) + .buildRemoteChannel(inputGate); + inputGate.setInputChannels(channel); + + // Force in-recovery + empty queue: push then poll a buffer, then push another while + // delaying finish. After the consumer drains the staged buffer, queue=empty and + // flag=false (no finishRecoveredBufferDelivery called yet). getNextBuffer must return + // empty. + channel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(1)); + channel.getNextBuffer(); + // Simulate an explicit recovery context where the producer signals "not done yet". + channel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(2)); + channel.getNextBuffer(); + // The boundary case "flag=false (drain still running) + queue empty" should return empty. + // To set this state explicitly, we deliberately do not call finishReadRecoveredState. + Optional result = channel.getNextBuffer(); + assertThat(result).isNotPresent(); + } + + @Test + void testInRecoveryBoundaryFlagFalseQueueNonEmptyPolls() throws Exception { + SingleInputGate inputGate = createSingleInputGate(1); + RemoteInputChannel channel = + InputChannelBuilder.newBuilder() + .setStateWriter(ChannelStateWriter.NO_OP) + .setNeedsRecovery(true) + .buildRemoteChannel(inputGate); + inputGate.setInputChannels(channel); + channel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(7)); + + Optional r = channel.getNextBuffer(); + assertThat(r).isPresent(); + assertThat(r.get().buffer().getSize()).isEqualTo(7); + } + + @Test + void testInRecoveryBoundaryFlagTrueQueueNonEmptyPolls() throws Exception { + SingleInputGate inputGate = createSingleInputGate(1); + RemoteInputChannel channel = + InputChannelBuilder.newBuilder() + .setStateWriter(ChannelStateWriter.NO_OP) + .setNeedsRecovery(true) + .buildRemoteChannel(inputGate); + inputGate.setInputChannels(channel); + channel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(8)); + channel.completeUpstreamReadyForTest(); + channel.finishRecoveredBufferDelivery(); + + Optional r = channel.getNextBuffer(); + assertThat(r).isPresent(); + assertThat(r.get().buffer().getSize()).isEqualTo(8); + } + + @Test + void testFinishWithNoRecoveredBuffersEmitsSentinelThenFallsToMasterPath() throws Exception { + // Wire a real network pool so requestSubpartitions() can succeed and the master path can + // poll receivedBuffers. + final NetworkBufferPool networkBufferPool = new NetworkBufferPool(4, 4096); + try { + SingleInputGate inputGate = + new SingleInputGateBuilder() + .setBufferPoolFactory(networkBufferPool.createBufferPool(1, 4)) + .setSegmentProvider(networkBufferPool) + .setChannelFactory( + (builder, gate) -> + builder.setNeedsRecovery(true).buildRemoteChannel(gate)) + .build(); + inputGate.setup(); + RemoteInputChannel channel = (RemoteInputChannel) inputGate.getChannel(0); + channel.completeUpstreamReadyForTest(); + // Even with no recovered buffers, finish appends the EndOfFetchedChannelStateEvent + // sentinel so the consume path can flip out of recovery in order. + channel.finishRecoveredBufferDelivery(); + inputGate.requestPartitions(); + + Optional sentinel = channel.getNextBuffer(); + assertThat(sentinel).isPresent(); + assertThat( + EventSerializer.fromBuffer( + sentinel.get().buffer(), getClass().getClassLoader())) + .isInstanceOf(EndOfFetchedChannelStateEvent.class); + + // Consuming the sentinel (done externally by the gate) flips the channel out of + // recovery; afterwards the master path is taken and there is no more queued data. + channel.onRecoveredStateConsumed(); + assertThat(channel.getNextBuffer()).isNotPresent(); + } finally { + networkBufferPool.destroy(); + } + } + + @Test + void testMoreAvailableNoneWhenLastRecoveredBufferAndDrainNotFinished() throws Exception { + // While the channel is in recovery and the drain has not finished, the last currently + // queued recovered buffer must report NONE as its next data type: no live data can enter + // receivedBuffers (the upstream has no credit), so there is nothing else to expose yet. + final NetworkBufferPool networkBufferPool = new NetworkBufferPool(4, 4096); + try { + SingleInputGate inputGate = + new SingleInputGateBuilder() + .setBufferPoolFactory(networkBufferPool.createBufferPool(1, 4)) + .setSegmentProvider(networkBufferPool) + .setChannelFactory( + (builder, gate) -> + builder.setNeedsRecovery(true).buildRemoteChannel(gate)) + .build(); + inputGate.setup(); + RemoteInputChannel channel = (RemoteInputChannel) inputGate.getChannel(0); + + channel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(11)); + + Optional recoveredBuf = channel.getNextBuffer(); + assertThat(recoveredBuf).isPresent(); + assertThat(recoveredBuf.get().buffer().getSize()).isEqualTo(11); + // Drain not finished and queue now empty: nothing more is available yet. + assertThat(recoveredBuf.get().moreAvailable()).isFalse(); + } finally { + networkBufferPool.destroy(); + } + } + + @Test + void testPriorityEventDuringRecoveryViaAddPriorityBuffer() throws Exception { + final NetworkBufferPool networkBufferPool = new NetworkBufferPool(4, 4096); + try { + SingleInputGate inputGate = + new SingleInputGateBuilder() + .setBufferPoolFactory(networkBufferPool.createBufferPool(1, 4)) + .setSegmentProvider(networkBufferPool) + .setChannelFactory( + (builder, gate) -> + builder.setNeedsRecovery(true).buildRemoteChannel(gate)) + .build(); + inputGate.setup(); + RemoteInputChannel channel = (RemoteInputChannel) inputGate.getChannel(0); + + channel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(11)); + + CheckpointBarrier barrier = new CheckpointBarrier(1L, 0L, UNALIGNED); + channel.onBuffer(toBuffer(barrier, true), 0, 0, 0); + + Optional first = channel.getNextBuffer(); + assertThat(first).isPresent(); + assertThat(first.get().buffer().getDataType().hasPriority()).isTrue(); + + Optional second = channel.getNextBuffer(); + assertThat(second).isPresent(); + assertThat(second.get().buffer().getSize()).isEqualTo(11); + } finally { + networkBufferPool.destroy(); + } + } + + @Test + void testCheckpointStartedScansRecoveredBuffersUpToBarrier() throws Exception { + SingleInputGate inputGate = createSingleInputGate(1); + RecordingChannelStateWriter stateWriter = new RecordingChannelStateWriter(); + RemoteInputChannel channel = + InputChannelBuilder.newBuilder() + .setStateWriter(stateWriter) + .setNeedsRecovery(true) + .buildRemoteChannel(inputGate); + inputGate.setInputChannels(channel); + + Buffer b1 = TestBufferFactory.createBuffer(1); + Buffer b2 = TestBufferFactory.createBuffer(2); + Buffer b3 = TestBufferFactory.createBuffer(3); + channel.onRecoveredStateBuffer(b1); + channel.onRecoveredStateBuffer(b2); + channel.onRecoveredStateBuffer( + EventSerializer.toBuffer(new RecoveryCheckpointBarrier(1L), false)); + channel.onRecoveredStateBuffer(b3); + + stateWriter.start(1L, UNALIGNED); + channel.checkpointStarted(new CheckpointBarrier(1L, 0L, UNALIGNED)); + + List persisted = stateWriter.getAddedInput().get(channel.getChannelInfo()); + assertThat(persisted).hasSize(2); + assertThat(persisted.stream().mapToInt(Buffer::getSize).toArray()).containsExactly(1, 2); + } + + @Test + void testCheckpointStartedDeclinesAsNotReadyWhenRecoveryBarrierIsMissing() throws Exception { + SingleInputGate inputGate = createSingleInputGate(1); + RecordingChannelStateWriter stateWriter = new RecordingChannelStateWriter(); + RemoteInputChannel channel = + InputChannelBuilder.newBuilder() + .setStateWriter(stateWriter) + .setNeedsRecovery(true) + .buildRemoteChannel(inputGate); + inputGate.setInputChannels(channel); + + Buffer b1 = TestBufferFactory.createBuffer(1); + channel.onRecoveredStateBuffer(b1); + int refCntBefore = b1.refCnt(); + + stateWriter.start(1L, UNALIGNED); + + // A missing RecoveryCheckpointBarrier means the channel is not yet ready to snapshot + // recovered state for this checkpoint, so it declines as TASK_NOT_READY (not a fatal + // CHECKPOINT_DECLINED): the checkpoint is deferred/retried and the recovered buffer is + // neither dropped nor persisted. + assertThatThrownBy( + () -> channel.checkpointStarted(new CheckpointBarrier(1L, 0L, UNALIGNED))) + .isInstanceOfSatisfying( + CheckpointException.class, + e -> + assertThat(e.getCheckpointFailureReason()) + .isEqualTo( + CheckpointFailureReason + .CHECKPOINT_DECLINED_TASK_NOT_READY)) + .hasMessageContaining("not yet present in channel"); + assertThat(b1.refCnt()).isEqualTo(refCntBefore); + assertThat(stateWriter.getAddedInput().get(channel.getChannelInfo())).isEmpty(); + } + + @Test + void testCheckpointStartedRetainsPreBarrierBuffers() throws Exception { + SingleInputGate inputGate = createSingleInputGate(1); + RecordingChannelStateWriter stateWriter = new RecordingChannelStateWriter(); + RemoteInputChannel channel = + InputChannelBuilder.newBuilder() + .setStateWriter(stateWriter) + .setNeedsRecovery(true) + .buildRemoteChannel(inputGate); + inputGate.setInputChannels(channel); + + Buffer b1 = TestBufferFactory.createBuffer(1); + channel.onRecoveredStateBuffer(b1); + channel.onRecoveredStateBuffer( + EventSerializer.toBuffer(new RecoveryCheckpointBarrier(1L), false)); + + int before = b1.refCnt(); + stateWriter.start(1L, UNALIGNED); + channel.checkpointStarted(new CheckpointBarrier(1L, 0L, UNALIGNED)); + // After retainBuffer the buffer remains live for both the queue read and the writer copy. + assertThat(b1.refCnt()).isGreaterThanOrEqualTo(before); + } + + @Test + void testCheckpointStartedRemovesSentinel() throws Exception { + SingleInputGate inputGate = createSingleInputGate(1); + RecordingChannelStateWriter stateWriter = new RecordingChannelStateWriter(); + RemoteInputChannel channel = + InputChannelBuilder.newBuilder() + .setStateWriter(stateWriter) + .setNeedsRecovery(true) + .buildRemoteChannel(inputGate); + inputGate.setInputChannels(channel); + + channel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(1)); + channel.onRecoveredStateBuffer( + EventSerializer.toBuffer(new RecoveryCheckpointBarrier(1L), false)); + channel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(2)); + + stateWriter.start(1L, UNALIGNED); + channel.checkpointStarted(new CheckpointBarrier(1L, 0L, UNALIGNED)); + + Optional head = channel.getNextBuffer(); + assertThat(head).isPresent(); + Optional nextHead = channel.getNextBuffer(); + assertThat(nextHead).isPresent(); + assertThat(nextHead.get().buffer().getSize()).isEqualTo(2); + } + + @Test + void testCheckpointStartedNestedCpIds() throws Exception { + SingleInputGate inputGate = createSingleInputGate(1); + RecordingChannelStateWriter stateWriter = new RecordingChannelStateWriter(); + RemoteInputChannel channel = + InputChannelBuilder.newBuilder() + .setStateWriter(stateWriter) + .setNeedsRecovery(true) + .buildRemoteChannel(inputGate); + inputGate.setInputChannels(channel); + + channel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(1)); + channel.onRecoveredStateBuffer( + EventSerializer.toBuffer(new RecoveryCheckpointBarrier(1L), false)); + channel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(2)); + channel.onRecoveredStateBuffer( + EventSerializer.toBuffer(new RecoveryCheckpointBarrier(2L), false)); + + stateWriter.start(1L, UNALIGNED); + channel.checkpointStarted(new CheckpointBarrier(1L, 0L, UNALIGNED)); + + List persisted1 = stateWriter.getAddedInput().get(channel.getChannelInfo()); + assertThat(persisted1).hasSize(1); + assertThat(persisted1.get(0).getSize()).isEqualTo(1); + } + + @Test + void testCheckpointStartedNotInRecoveryUsesMasterPath() throws Exception { + SingleInputGate inputGate = createSingleInputGate(1); + RecordingChannelStateWriter stateWriter = new RecordingChannelStateWriter(); + RemoteInputChannel channel = + InputChannelBuilder.newBuilder() + .setStateWriter(stateWriter) + .buildRemoteChannel(inputGate); + inputGate.setInputChannels(channel); + channel.requestSubpartitions(); + stateWriter.start(7L, UNALIGNED); + channel.checkpointStarted(new CheckpointBarrier(7L, 0L, UNALIGNED)); + + assertThat(stateWriter.getAddedInput().get(channel.getChannelInfo())).isNullOrEmpty(); + } + + @Test + void testReceivedBuffersHasNoLiveDataBufferDetectsLiveData() throws Exception { + final NetworkBufferPool networkBufferPool = new NetworkBufferPool(4, 4096); + try { + SingleInputGate inputGate = + new SingleInputGateBuilder() + .setBufferPoolFactory(networkBufferPool.createBufferPool(1, 4)) + .setSegmentProvider(networkBufferPool) + .setChannelFactory( + (builder, gate) -> + builder.setNeedsRecovery(true).buildRemoteChannel(gate)) + .build(); + inputGate.setup(); + RemoteInputChannel channel = (RemoteInputChannel) inputGate.getChannel(0); + + channel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(1)); + // During recovery the upstream has no credit and can only send events. A live data + // buffer is a protocol violation that onBuffer must reject at the entry point. + assertThatThrownBy(() -> channel.onBuffer(TestBufferFactory.createBuffer(1), 0, 0, 0)) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Received live data buffer during recovery"); + } finally { + networkBufferPool.destroy(); + } + } + + @Test + void testReceivedBuffersHasNoLiveDataBufferAcceptsPriorityOnly() throws Exception { + final NetworkBufferPool networkBufferPool = new NetworkBufferPool(4, 4096); + try { + RecordingChannelStateWriter stateWriter = new RecordingChannelStateWriter(); + SingleInputGate inputGate = + new SingleInputGateBuilder() + .setBufferPoolFactory(networkBufferPool.createBufferPool(1, 4)) + .setSegmentProvider(networkBufferPool) + .setChannelFactory( + (builder, gate) -> + builder.setStateWriter(stateWriter) + .setNeedsRecovery(true) + .buildRemoteChannel(gate)) + .build(); + inputGate.setup(); + RemoteInputChannel channel = (RemoteInputChannel) inputGate.getChannel(0); + + channel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(1)); + // Mirror what snapshotAndInsertBarriers does: push the + // RecoveryCheckpointBarrier sentinel so collectPreRecoveryBarrier finds it. + channel.onRecoveredStateBuffer( + EventSerializer.toBuffer(new RecoveryCheckpointBarrier(1L), false)); + // Priority event in receivedBuffers is OK (!isBuffer()). + CheckpointBarrier priorityBarrier = new CheckpointBarrier(1L, 0L, UNALIGNED); + channel.onBuffer(toBuffer(priorityBarrier, true), 0, 0, 0); + + stateWriter.start(1L, UNALIGNED); + channel.checkpointStarted(new CheckpointBarrier(1L, 0L, UNALIGNED)); + } finally { + networkBufferPool.destroy(); + } + } + private static final class TestBufferPool extends NoOpBufferPool { @Override diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteRecoveredInputChannelTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteRecoveredInputChannelTest.java new file mode 100644 index 00000000000000..afa6d6fc6d8b82 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteRecoveredInputChannelTest.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.io.network.partition.consumer; + +import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter; +import org.apache.flink.runtime.io.network.buffer.Buffer; +import org.apache.flink.runtime.io.network.util.TestBufferFactory; + +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +class RemoteRecoveredInputChannelTest { + + @Test + void testToInputChannelRequiresEmptyRecoveredBuffers() throws Exception { + SingleInputGate inputGate = new SingleInputGateBuilder().build(); + RemoteRecoveredInputChannel recoveredChannel = + InputChannelBuilder.newBuilder() + .setStateWriter(ChannelStateWriter.NO_OP) + .buildRemoteRecoveredChannel(inputGate); + + Buffer buffer = TestBufferFactory.createBuffer(13); + recoveredChannel.onRecoveredStateBuffer(buffer); + + try { + recoveredChannel.finishReadRecoveredState(); + assertThatThrownBy(() -> recoveredChannel.toInputChannel(true)) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Received buffer should be empty"); + } finally { + recoveredChannel.releaseAllResources(); + } + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateBuilder.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateBuilder.java index a4da811f8a32a4..e4a4c289dc6e8d 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateBuilder.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateBuilder.java @@ -83,8 +83,6 @@ public class SingleInputGateBuilder { private TieredStorageConsumerClient tieredStorageConsumerClient = null; - private boolean isCheckpointingDuringRecoveryEnabled = false; - public SingleInputGateBuilder setPartitionProducerStateProvider( PartitionProducerStateProvider partitionProducerStateProvider) { @@ -169,11 +167,6 @@ public SingleInputGateBuilder setTieredStorageConsumerClient( return this; } - public SingleInputGateBuilder setCheckpointingDuringRecoveryEnabled(boolean enabled) { - this.isCheckpointingDuringRecoveryEnabled = enabled; - return this; - } - public SingleInputGate build() { SingleInputGate gate = new SingleInputGate( @@ -202,7 +195,6 @@ public SingleInputGate build() { .toArray(InputChannel[]::new)); } gate.setTieredStorageService(null, tieredStorageConsumerClient, null); - gate.setCheckpointingDuringRecoveryEnabled(isCheckpointingDuringRecoveryEnabled); return gate; } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java index b2cc9d7ce3c9ca..f7f0b744fb9fdd 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java @@ -142,36 +142,6 @@ void testCheckpointsDeclinedUnlessStateConsumed() { .isInstanceOf(CheckpointException.class); } - @Test - void testBufferFilteringCompleteFutureAggregation() throws Exception { - final NettyShuffleEnvironment environment = createNettyShuffleEnvironment(); - final SingleInputGate inputGate = createInputGate(environment); - try (Closer closer = Closer.create()) { - closer.register(environment::close); - closer.register(inputGate::close); - - // Enable unaligned during recovery for this test so that - // bufferFilteringCompleteFuture is completed by finishReadRecoveredState() - inputGate.setCheckpointingDuringRecoveryEnabled(true); - inputGate.setup(); - - // Initially, the aggregated future should not be completed - assertThat(inputGate.getBufferFilteringCompleteFuture()).isNotDone(); - - // After finishing read recovered state, bufferFilteringCompleteFuture should be - // completed (only when config is enabled) - inputGate.finishReadRecoveredState(); - assertThat(inputGate.getBufferFilteringCompleteFuture()).isDone(); - - // stateConsumedFuture should not be completed until data is consumed - assertThat(inputGate.getStateConsumedFuture()).isNotDone(); - - // Consuming the EndOfInputChannelStateEvent should complete stateConsumedFuture - inputGate.pollNext(); - assertThat(inputGate.getStateConsumedFuture()).isDone(); - } - } - /** * Tests {@link InputGate#setup()} should create the respective {@link BufferPool} and assign * exclusive buffers for {@link RemoteInputChannel}s, but should not request partitions. diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/TestInputChannel.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/TestInputChannel.java index 14a3654a666399..1759d10cd5c238 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/TestInputChannel.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/TestInputChannel.java @@ -19,6 +19,7 @@ package org.apache.flink.runtime.io.network.partition.consumer; import org.apache.flink.metrics.SimpleCounter; +import org.apache.flink.runtime.checkpoint.channel.RecoveryCheckpointBarrier; import org.apache.flink.runtime.event.TaskEvent; import org.apache.flink.runtime.io.network.api.EndOfData; import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent; @@ -31,8 +32,10 @@ import javax.annotation.Nullable; import java.io.IOException; +import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Collection; +import java.util.Deque; import java.util.Optional; import java.util.Queue; import java.util.concurrent.CompletableFuture; @@ -45,7 +48,7 @@ import static org.assertj.core.api.Assertions.assertThat; /** A mocked input channel. */ -public class TestInputChannel extends InputChannel { +public class TestInputChannel extends InputChannel implements RecoverableInputChannel { private final Queue buffers = new ConcurrentLinkedQueue<>(); @@ -259,6 +262,45 @@ public void notifyRequiredSegmentId(int subpartitionId, int segmentId) { requiredSegmentIdFuture.complete(segmentId); } + private final Deque recoveredBuffersSpy = new ArrayDeque<>(); + private boolean finishRecoveredBufferDeliveryCalled = false; + + @Override + public void onRecoveredStateBuffer(Buffer buffer) { + recoveredBuffersSpy.add(buffer); + } + + @Override + public void finishRecoveredBufferDelivery() { + finishRecoveredBufferDeliveryCalled = true; + } + + @Override + public void insertRecoveryCheckpointBarrierIfInRecovery(long checkpointId) throws IOException { + if (!finishRecoveredBufferDeliveryCalled || !recoveredBuffersSpy.isEmpty()) { + recoveredBuffersSpy.add( + EventSerializer.toBuffer(new RecoveryCheckpointBarrier(checkpointId), false)); + } + } + + @Override + public Buffer requestRecoveryBufferBlocking() { + throw new UnsupportedOperationException("TestInputChannel does not back recovery drain"); + } + + @Override + public void onRecoveredStateConsumed() { + // No-op in this test stub. + } + + public Deque getRecoveredBuffersSpy() { + return recoveredBuffersSpy; + } + + public boolean isFinishRecoveredBufferDeliveryCalled() { + return finishRecoveredBufferDeliveryCalled; + } + public void assertReturnedEventsAreRecycled() { assertReturnedBuffersAreRecycled(false, true); } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java index 1ed1a42a66ea01..419246137e8c26 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java @@ -18,14 +18,12 @@ package org.apache.flink.runtime.io.network.partition.consumer; -import org.apache.flink.metrics.SimpleCounter; import org.apache.flink.runtime.clusterframework.types.ResourceID; import org.apache.flink.runtime.io.PullingAsyncDataInput; import org.apache.flink.runtime.io.network.api.StopMode; import org.apache.flink.runtime.io.network.buffer.BufferBuilderTestUtils; import org.apache.flink.runtime.io.network.partition.NoOpResultSubpartitionView; import org.apache.flink.runtime.io.network.partition.ResultPartitionID; -import org.apache.flink.runtime.io.network.partition.ResultSubpartitionIndexSet; import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGateTest.TestingResultPartitionManager; import org.junit.jupiter.api.Test; @@ -277,62 +275,6 @@ void testGetChannelWithShiftedGateIndexes() { assertThat(unionInputGate.getChannel(1)).isEqualTo(inputChannel2); } - @Test - void testBufferFilteringCompleteFutureAggregation() throws IOException { - // Create 2 SingleInputGates, each with 1 RecoveredInputChannel - SingleInputGate ig1 = - new SingleInputGateBuilder().setCheckpointingDuringRecoveryEnabled(true).build(); - RecoveredInputChannel channel1 = buildRecoveredChannel(ig1); - ig1.setInputChannels(channel1); - - SingleInputGate ig2 = - new SingleInputGateBuilder() - .setSingleInputGateIndex(1) - .setCheckpointingDuringRecoveryEnabled(true) - .build(); - RecoveredInputChannel channel2 = buildRecoveredChannel(ig2); - ig2.setInputChannels(channel2); - - UnionInputGate union = new UnionInputGate(ig1, ig2); - - // Initially, bufferFilteringCompleteFuture should not be done - assertThat(union.getBufferFilteringCompleteFuture()).isNotDone(); - assertThat(union.getStateConsumedFuture()).isNotDone(); - - // Complete buffer filtering on first gate only - channel1.finishReadRecoveredState(); - assertThat(ig1.getBufferFilteringCompleteFuture()).isDone(); - assertThat(union.getBufferFilteringCompleteFuture()).isNotDone(); - - // Complete buffer filtering on second gate - channel2.finishReadRecoveredState(); - assertThat(ig2.getBufferFilteringCompleteFuture()).isDone(); - assertThat(union.getBufferFilteringCompleteFuture()).isDone(); - - // State consumed futures should still NOT be done (state not consumed yet) - assertThat(union.getStateConsumedFuture()).isNotDone(); - } - - private static RecoveredInputChannel buildRecoveredChannel(SingleInputGate inputGate) { - return new RecoveredInputChannel( - inputGate, - 0, - new ResultPartitionID(), - new ResultSubpartitionIndexSet(0), - 0, - 0, - new SimpleCounter(), - new SimpleCounter(), - 10) { - @Override - protected InputChannel toInputChannelInternal( - java.util.ArrayDeque - remainingBuffers) { - throw new UnsupportedOperationException(); - } - }; - } - @Test void testEmptyPull() throws IOException, InterruptedException { final SingleInputGate inputGate1 = createInputGate(1); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ChannelPersistenceITCase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ChannelPersistenceITCase.java index f64a4d9fb9cac7..1eb305cf10a4d0 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ChannelPersistenceITCase.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ChannelPersistenceITCase.java @@ -53,10 +53,12 @@ import org.apache.flink.util.function.SupplierWithException; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; import java.io.IOException; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; +import java.nio.file.Path; import java.util.ArrayList; import java.util.List; import java.util.Map; @@ -78,6 +80,8 @@ /** ChannelPersistenceITCase. */ class ChannelPersistenceITCase { + @TempDir static Path tmpDir; + private static final Random RANDOM = new Random(System.currentTimeMillis()); private static final JobID JOB_ID = new JobID(); private static final JobVertexID JOB_VERTEX_ID = new JobVertexID(); @@ -120,7 +124,10 @@ void testReadWritten() throws Exception { try { int numChannels = 1; InputGate gate = buildGate(networkBufferPool, numChannels); - reader.readInputData(new InputGate[] {gate}, RecordFilterContext.disabled()); + reader.readInputData( + new InputGate[] {gate}, + RecordFilterContext.disabled( + new String[] {tmpDir.toAbsolutePath().toString()})); assertThat(collectBytes(gate::pollNext, BufferOrEvent::getBuffer)) .isEqualTo(inputChannelInfoData); diff --git a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/MockIndexedInputGate.java b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/MockIndexedInputGate.java index 584aeb7eb9089f..53bba67f7a2be7 100644 --- a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/MockIndexedInputGate.java +++ b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/MockIndexedInputGate.java @@ -57,11 +57,6 @@ public CompletableFuture getStateConsumedFuture() { return CompletableFuture.completedFuture(null); } - @Override - public CompletableFuture getBufferFilteringCompleteFuture() { - return CompletableFuture.completedFuture(null); - } - @Override public void finishReadRecoveredState() {} @@ -86,6 +81,11 @@ public InputChannel getChannel(int channelIndex) { throw new UnsupportedOperationException(); } + @Override + public InputChannel getChannel(InputChannelInfo channelInfo) { + throw new UnsupportedOperationException(); + } + @Override public void setChannelStateWriter(ChannelStateWriter channelStateWriter) {} @@ -142,12 +142,4 @@ public ResultPartitionType getConsumedPartitionType() { @Override public void triggerDebloating() {} - - @Override - public void setCheckpointingDuringRecoveryEnabled(boolean enabled) {} - - @Override - public boolean isCheckpointingDuringRecoveryEnabled() { - return false; - } } diff --git a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/MockInputGate.java b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/MockInputGate.java index 71b2c43f3306aa..47e3b79a77fc13 100644 --- a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/MockInputGate.java +++ b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/MockInputGate.java @@ -80,11 +80,6 @@ public CompletableFuture getStateConsumedFuture() { return CompletableFuture.completedFuture(null); } - @Override - public CompletableFuture getBufferFilteringCompleteFuture() { - return CompletableFuture.completedFuture(null); - } - @Override public void finishReadRecoveredState() {} @@ -101,6 +96,11 @@ public InputChannel getChannel(int channelIndex) { throw new UnsupportedOperationException(); } + @Override + public InputChannel getChannel(InputChannelInfo channelInfo) { + throw new UnsupportedOperationException(); + } + @Override public List getChannelInfos() { return IntStream.range(0, numberOfChannels) @@ -204,12 +204,4 @@ public int getGateIndex() { public List getUnfinishedChannels() { return Collections.emptyList(); } - - @Override - public void setCheckpointingDuringRecoveryEnabled(boolean enabled) {} - - @Override - public boolean isCheckpointingDuringRecoveryEnabled() { - return false; - } } diff --git a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/benchmark/SingleInputGateBenchmarkFactory.java b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/benchmark/SingleInputGateBenchmarkFactory.java index b850a7cc553702..b3f0aec9dc5e53 100644 --- a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/benchmark/SingleInputGateBenchmarkFactory.java +++ b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/benchmark/SingleInputGateBenchmarkFactory.java @@ -37,7 +37,6 @@ import org.apache.flink.runtime.taskmanager.NettyShuffleEnvironmentConfiguration; import java.io.IOException; -import java.util.ArrayDeque; /** * A benchmark-specific input gate factory which overrides the respective methods of creating {@link @@ -130,7 +129,8 @@ public TestLocalInputChannel( metrics.getNumBytesInLocalCounter(), metrics.getNumBuffersInLocalCounter(), ChannelStateWriter.NO_OP, - new ArrayDeque<>()); + 0, + false); } @Override @@ -186,7 +186,7 @@ public TestRemoteInputChannel( metrics.getNumBytesInRemoteCounter(), metrics.getNumBuffersInRemoteCounter(), ChannelStateWriter.NO_OP, - new ArrayDeque<>()); + false); } @Override diff --git a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/checkpointing/AlignedCheckpointsMassiveRandomTest.java b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/checkpointing/AlignedCheckpointsMassiveRandomTest.java index 619873c387d08f..ad792d221c2efb 100644 --- a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/checkpointing/AlignedCheckpointsMassiveRandomTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/checkpointing/AlignedCheckpointsMassiveRandomTest.java @@ -178,6 +178,11 @@ public InputChannel getChannel(int channelIndex) { throw new UnsupportedOperationException(); } + @Override + public InputChannel getChannel(InputChannelInfo channelInfo) { + throw new UnsupportedOperationException(); + } + @Override public List getChannelInfos() { return IntStream.range(0, numberOfChannels) @@ -263,11 +268,6 @@ public CompletableFuture getStateConsumedFuture() { return CompletableFuture.completedFuture(null); } - @Override - public CompletableFuture getBufferFilteringCompleteFuture() { - return CompletableFuture.completedFuture(null); - } - @Override public void finishReadRecoveredState() {} @@ -286,13 +286,5 @@ public int getGateIndex() { public List getUnfinishedChannels() { return Collections.emptyList(); } - - @Override - public void setCheckpointingDuringRecoveryEnabled(boolean enabled) {} - - @Override - public boolean isCheckpointingDuringRecoveryEnabled() { - return false; - } } } diff --git a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/checkpointing/AlternatingCollectingBarriersDispatchHookTest.java b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/checkpointing/AlternatingCollectingBarriersDispatchHookTest.java new file mode 100644 index 00000000000000..dfd865172882f6 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/checkpointing/AlternatingCollectingBarriersDispatchHookTest.java @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.runtime.io.checkpointing; + +import org.junit.jupiter.api.Test; + +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Source-level invariants for the {@code alignedCheckpointTimeout} → UC switch in {@link + * AlternatingCollectingBarriers}. Note this is NOT a {@code barrierReceived} hook — the parent + * {@code AbstractAlternatingAlignedBarrierHandlerState.barrierReceived} is {@code final}; the only + * place to switch from aligned to UC is {@code alignedCheckpointTimeout}. + */ +class AlternatingCollectingBarriersDispatchHookTest { + + @Test + void testDispatcherIsCalledBeforeTriggerGlobalCheckpoint() throws Exception { + String source = readSource(); + + int initIdx = source.indexOf("controller.initInputsCheckpoint(unalignedBarrier)"); + int dispatchIdx = source.indexOf("onCheckpointStartedForAllInputs(unalignedBarrier)"); + int triggerIdx = source.indexOf("controller.triggerGlobalCheckpoint(unalignedBarrier)"); + + assertThat(initIdx).as("initInputsCheckpoint call exists").isNotNegative(); + assertThat(dispatchIdx).as("dispatcher call exists").isNotNegative(); + assertThat(triggerIdx).as("triggerGlobalCheckpoint call exists").isNotNegative(); + + assertThat(initIdx) + .as( + "initInputsCheckpoint precedes dispatcher (cpId result registered before " + + "Step 3)") + .isLessThan(dispatchIdx); + assertThat(dispatchIdx) + .as("dispatcher precedes triggerGlobalCheckpoint (master ordering)") + .isLessThan(triggerIdx); + } + + @Test + void testHookIsInAlignedCheckpointTimeoutNotBarrierReceived() throws Exception { + String source = readSource(); + + // Hard guard: AbstractAlternatingAlignedBarrierHandlerState.barrierReceived is final, so + // AlternatingCollectingBarriers cannot override it. The UC switch lives in + // alignedCheckpointTimeout only. + assertThat(source) + .as("alignedCheckpointTimeout is the UC switch entry on this state class") + .contains("public BarrierHandlerState alignedCheckpointTimeout("); + assertThat(source) + .as("No barrierReceived override should exist on AlternatingCollectingBarriers") + .doesNotContain("public BarrierHandlerState barrierReceived("); + + int timeoutIdx = source.indexOf("alignedCheckpointTimeout("); + int dispatchIdx = source.indexOf("onCheckpointStartedForAllInputs(", timeoutIdx); + assertThat(dispatchIdx) + .as("dispatcher call is inside alignedCheckpointTimeout body") + .isNotNegative(); + } + + @Test + void testNoLegacyPerInputCheckpointStartedLoopInTimeout() throws Exception { + String source = readSource(); + + int methodStart = source.indexOf("public BarrierHandlerState alignedCheckpointTimeout("); + assertThat(methodStart).isNotNegative(); + int methodEnd = findMethodEnd(source, methodStart); + String body = source.substring(methodStart, methodEnd); + + assertThat(body) + .as( + "Pre-trigger per-input checkpointStarted loop should be removed; dispatcher " + + "now owns the fan-out") + .doesNotContain("input.checkpointStarted(unalignedBarrier)"); + } + + private static String readSource() throws Exception { + Path candidate = + Paths.get( + "src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/AlternatingCollectingBarriers.java"); + if (!Files.exists(candidate)) { + candidate = + Paths.get( + "flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/AlternatingCollectingBarriers.java"); + } + return new String(Files.readAllBytes(candidate)); + } + + private static int findMethodEnd(String source, int start) { + int firstBrace = source.indexOf('{', start); + int depth = 1; + for (int i = firstBrace + 1; i < source.length(); i++) { + char c = source.charAt(i); + if (c == '{') { + depth++; + } else if (c == '}') { + depth--; + if (depth == 0) { + return i + 1; + } + } + } + return source.length(); + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/checkpointing/AlternatingWaitingForFirstBarrierUnalignedDispatchHookTest.java b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/checkpointing/AlternatingWaitingForFirstBarrierUnalignedDispatchHookTest.java new file mode 100644 index 00000000000000..036e5ad70c9684 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/checkpointing/AlternatingWaitingForFirstBarrierUnalignedDispatchHookTest.java @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.runtime.io.checkpointing; + +import org.junit.jupiter.api.Test; + +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Source-level invariants for the recovery-checkpoint dispatcher hook in {@link + * AlternatingWaitingForFirstBarrierUnaligned}: + * + *

    + *
  1. The dispatcher call exists. + *
  2. It runs after {@code controller.initInputsCheckpoint} (so the cpId's {@code + * ChannelStateWriteResult} is registered before the writer is invoked) and before {@code + * controller.triggerGlobalCheckpoint}. + *
  3. No vestigial per-input {@code checkpointStarted} loop remains alongside the dispatcher. + *
+ */ +class AlternatingWaitingForFirstBarrierUnalignedDispatchHookTest { + + @Test + void testDispatcherIsCalledBeforeTriggerGlobalCheckpoint() throws Exception { + String source = readSource(); + + int initIdx = source.indexOf("controller.initInputsCheckpoint(unalignedBarrier)"); + int dispatchIdx = source.indexOf("onCheckpointStartedForAllInputs(unalignedBarrier)"); + int triggerIdx = source.indexOf("controller.triggerGlobalCheckpoint(unalignedBarrier)"); + + assertThat(initIdx).as("initInputsCheckpoint call exists").isNotNegative(); + assertThat(dispatchIdx).as("dispatcher call exists").isNotNegative(); + assertThat(triggerIdx).as("triggerGlobalCheckpoint call exists").isNotNegative(); + + assertThat(initIdx) + .as( + "initInputsCheckpoint must precede the dispatcher so the cpId result is " + + "registered when Step 3 fires") + .isLessThan(dispatchIdx); + assertThat(dispatchIdx) + .as("dispatcher must precede triggerGlobalCheckpoint (master ordering)") + .isLessThan(triggerIdx); + } + + @Test + void testNoLegacyPerInputCheckpointStartedLoopInBarrierReceived() throws Exception { + String source = readSource(); + + // barrierReceived still iterates input.checkpointStopped in the allBarriersReceived + // branch — that loop is fine. We only forbid a parallel input.checkpointStarted loop, + // which the dispatcher now owns. + int methodStart = source.indexOf("public BarrierHandlerState barrierReceived("); + assertThat(methodStart).isNotNegative(); + + int methodEnd = findMethodEnd(source, methodStart); + String body = source.substring(methodStart, methodEnd); + + assertThat(body) + .as( + "Pre-trigger per-input checkpointStarted loop should be removed; dispatcher " + + "now owns the fan-out") + .doesNotContain("input.checkpointStarted(unalignedBarrier)"); + } + + private static String readSource() throws Exception { + Path candidate = + Paths.get( + "src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/AlternatingWaitingForFirstBarrierUnaligned.java"); + if (!Files.exists(candidate)) { + candidate = + Paths.get( + "flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/AlternatingWaitingForFirstBarrierUnaligned.java"); + } + return new String(Files.readAllBytes(candidate)); + } + + private static int findMethodEnd(String source, int start) { + int firstBrace = source.indexOf('{', start); + int depth = 1; + for (int i = firstBrace + 1; i < source.length(); i++) { + char c = source.charAt(i); + if (c == '{') { + depth++; + } else if (c == '}') { + depth--; + if (depth == 0) { + return i + 1; + } + } + } + return source.length(); + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/checkpointing/ChannelStateDispatcherTest.java b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/checkpointing/ChannelStateDispatcherTest.java new file mode 100644 index 00000000000000..53c355727fc2b3 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/checkpointing/ChannelStateDispatcherTest.java @@ -0,0 +1,290 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.runtime.io.checkpointing; + +import org.apache.flink.runtime.checkpoint.CheckpointException; +import org.apache.flink.runtime.checkpoint.CheckpointOptions; +import org.apache.flink.runtime.checkpoint.CheckpointType; +import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter; +import org.apache.flink.runtime.checkpoint.channel.FetchedChannelStateReader; +import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo; +import org.apache.flink.runtime.checkpoint.channel.RecoveryCheckpointTrigger; +import org.apache.flink.runtime.checkpoint.channel.ResultSubpartitionInfo; +import org.apache.flink.runtime.io.network.api.CheckpointBarrier; +import org.apache.flink.runtime.io.network.buffer.Buffer; +import org.apache.flink.runtime.io.network.partition.consumer.CheckpointableInput; +import org.apache.flink.runtime.state.CheckpointStorageLocationReference; +import org.apache.flink.util.CloseableIterator; + +import org.junit.jupiter.api.Test; + +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Verifies the {@link ChannelState#onCheckpointStartedForAllInputs} dispatcher: call ordering, + * feature-off no-op routing through the {@link RecoveryCheckpointTrigger#NO_OP} singleton, and + * absence of an outer feature-flag branch. + */ +class ChannelStateDispatcherTest { + + private static final long CHECKPOINT_ID = 7L; + + @Test + void testStepOrderingFeatureOn() throws Exception { + List trace = new ArrayList<>(); + // An empty reader is sufficient to verify ordering. + FetchedChannelStateReader snap = FetchedChannelStateReader.emptyReader(); + RecordingTrigger trigger = new RecordingTrigger(trace, snap); + RecordingWriter writer = new RecordingWriter(trace); + CheckpointableInput input1 = new RecordingInput(trace, "in1"); + CheckpointableInput input2 = new RecordingInput(trace, "in2"); + + ChannelState state = + new ChannelState(new CheckpointableInput[] {input1, input2}, trigger, writer); + + CheckpointBarrier barrier = newUnalignedBarrier(); + state.onCheckpointStartedForAllInputs(barrier); + + assertThat(trace) + .containsExactly( + "trigger.snapshotAndInsertBarriers:" + CHECKPOINT_ID, + "in1.checkpointStarted:" + CHECKPOINT_ID, + "in2.checkpointStarted:" + CHECKPOINT_ID, + "writer.addInputDataFromSpill:" + CHECKPOINT_ID); + } + + @Test + void testStepOrderingFeatureOff() throws Exception { + List trace = new ArrayList<>(); + RecordingWriter writer = new RecordingWriter(trace); + CheckpointableInput input = new RecordingInput(trace, "in1"); + + ChannelState state = + new ChannelState( + new CheckpointableInput[] {input}, RecoveryCheckpointTrigger.NO_OP, writer); + + state.onCheckpointStartedForAllInputs(newUnalignedBarrier()); + + assertThat(trace) + .containsExactly( + "in1.checkpointStarted:" + CHECKPOINT_ID, + "writer.addInputDataFromSpill:" + CHECKPOINT_ID); + assertThat(writer.lastSnapshotWasEmpty.get()).isTrue(); + } + + @Test + void testEmptySnapshotStillSubmitted() throws Exception { + // Empty readers (no spill files) are no longer short-circuited; they still reach + // addInputDataFromSpill on the writer thread. + List trace = new ArrayList<>(); + FetchedChannelStateReader emptySnap = FetchedChannelStateReader.emptyReader(); + RecordingTrigger trigger = new RecordingTrigger(trace, emptySnap); + RecordingWriter writer = new RecordingWriter(trace); + + ChannelState state = + new ChannelState( + new CheckpointableInput[] {new RecordingInput(trace, "in1")}, + trigger, + writer); + + state.onCheckpointStartedForAllInputs(newUnalignedBarrier()); + + // Empty reader must still reach the writer (no inline short-circuit). + assertThat(writer.addInputDataFromSpillCalls.get()).isEqualTo(1); + assertThat(writer.lastSnapshotWasEmpty.get()).isTrue(); + } + + @Test + void testNoIfFilterOnInDispatcher() throws Exception { + // Branch-free routing through the null-object trigger is a hard correctness invariant; + // a feature-flag check inside the dispatcher would silently bypass it. Guard against + // that by scanning the dispatcher source for "filter" / "feature". + Path candidate = + Paths.get( + "src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/ChannelState.java"); + if (!Files.exists(candidate)) { + candidate = + Paths.get( + "flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/ChannelState.java"); + } + assertThat(Files.exists(candidate)) + .as("Located ChannelState.java for the source-level invariant check") + .isTrue(); + + String all = new String(Files.readAllBytes(candidate)); + int methodStart = all.indexOf("public void onCheckpointStartedForAllInputs"); + assertThat(methodStart).isNotNegative(); + int methodEnd = all.indexOf(" private void", methodStart); + String body = all.substring(methodStart, methodEnd > 0 ? methodEnd : all.length()); + + StringBuilder code = new StringBuilder(); + for (String line : body.split("\n")) { + int idx = line.indexOf("//"); + code.append(idx >= 0 ? line.substring(0, idx) : line).append('\n'); + } + String codeOnly = code.toString(); + + assertThat(codeOnly) + .as("Dispatcher must not branch on filter / feature flags") + .doesNotContain("filter") + .doesNotContain("feature"); + } + + private static CheckpointBarrier newUnalignedBarrier() { + return new CheckpointBarrier( + CHECKPOINT_ID, + 1000L, + CheckpointOptions.unaligned( + CheckpointType.CHECKPOINT, + CheckpointStorageLocationReference.getDefault())); + } + + private static final class RecordingTrigger implements RecoveryCheckpointTrigger { + private final List trace; + private final FetchedChannelStateReader snapshot; + + RecordingTrigger(List trace, FetchedChannelStateReader snapshot) { + this.trace = trace; + this.snapshot = snapshot; + } + + @Override + public FetchedChannelStateReader snapshotAndInsertBarriers(long checkpointId) { + trace.add("trigger.snapshotAndInsertBarriers:" + checkpointId); + return snapshot; + } + } + + private static final class RecordingWriter implements ChannelStateWriter { + private final List trace; + final AtomicBoolean lastSnapshotWasEmpty = new AtomicBoolean(false); + final AtomicLong lastCpId = new AtomicLong(-1L); + final AtomicInteger addInputDataFromSpillCalls = new AtomicInteger(0); + + RecordingWriter(List trace) { + this.trace = trace; + } + + @Override + public void start(long checkpointId, CheckpointOptions checkpointOptions) {} + + @Override + public void addInputData( + long checkpointId, + InputChannelInfo info, + int startSeqNum, + CloseableIterator data) {} + + @Override + public void addOutputData( + long checkpointId, ResultSubpartitionInfo info, int startSeqNum, Buffer... data) {} + + @Override + public void addOutputDataFuture( + long checkpointId, + ResultSubpartitionInfo info, + int startSeqNum, + CompletableFuture> data) {} + + @Override + public void finishInput(long checkpointId) {} + + @Override + public void finishOutput(long checkpointId) {} + + @Override + public void abort(long checkpointId, Throwable cause, boolean cleanup) {} + + @Override + public ChannelStateWriteResult getAndRemoveWriteResult(long checkpointId) { + return ChannelStateWriteResult.EMPTY; + } + + @Override + public void addInputDataFromSpill(long checkpointId, FetchedChannelStateReader reader) { + trace.add("writer.addInputDataFromSpill:" + checkpointId); + lastCpId.set(checkpointId); + addInputDataFromSpillCalls.incrementAndGet(); + try { + // Peek whether the reader has any segments by attempting the first advance. + // The first nextSegment() call is exempt from the "previous body consumed" rule. + lastSnapshotWasEmpty.set(reader.nextSegment().isEmpty()); + reader.close(); + } catch (Exception ignored) { + } + } + + @Override + public void close() {} + } + + private static final class RecordingInput implements CheckpointableInput { + + private final List trace; + private final String name; + + RecordingInput(List trace, String name) { + this.trace = trace; + this.name = name; + } + + @Override + public void blockConsumption(InputChannelInfo channelInfo) {} + + @Override + public void resumeConsumption(InputChannelInfo channelInfo) {} + + @Override + public List getChannelInfos() { + return Collections.emptyList(); + } + + @Override + public int getNumberOfInputChannels() { + return 0; + } + + @Override + public void checkpointStarted(CheckpointBarrier barrier) throws CheckpointException { + trace.add(name + ".checkpointStarted:" + barrier.getId()); + } + + @Override + public void checkpointStopped(long cancelledCheckpointId) {} + + @Override + public int getInputGateIndex() { + return 0; + } + + @Override + public void convertToPriorityEvent(int channelIndex, int sequenceNumber) {} + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/checkpointing/TestBarrierHandlerFactory.java b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/checkpointing/TestBarrierHandlerFactory.java index 1b345ebe0d2a05..8c281c04a6b13f 100644 --- a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/checkpointing/TestBarrierHandlerFactory.java +++ b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/checkpointing/TestBarrierHandlerFactory.java @@ -20,6 +20,7 @@ import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter; import org.apache.flink.runtime.checkpoint.channel.RecordingChannelStateWriter; +import org.apache.flink.runtime.checkpoint.channel.RecoveryCheckpointTrigger; import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate; import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable; import org.apache.flink.streaming.runtime.tasks.TestSubtaskCheckpointCoordinator; @@ -72,6 +73,8 @@ public SingleCheckpointBarrierHandler create( inputGate.getNumberOfInputChannels(), actionRegistration, enableCheckpointsAfterTasksFinish, + RecoveryCheckpointTrigger.NO_OP, + stateWriter, inputGate); } } diff --git a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilterContextTest.java b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilterContextTest.java index 60494f7acbe5c7..bd17c8d33f0dbe 100644 --- a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilterContextTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilterContextTest.java @@ -36,9 +36,10 @@ class RecordFilterContextTest { @Test void testDisabledContextHasNoGates() { - RecordFilterContext disabled = RecordFilterContext.disabled(); + RecordFilterContext disabled = RecordFilterContext.disabled(new String[] {"/tmp"}); assertThat(disabled.getNumberOfGates()).isEqualTo(0); assertThat(disabled.isCheckpointingDuringRecoveryEnabled()).isFalse(); + assertThat(disabled.getTmpDirectories()).containsExactly("/tmp"); } @Test @@ -72,7 +73,7 @@ void testGetInputConfigThrowsForInvalidIndex() { InflightDataRescalingDescriptor.NO_RESCALE, 0, 128, - null, + new String[] {"/tmp"}, false, MemoryManager.DEFAULT_PAGE_SIZE); @@ -83,8 +84,38 @@ void testGetInputConfigThrowsForInvalidIndex() { } @Test - void testNullTmpDirectoriesConvertedToEmptyArray() { - RecordFilterContext context = + void testEnabledContextRejectsNullOrEmptyTmpDirectories() { + // When checkpointing-during-recovery is enabled, the spilling path needs spill + // directories, so null/empty tmpDirectories are rejected. + assertThatThrownBy( + () -> + new RecordFilterContext( + new RecordFilterContext.InputFilterConfig[0], + InflightDataRescalingDescriptor.NO_RESCALE, + 0, + 128, + null, + true, + MemoryManager.DEFAULT_PAGE_SIZE)) + .isInstanceOf(NullPointerException.class); + assertThatThrownBy( + () -> + new RecordFilterContext( + new RecordFilterContext.InputFilterConfig[0], + InflightDataRescalingDescriptor.NO_RESCALE, + 0, + 128, + new String[0], + true, + MemoryManager.DEFAULT_PAGE_SIZE)) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + void testDisabledContextToleratesNullOrEmptyTmpDirectories() { + // A disabled context never spills, so it needs no spill directories: null/empty are + // tolerated and normalized to an empty array. + RecordFilterContext fromNull = new RecordFilterContext( new RecordFilterContext.InputFilterConfig[0], InflightDataRescalingDescriptor.NO_RESCALE, @@ -93,8 +124,18 @@ void testNullTmpDirectoriesConvertedToEmptyArray() { null, false, MemoryManager.DEFAULT_PAGE_SIZE); + assertThat(fromNull.getTmpDirectories()).isEmpty(); - assertThat(context.getTmpDirectories()).isNotNull().isEmpty(); + RecordFilterContext fromEmpty = + new RecordFilterContext( + new RecordFilterContext.InputFilterConfig[0], + InflightDataRescalingDescriptor.NO_RESCALE, + 0, + 128, + new String[0], + false, + MemoryManager.DEFAULT_PAGE_SIZE); + assertThat(fromEmpty.getTmpDirectories()).isEmpty(); } @Test @@ -111,7 +152,7 @@ void testIsAmbiguousWhenDisabled() { descriptor, 0, 128, - null, + new String[] {"/tmp"}, false, MemoryManager.DEFAULT_PAGE_SIZE); @@ -132,7 +173,7 @@ void testIsAmbiguousWhenEnabled() { descriptor, 0, 128, - null, + new String[] {"/tmp"}, true, MemoryManager.DEFAULT_PAGE_SIZE); @@ -152,7 +193,7 @@ void testIsAmbiguousForNonAmbiguousSubtask() { descriptor, 0, 128, - null, + new String[] {"/tmp"}, true, MemoryManager.DEFAULT_PAGE_SIZE); @@ -170,7 +211,7 @@ void testMemorySegmentSizeExposedAndValidated() { InflightDataRescalingDescriptor.NO_RESCALE, 0, 128, - null, + new String[] {"/tmp"}, false, MemoryManager.DEFAULT_PAGE_SIZE * 2); @@ -184,7 +225,7 @@ void testMemorySegmentSizeExposedAndValidated() { InflightDataRescalingDescriptor.NO_RESCALE, 0, 128, - null, + new String[] {"/tmp"}, false, 0)) .isInstanceOf(IllegalArgumentException.class); @@ -195,7 +236,7 @@ void testMemorySegmentSizeExposedAndValidated() { InflightDataRescalingDescriptor.NO_RESCALE, 0, 128, - null, + new String[] {"/tmp"}, false, -1)) .isInstanceOf(IllegalArgumentException.class); diff --git a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/tasks/TaskCheckpointingBehaviourTest.java b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/tasks/TaskCheckpointingBehaviourTest.java index 098d86587f93eb..b064511e6e6f62 100644 --- a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/tasks/TaskCheckpointingBehaviourTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/tasks/TaskCheckpointingBehaviourTest.java @@ -94,6 +94,7 @@ import org.apache.flink.util.concurrent.Executors; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; import javax.annotation.Nonnull; import javax.annotation.Nullable; @@ -112,6 +113,7 @@ * checks correct working of different policies how tasks deal with checkpoint failures (fail task, * decline checkpoint and continue). */ +@Timeout(120) class TaskCheckpointingBehaviourTest { private static final OneShotLatch IN_CHECKPOINT_LATCH = new OneShotLatch();