Skip to content

Commit

Permalink
[FLINK-28474][checkpoint] Fix the bug ChannelStateWriteResult might n…
Browse files Browse the repository at this point in the history
…ot fail after checkpoint abort
  • Loading branch information
1996fanrui committed Oct 6, 2022
1 parent 17af0d2 commit 52f3fdb
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 2 deletions.
Expand Up @@ -18,6 +18,7 @@
package org.apache.flink.runtime.checkpoint.channel;

import org.apache.flink.annotation.Internal;
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.runtime.checkpoint.CheckpointOptions;
import org.apache.flink.runtime.io.network.buffer.Buffer;
import org.apache.flink.runtime.state.InputChannelStateHandle;
Expand Down Expand Up @@ -72,9 +73,21 @@ public void fail(Throwable e) {
resultSubpartitionStateHandles.completeExceptionally(e);
}

boolean isDone() {
public boolean isDone() {
return inputChannelStateHandles.isDone() && resultSubpartitionStateHandles.isDone();
}

@VisibleForTesting
public void waitForDone() {
try {
inputChannelStateHandles.get();
} catch (Throwable ignored) {
}
try {
resultSubpartitionStateHandles.get();
} catch (Throwable ignored) {
}
}
}

/**
Expand Down
Expand Up @@ -18,6 +18,7 @@
package org.apache.flink.runtime.checkpoint.channel;

import org.apache.flink.annotation.Internal;
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.runtime.checkpoint.CheckpointOptions;
import org.apache.flink.runtime.io.network.buffer.Buffer;
import org.apache.flink.runtime.state.CheckpointStateOutputStream;
Expand Down Expand Up @@ -223,6 +224,12 @@ public ChannelStateWriteResult getAndRemoveWriteResult(long checkpointId) {
return result;
}

// just for test
@VisibleForTesting
public ChannelStateWriteResult getWriteResult(long checkpointId) {
return results.get(checkpointId);
}

public void open() {
executor.start();
}
Expand Down
Expand Up @@ -95,6 +95,10 @@ class SubtaskCheckpointCoordinatorImpl implements SubtaskCheckpointCoordinator {
/** The IDs of the checkpoint for which we are notified aborted. */
private final Set<Long> abortedCheckpointIds;

private final int maxRecordAbortedCheckpoints;

private long maxAbortedCheckpointId = 0;

private long lastCheckpointId;

/** Lock that guards state of AsyncCheckpointRunnable registry. * */
Expand Down Expand Up @@ -180,6 +184,7 @@ class SubtaskCheckpointCoordinatorImpl implements SubtaskCheckpointCoordinator {
this.prepareInputSnapshot = prepareInputSnapshot;
this.abortedCheckpointIds =
createAbortedCheckpointSetWithLimitSize(maxRecordAbortedCheckpoints);
this.maxRecordAbortedCheckpoints = maxRecordAbortedCheckpoints;
this.lastCheckpointId = -1L;
this.closed = false;
this.enableCheckpointAfterTasksFinished = enableCheckpointAfterTasksFinished;
Expand Down Expand Up @@ -287,6 +292,10 @@ public void checkpointState(
// broadcast cancel checkpoint marker to avoid downstream back-pressure due to
// checkpoint barrier align.
operatorChain.broadcastEvent(new CancelCheckpointMarker(metadata.getCheckpointId()));
channelStateWriter.abort(
metadata.getCheckpointId(),
new CancellationException("checkpoint aborted via notification"),
true);
LOG.info(
"Checkpoint {} has been notified as aborted, would not trigger any checkpoint.",
metadata.getCheckpointId());
Expand Down Expand Up @@ -432,6 +441,7 @@ private void notifyCheckpoint(
if (checkpointId > lastCheckpointId) {
// only record checkpoints that have not triggered on task side.
abortedCheckpointIds.add(checkpointId);
maxAbortedCheckpointId = Math.max(maxAbortedCheckpointId, checkpointId);
}
}

Expand Down Expand Up @@ -544,7 +554,8 @@ int getAbortedCheckpointSize() {
}

private boolean checkAndClearAbortedStatus(long checkpointId) {
return abortedCheckpointIds.remove(checkpointId);
return abortedCheckpointIds.remove(checkpointId)
|| checkpointId + maxRecordAbortedCheckpoints < maxAbortedCheckpointId;
}

private void registerAsyncCheckpointRunnable(
Expand Down
Expand Up @@ -82,10 +82,12 @@
import java.util.function.Supplier;

import static org.apache.flink.runtime.checkpoint.CheckpointType.CHECKPOINT;
import static org.apache.flink.runtime.state.ChannelPersistenceITCase.getStreamFactoryFactory;
import static org.apache.flink.shaded.guava30.com.google.common.util.concurrent.MoreExecutors.newDirectExecutorService;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;

Expand Down Expand Up @@ -569,6 +571,58 @@ public void snapshotState(
}
}

@Test
public void testChannelStateWriteResultLeakAndNotFailAfterCheckpointAborted() throws Exception {
String taskName = "test";
try (MockEnvironment mockEnvironment = MockEnvironment.builder().build();
ChannelStateWriterImpl writer =
new ChannelStateWriterImpl(taskName, 0, getStreamFactoryFactory());
SubtaskCheckpointCoordinator coordinator =
new SubtaskCheckpointCoordinatorImpl(
new TestCheckpointStorageWorkerView(100),
taskName,
StreamTaskActionExecutor.IMMEDIATE,
newDirectExecutorService(),
new DummyEnvironment(),
(unused1, unused2) -> {},
(unused1, unused2) -> CompletableFuture.completedFuture(null),
1,
writer,
true,
(callable, duration) -> () -> {})) {
writer.open();
final OperatorChain<?, ?> operatorChain = getOperatorChain(mockEnvironment);
int checkpointId = 1;
// Abort checkpoint 1
coordinator.notifyCheckpointAborted(checkpointId, operatorChain, () -> true);

coordinator.initInputsCheckpoint(
checkpointId,
CheckpointOptions.unaligned(
CheckpointType.CHECKPOINT,
CheckpointStorageLocationReference.getDefault()));
ChannelStateWriter.ChannelStateWriteResult writeResult =
writer.getWriteResult(checkpointId);
assertNotNull(writeResult);
assertFalse(writeResult.isDone());
assertFalse(writeResult.getInputChannelStateHandles().isCompletedExceptionally());
assertFalse(writeResult.getResultSubpartitionStateHandles().isCompletedExceptionally());

coordinator.checkpointState(
new CheckpointMetaData(checkpointId, System.currentTimeMillis()),
CheckpointOptions.forCheckpointWithDefaultLocation(),
new CheckpointMetricsBuilder(),
operatorChain,
false,
() -> true);
assertNull(writer.getWriteResult(checkpointId));
writeResult.waitForDone();
assertTrue(writeResult.isDone());
assertTrue(writeResult.getInputChannelStateHandles().isCompletedExceptionally());
assertTrue(writeResult.getResultSubpartitionStateHandles().isCompletedExceptionally());
}
}

private OperatorChain<?, ?> getOperatorChain(MockEnvironment mockEnvironment) throws Exception {
return new RegularOperatorChain<>(
new MockStreamTaskBuilder(mockEnvironment).build(), new NonRecordWriter<>());
Expand Down

0 comments on commit 52f3fdb

Please sign in to comment.