diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriter.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriter.java index 74eef6c93366e0..193d46a3e607fd 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriter.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriter.java @@ -18,6 +18,7 @@ package org.apache.flink.runtime.checkpoint.channel; import org.apache.flink.annotation.Internal; +import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.runtime.checkpoint.CheckpointOptions; import org.apache.flink.runtime.io.network.buffer.Buffer; import org.apache.flink.runtime.state.InputChannelStateHandle; @@ -72,7 +73,7 @@ public void fail(Throwable e) { resultSubpartitionStateHandles.completeExceptionally(e); } - boolean isDone() { + public boolean isDone() { return inputChannelStateHandles.isDone() && resultSubpartitionStateHandles.isDone(); } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriterImpl.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriterImpl.java index 40d09f1a21a948..b1091d2b2bf650 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriterImpl.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriterImpl.java @@ -18,6 +18,7 @@ package org.apache.flink.runtime.checkpoint.channel; import org.apache.flink.annotation.Internal; +import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.runtime.checkpoint.CheckpointOptions; import org.apache.flink.runtime.io.network.buffer.Buffer; import org.apache.flink.runtime.state.CheckpointStateOutputStream; @@ -223,6 +224,12 @@ public ChannelStateWriteResult getAndRemoveWriteResult(long checkpointId) { return result; } + // just for test + @VisibleForTesting + public ChannelStateWriteResult getWriteResult(long checkpointId) { + return results.get(checkpointId); + } + public void open() { executor.start(); } diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/SubtaskCheckpointCoordinatorImpl.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/SubtaskCheckpointCoordinatorImpl.java index 48698a2c13e1f9..6d4de5aacf7c68 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/SubtaskCheckpointCoordinatorImpl.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/SubtaskCheckpointCoordinatorImpl.java @@ -316,6 +316,10 @@ public void checkpointState( // broadcast cancel checkpoint marker to avoid downstream back-pressure due to // checkpoint barrier align. operatorChain.broadcastEvent(new CancelCheckpointMarker(metadata.getCheckpointId())); + channelStateWriter.abort( + metadata.getCheckpointId(), + new CancellationException("checkpoint aborted via notification"), + true); LOG.info( "Checkpoint {} has been notified as aborted, would not trigger any checkpoint.", metadata.getCheckpointId()); diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SubtaskCheckpointCoordinatorTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SubtaskCheckpointCoordinatorTest.java index 92e526e773c3b4..8d21e8bbb28f74 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SubtaskCheckpointCoordinatorTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SubtaskCheckpointCoordinatorTest.java @@ -81,11 +81,15 @@ import java.util.concurrent.atomic.AtomicReference; import java.util.function.Supplier; +import static org.apache.flink.runtime.checkpoint.CheckpointOptions.unaligned; import static org.apache.flink.runtime.checkpoint.CheckpointType.CHECKPOINT; +import static org.apache.flink.runtime.state.ChannelPersistenceITCase.getStreamFactoryFactory; +import static org.apache.flink.runtime.state.CheckpointStorageLocationReference.getDefault; import static org.apache.flink.shaded.guava30.com.google.common.util.concurrent.MoreExecutors.newDirectExecutorService; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -568,6 +572,53 @@ public void snapshotState( } } + @Test + public void testChannelStateWriteResultLeakAndNotFailAfterCheckpointAborted() throws Exception { + MockEnvironment mockEnvironment = MockEnvironment.builder().build(); + final OperatorChain operatorChain = getOperatorChain(mockEnvironment); + String taskName = "test"; + ChannelStateWriterImpl writer = + new ChannelStateWriterImpl(taskName, 0, getStreamFactoryFactory()); + writer.open(); + SubtaskCheckpointCoordinator coordinator = + new SubtaskCheckpointCoordinatorImpl( + new TestCheckpointStorageWorkerView(100), + taskName, + StreamTaskActionExecutor.IMMEDIATE, + newDirectExecutorService(), + new DummyEnvironment(), + (unused1, unused2) -> {}, + (unused1, unused2) -> CompletableFuture.completedFuture(null), + 128, + writer, + true, + (callable, duration) -> () -> {}); + int checkpointId = 1; + coordinator.notifyCheckpointAborted(checkpointId, operatorChain, () -> true); + + coordinator.initInputsCheckpoint( + checkpointId, unaligned(CheckpointType.CHECKPOINT, getDefault())); + ChannelStateWriter.ChannelStateWriteResult writeResult = + writer.getWriteResult(checkpointId); + assertNotNull(writeResult); + assertFalse(writeResult.isDone()); + assertFalse(writeResult.getInputChannelStateHandles().isCompletedExceptionally()); + assertFalse(writeResult.getResultSubpartitionStateHandles().isCompletedExceptionally()); + + coordinator.checkpointState( + new CheckpointMetaData(checkpointId, System.currentTimeMillis()), + CheckpointOptions.forCheckpointWithDefaultLocation(), + new CheckpointMetricsBuilder(), + operatorChain, + false, + () -> true); + assertNull(writer.getWriteResult(checkpointId)); + TimeUnit.MILLISECONDS.sleep(10); + assertTrue(writeResult.isDone()); + assertTrue(writeResult.getInputChannelStateHandles().isCompletedExceptionally()); + assertTrue(writeResult.getResultSubpartitionStateHandles().isCompletedExceptionally()); + } + private OperatorChain getOperatorChain(MockEnvironment mockEnvironment) throws Exception { return new RegularOperatorChain<>( new MockStreamTaskBuilder(mockEnvironment).build(), new NonRecordWriter<>());