From 90ca438106e63c5032ee2ad27e54e9f573eac386 Mon Sep 17 00:00:00 2001 From: Stephan Ewen Date: Mon, 27 Mar 2017 17:20:47 +0200 Subject: [PATCH] [FLINK-6390] [checkpoints] Add API for checkpoints that are triggered via external systems This includes - A interface for hooks that are called by the checkpoint coordinator to trigger/restore a checkpoint - A source extension that triggers the operator checkpoints and barrier injection on certain events Because this changes the checkpoint metadata format, the commit introduces a new metadata format version. This closes #3782 --- .../core/io/SimpleVersionedSerializer.java | 80 +++ .../org/apache/flink/util/StringUtils.java | 7 + .../checkpoint/savepoint/SavepointV0.java | 10 + .../savepoint/SavepointV0Serializer.java | 12 +- .../checkpoint/CheckpointCoordinator.java | 80 ++- .../checkpoint/CheckpointDeclineReason.java | 4 +- .../checkpoint/CompletedCheckpoint.java | 31 +- .../flink/runtime/checkpoint/MasterState.java | 62 +++ .../checkpoint/MasterTriggerRestoreHook.java | 140 ++++++ .../runtime/checkpoint/PendingCheckpoint.java | 23 +- .../runtime/checkpoint/hooks/MasterHooks.java | 273 ++++++++++ .../checkpoint/savepoint/Savepoint.java | 6 + .../checkpoint/savepoint/SavepointLoader.java | 3 +- .../savepoint/SavepointSerializers.java | 7 +- .../checkpoint/savepoint/SavepointV1.java | 36 +- .../savepoint/SavepointV1Serializer.java | 76 +-- .../checkpoint/savepoint/SavepointV2.java | 91 ++++ .../savepoint/SavepointV2Serializer.java | 468 ++++++++++++++++++ .../executiongraph/ExecutionGraph.java | 9 + .../executiongraph/ExecutionGraphBuilder.java | 18 + .../tasks/JobCheckpointingSettings.java | 48 +- .../CheckpointCoordinatorMasterHooksTest.java | 421 ++++++++++++++++ .../CompletedCheckpointStoreTest.java | 2 +- .../checkpoint/CompletedCheckpointTest.java | 24 +- ...ecutionGraphCheckpointCoordinatorTest.java | 1 + ...ntV1Test.java => CheckpointTestUtils.java} | 89 ++-- .../savepoint/SavepointLoaderTest.java | 2 +- .../savepoint/SavepointStoreTest.java | 27 +- .../savepoint/SavepointV1SerializerTest.java | 17 +- .../savepoint/SavepointV2SerializerTest.java | 148 ++++++ .../checkpoint/savepoint/SavepointV2Test.java | 68 +++ .../ArchivedExecutionGraphTest.java | 2 + .../checkpoint/ExternallyInducedSource.java | 75 +++ .../checkpoint/WithMasterCheckpointHook.java | 38 ++ .../FunctionMasterCheckpointHookFactory.java | 45 ++ .../api/graph/StreamingJobGraphGenerator.java | 28 +- .../runtime/tasks/SourceStreamTask.java | 56 +++ .../WithMasterCheckpointHookConfigTest.java | 189 +++++++ .../runtime/io/StreamRecordWriterTest.java | 5 - .../SourceExternalCheckpointTriggerTest.java | 171 +++++++ .../runtime/tasks/StreamTaskTestHarness.java | 7 +- .../test/checkpointing/SavepointITCase.java | 4 +- 42 files changed, 2748 insertions(+), 155 deletions(-) create mode 100644 flink-core/src/main/java/org/apache/flink/core/io/SimpleVersionedSerializer.java create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/MasterState.java create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/MasterTriggerRestoreHook.java create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/hooks/MasterHooks.java create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV2.java create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV2Serializer.java create mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorMasterHooksTest.java rename flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/{SavepointV1Test.java => CheckpointTestUtils.java} (69%) create mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV2SerializerTest.java create mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV2Test.java create mode 100644 flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/ExternallyInducedSource.java create mode 100644 flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/WithMasterCheckpointHook.java create mode 100644 flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/FunctionMasterCheckpointHookFactory.java create mode 100644 flink-streaming-java/src/test/java/org/apache/flink/streaming/graph/WithMasterCheckpointHookConfigTest.java create mode 100644 flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceExternalCheckpointTriggerTest.java diff --git a/flink-core/src/main/java/org/apache/flink/core/io/SimpleVersionedSerializer.java b/flink-core/src/main/java/org/apache/flink/core/io/SimpleVersionedSerializer.java new file mode 100644 index 0000000000000..6c061a54ac127 --- /dev/null +++ b/flink-core/src/main/java/org/apache/flink/core/io/SimpleVersionedSerializer.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.core.io; + +import java.io.IOException; + +/** + * A simple serializer interface for versioned serialization. + * + *

The serializer has a version (returned by {@link #getVersion()}) which can be attached + * to the serialized data. When the serializer evolves, the version can be used to identify + * with which prior version the data was serialized. + * + *

{@code
+ * MyType someObject = ...;
+ * SimpleVersionedSerializer serializer = ...;
+ *
+ * byte[] serializedData = serializer.serialize(someObject);
+ * int version = serializer.getVersion();
+ *
+ * MyType deserialized = serializer.deserialize(version, serializedData);
+ * 
+ * byte[] someOldData = ...;
+ * int oldVersion = ...;
+ * MyType deserializedOldObject = serializer.deserialize(oldVersion, someOldData);
+ * 
+ * }
+ * + * @param The data type serialized / deserialized by this serializer. + */ +public interface SimpleVersionedSerializer extends Versioned { + + /** + * Gets the version with which this serializer serializes. + * + * @return The version of the serialization schema. + */ + @Override + int getVersion(); + + /** + * Serializes the given object. The serialization is assumed to correspond to the + * current serialization version (as returned by {@link #getVersion()}. + * + * + * @param obj The object to serialize. + * @return The serialized data (bytes). + * + * @throws IOException Thrown, if the serialization fails. + */ + byte[] serialize(E obj) throws IOException; + + /** + * De-serializes the given data (bytes) which was serialized with the scheme of the + * indicated version. + * + * @param version The version in which the data was serialized + * @param serialized The serialized data + * @return The deserialized object + * + * @throws IOException Thrown, if the deserialization fails. + */ + E deserialize(int version, byte[] serialized) throws IOException; +} diff --git a/flink-core/src/main/java/org/apache/flink/util/StringUtils.java b/flink-core/src/main/java/org/apache/flink/util/StringUtils.java index b84f602d26bac..abd6ba6a4fae5 100644 --- a/flink-core/src/main/java/org/apache/flink/util/StringUtils.java +++ b/flink-core/src/main/java/org/apache/flink/util/StringUtils.java @@ -309,6 +309,13 @@ public static void writeNullableString(@Nullable String str, DataOutputView out) } } + /** + * Checks if the string is null, empty, or contains only whitespace characters. + * A whitespace character is defined via {@link Character#isWhitespace(char)}. + * + * @param str The string to check + * @return True, if the string is null or blank, false otherwise. + */ public static boolean isNullOrWhitespaceOnly(String str) { if (str == null || str.length() == 0) { return true; diff --git a/flink-runtime/src/main/java/org/apache/flink/migration/runtime/checkpoint/savepoint/SavepointV0.java b/flink-runtime/src/main/java/org/apache/flink/migration/runtime/checkpoint/savepoint/SavepointV0.java index 1c51a695c3dfa..f3ec1cf2f4b5f 100644 --- a/flink-runtime/src/main/java/org/apache/flink/migration/runtime/checkpoint/savepoint/SavepointV0.java +++ b/flink-runtime/src/main/java/org/apache/flink/migration/runtime/checkpoint/savepoint/SavepointV0.java @@ -19,6 +19,7 @@ package org.apache.flink.migration.runtime.checkpoint.savepoint; import org.apache.flink.migration.runtime.checkpoint.TaskState; +import org.apache.flink.runtime.checkpoint.MasterState; import org.apache.flink.runtime.checkpoint.savepoint.Savepoint; import org.apache.flink.util.Preconditions; @@ -58,6 +59,15 @@ public long getCheckpointId() { @Override public Collection getTaskStates() { + // since checkpoints are never deserialized into this format, + // this method should never be called + throw new UnsupportedOperationException(); + } + + @Override + public Collection getMasterStates() { + // since checkpoints are never deserialized into this format, + // this method should never be called throw new UnsupportedOperationException(); } diff --git a/flink-runtime/src/main/java/org/apache/flink/migration/runtime/checkpoint/savepoint/SavepointV0Serializer.java b/flink-runtime/src/main/java/org/apache/flink/migration/runtime/checkpoint/savepoint/SavepointV0Serializer.java index 4739033484194..d285906262565 100644 --- a/flink-runtime/src/main/java/org/apache/flink/migration/runtime/checkpoint/savepoint/SavepointV0Serializer.java +++ b/flink-runtime/src/main/java/org/apache/flink/migration/runtime/checkpoint/savepoint/SavepointV0Serializer.java @@ -34,7 +34,7 @@ import org.apache.flink.migration.streaming.runtime.tasks.StreamTaskStateList; import org.apache.flink.migration.util.SerializedValue; import org.apache.flink.runtime.checkpoint.savepoint.SavepointSerializer; -import org.apache.flink.runtime.checkpoint.savepoint.SavepointV1; +import org.apache.flink.runtime.checkpoint.savepoint.SavepointV2; import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.state.ChainedStateHandle; import org.apache.flink.runtime.state.CheckpointStreamFactory; @@ -68,7 +68,7 @@ * don't rely on any involved Java classes to stay the same. */ @SuppressWarnings("deprecation") -public class SavepointV0Serializer implements SavepointSerializer { +public class SavepointV0Serializer implements SavepointSerializer { public static final SavepointV0Serializer INSTANCE = new SavepointV0Serializer(); private static final StreamStateHandle SIGNAL_0 = new ByteStreamStateHandle("SIGNAL_0", new byte[]{0}); @@ -81,12 +81,12 @@ private SavepointV0Serializer() { @Override - public void serialize(SavepointV1 savepoint, DataOutputStream dos) throws IOException { + public void serialize(SavepointV2 savepoint, DataOutputStream dos) throws IOException { throw new UnsupportedOperationException("This serializer is read-only and only exists for backwards compatibility"); } @Override - public SavepointV1 deserialize(DataInputStream dis, ClassLoader userClassLoader) throws IOException { + public SavepointV2 deserialize(DataInputStream dis, ClassLoader userClassLoader) throws IOException { long checkpointId = dis.readLong(); @@ -165,7 +165,7 @@ private static SerializedValue> readSerializedValueStateHandle(Da return serializedValue; } - private SavepointV1 convertSavepoint( + private SavepointV2 convertSavepoint( List taskStates, ClassLoader userClassLoader, long checkpointID) throws Exception { @@ -176,7 +176,7 @@ private SavepointV1 convertSavepoint( newTaskStates.add(convertTaskState(taskState, userClassLoader, checkpointID)); } - return new SavepointV1(checkpointID, newTaskStates); + return new SavepointV2(checkpointID, newTaskStates); } private org.apache.flink.runtime.checkpoint.TaskState convertTaskState( diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java index 256321eb854b1..23a38d4847f2c 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java @@ -20,7 +20,9 @@ import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.api.common.JobID; +import org.apache.flink.api.common.time.Time; import org.apache.flink.configuration.ConfigConstants; +import org.apache.flink.runtime.checkpoint.hooks.MasterHooks; import org.apache.flink.runtime.checkpoint.savepoint.SavepointLoader; import org.apache.flink.runtime.checkpoint.savepoint.SavepointStore; import org.apache.flink.runtime.concurrent.ApplyFunction; @@ -39,7 +41,10 @@ import org.apache.flink.runtime.messages.checkpoint.DeclineCheckpoint; import org.apache.flink.runtime.state.TaskStateHandles; import org.apache.flink.runtime.taskmanager.DispatcherThreadFactory; +import org.apache.flink.util.FlinkException; import org.apache.flink.util.Preconditions; +import org.apache.flink.util.StringUtils; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -138,6 +143,9 @@ public class CheckpointCoordinator { /** The timer that handles the checkpoint timeouts and triggers periodic checkpoints */ private final ScheduledThreadPoolExecutor timer; + /** The master checkpoint hooks executed by this checkpoint coordinator */ + private final HashMap> masterHooks; + /** Actor that receives status updates from the execution graph this coordinator works for */ private JobStatusListener jobStatusListener; @@ -220,6 +228,7 @@ public CheckpointCoordinator( this.executor = checkNotNull(executor); this.recentPendingCheckpoints = new ArrayDeque<>(NUM_GHOST_CHECKPOINT_IDS); + this.masterHooks = new HashMap<>(); this.timer = new ScheduledThreadPoolExecutor(1, new DispatcherThreadFactory(Thread.currentThread().getThreadGroup(), "Checkpoint Timer")); @@ -245,6 +254,45 @@ public CheckpointCoordinator( } } + // -------------------------------------------------------------------------------------------- + // Configuration + // -------------------------------------------------------------------------------------------- + + /** + * Adds the given master hook to the checkpoint coordinator. This method does nothing, if + * the checkpoint coordinator already contained a hook with the same ID (as defined via + * {@link MasterTriggerRestoreHook#getIdentifier()}). + * + * @param hook The hook to add. + * @return True, if the hook was added, false if the checkpoint coordinator already + * contained a hook with the same ID. + */ + public boolean addMasterHook(MasterTriggerRestoreHook hook) { + checkNotNull(hook); + + final String id = hook.getIdentifier(); + checkArgument(!StringUtils.isNullOrWhitespaceOnly(id), "The hook has a null or empty id"); + + synchronized (lock) { + if (!masterHooks.containsKey(id)) { + masterHooks.put(id, hook); + return true; + } + else { + return false; + } + } + } + + /** + * Gets the number of currently register master hooks. + */ + public int getNumberOfRegisteredMasterHooks() { + synchronized (lock) { + return masterHooks.size(); + } + } + /** * Sets the checkpoint stats tracker. * @@ -492,6 +540,20 @@ CheckpointTriggerResult triggerCheckpoint( checkpoint.setStatsCallback(callback); } + // trigger the master hooks for the checkpoint + try { + List masterStates = MasterHooks.triggerMasterHooks(masterHooks.values(), + checkpointID, timestamp, executor, Time.milliseconds(checkpointTimeout)); + + for (MasterState s : masterStates) { + checkpoint.addMasterState(s); + } + } + catch (FlinkException e) { + checkpoint.abortError(e); + return new CheckpointTriggerResult(CheckpointDeclineReason.EXCEPTION); + } + // schedule the timer that will clean up the expired checkpoints final Runnable canceller = new Runnable() { @Override @@ -962,13 +1024,25 @@ public boolean restoreLatestCheckpointedState( LOG.info("Restoring from latest valid checkpoint: {}.", latest); + // re-assign the task states + final Map taskStates = latest.getTaskStates(); StateAssignmentOperation stateAssignmentOperation = new StateAssignmentOperation(LOG, tasks, taskStates, allowNonRestoredState); - stateAssignmentOperation.assignStates(); + // call master hooks for restore + + MasterHooks.restoreMasterHooks( + masterHooks, + latest.getMasterHookStates(), + latest.getCheckpointID(), + allowNonRestoredState, + LOG); + + // update metrics + if (statsTracker != null) { long restoreTimestamp = System.currentTimeMillis(); RestoredCheckpointStats restored = new RestoredCheckpointStats( @@ -1022,9 +1096,9 @@ public boolean restoreSavepoint( return restoreLatestCheckpointedState(tasks, true, allowNonRestored); } - // -------------------------------------------------------------------------------------------- + // ------------------------------------------------------------------------ // Accessors - // -------------------------------------------------------------------------------------------- + // ------------------------------------------------------------------------ public int getNumberOfPendingCheckpoints() { return this.pendingCheckpoints.size(); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointDeclineReason.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointDeclineReason.java index 60fe657c851b1..41c50cc08e45e 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointDeclineReason.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointDeclineReason.java @@ -36,7 +36,9 @@ public enum CheckpointDeclineReason { NOT_ALL_REQUIRED_TASKS_RUNNING("Not all required tasks are currently running."), - EXCEPTION("An Exception occurred while triggering the checkpoint."); + EXCEPTION("An Exception occurred while triggering the checkpoint."), + + EXPIRED("The checkpoint expired before triggering was complete"); // ------------------------------------------------------------------------ diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpoint.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpoint.java index 79fc31f862344..bb49b45637b8b 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpoint.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpoint.java @@ -27,11 +27,16 @@ import org.apache.flink.runtime.state.StreamStateHandle; import org.apache.flink.util.ExceptionUtils; import org.apache.flink.util.Preconditions; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; import javax.annotation.Nullable; import java.io.Serializable; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; import java.util.Map; import static org.apache.flink.util.Preconditions.checkArgument; @@ -90,11 +95,14 @@ public class CompletedCheckpoint implements Serializable { private final long duration; /** States of the different task groups belonging to this checkpoint */ - private final Map taskStates; + private final HashMap taskStates; /** Properties for this checkpoint. */ private final CheckpointProperties props; + /** States that were created by a hook on the master (in the checkpoint coordinator) */ + private final Collection masterHookStates; + /** The state handle to the externalized meta data, if the metadata has been externalized */ @Nullable private final StreamStateHandle externalizedMetadata; @@ -118,6 +126,7 @@ public class CompletedCheckpoint implements Serializable { Map taskStates) { this(job, checkpointID, timestamp, completionTimestamp, taskStates, + Collections.emptyList(), CheckpointProperties.forStandardCheckpoint()); } @@ -127,9 +136,11 @@ public CompletedCheckpoint( long timestamp, long completionTimestamp, Map taskStates, + @Nullable Collection masterHookStates, CheckpointProperties props) { - this(job, checkpointID, timestamp, completionTimestamp, taskStates, props, null, null); + this(job, checkpointID, timestamp, completionTimestamp, taskStates, + masterHookStates, props, null, null); } public CompletedCheckpoint( @@ -138,6 +149,7 @@ public CompletedCheckpoint( long timestamp, long completionTimestamp, Map taskStates, + @Nullable Collection masterHookStates, CheckpointProperties props, @Nullable StreamStateHandle externalizedMetadata, @Nullable String externalPointer) { @@ -156,7 +168,14 @@ public CompletedCheckpoint( this.checkpointID = checkpointID; this.timestamp = timestamp; this.duration = completionTimestamp - timestamp; - this.taskStates = checkNotNull(taskStates); + + // we create copies here, to make sure we have no shared mutable + // data structure with the "outside world" + this.taskStates = new HashMap<>(checkNotNull(taskStates)); + this.masterHookStates = masterHookStates == null || masterHookStates.isEmpty() ? + Collections.emptyList() : + new ArrayList<>(masterHookStates); + this.props = checkNotNull(props); this.externalizedMetadata = externalizedMetadata; this.externalPointer = externalPointer; @@ -228,13 +247,17 @@ public long getStateSize() { } public Map getTaskStates() { - return taskStates; + return Collections.unmodifiableMap(taskStates); } public TaskState getTaskState(JobVertexID jobVertexID) { return taskStates.get(jobVertexID); } + public Collection getMasterHookStates() { + return Collections.unmodifiableCollection(masterHookStates); + } + public boolean isExternalized() { return externalizedMetadata != null; } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/MasterState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/MasterState.java new file mode 100644 index 0000000000000..2d09fdb62430a --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/MasterState.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint; + +import java.util.Arrays; + +import static org.apache.flink.util.Preconditions.checkNotNull; + +/** + * Simple encapsulation of state generated by checkpoint coordinator. + */ +public class MasterState implements java.io.Serializable { + + private static final long serialVersionUID = 1L; + + private final String name; + private final byte[] bytes; + private final int version; + + public MasterState(String name, byte[] bytes, int version) { + this.name = checkNotNull(name); + this.bytes = checkNotNull(bytes); + this.version = version; + } + + // ------------------------------------------------------------------------ + + public String name() { + return name; + } + + public byte[] bytes() { + return bytes; + } + + public int version() { + return version; + } + + // ------------------------------------------------------------------------ + + @Override + public String toString() { + return "name: " + name + " ; version: " + version + " ; bytes: " + Arrays.toString(bytes); + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/MasterTriggerRestoreHook.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/MasterTriggerRestoreHook.java new file mode 100644 index 0000000000000..e77ed57653143 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/MasterTriggerRestoreHook.java @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint; + +import org.apache.flink.core.io.SimpleVersionedSerializer; +import org.apache.flink.runtime.concurrent.Future; + +import javax.annotation.Nullable; +import java.util.concurrent.Executor; + +/** + * The interface for hooks that can be called by the checkpoint coordinator when triggering or + * restoring a checkpoint. Such a hook is useful for example when preparing external systems for + * taking or restoring checkpoints. + * + *

The {@link #triggerCheckpoint(long, long, Executor)} method (called when triggering a checkpoint) + * can return a result (via a future) that will be stored as part of the checkpoint metadata. + * When restoring a checkpoint, that stored result will be given to the {@link #restoreCheckpoint(long, Object)} + * method. The hook's {@link #getIdentifier() identifier} is used to map data to hook in the presence + * of multiple hooks, and when resuming a savepoint that was potentially created by a different job. + * The identifier has a similar role as for example the operator UID in the streaming API. + * + *

The MasterTriggerRestoreHook is defined when creating the streaming dataflow graph. It is attached + * to the job graph, which gets sent to the cluster for execution. To avoid having to make the hook + * itself serializable, these hooks are attached to the job graph via a {@link MasterTriggerRestoreHook.Factory}. + * + * @param The type of the data produced by the hook and stored as part of the checkpoint metadata. + * If the hook never stores any data, this can be typed to {@code Void}. + */ +public interface MasterTriggerRestoreHook { + + /** + * Gets the identifier of this hook. The identifier is used to identify a specific hook in the + * presence of multiple hooks and to give it the correct checkpointed data upon checkpoint restoration. + * + *

The identifier should be unique between different hooks of a job, but deterministic/constant + * so that upon resuming a savepoint, the hook will get the correct data. + * For example, if the hook calls into another storage system and persists namespace/schema specific + * information, then the name of the storage system, together with the namespace/schema name could + * be an appropriate identifier. + * + *

When multiple hooks of the same name are created and attached to a job graph, only the first + * one is actually used. This can be exploited to deduplicate hooks that would do the same thing. + * + * @return The identifier of the hook. + */ + String getIdentifier(); + + /** + * This method is called by the checkpoint coordinator prior when triggering a checkpoint, prior + * to sending the "trigger checkpoint" messages to the source tasks. + * + *

If the hook implementation wants to store data as part of the checkpoint, it may return + * that data via a future, otherwise it should return null. The data is stored as part of + * the checkpoint metadata under the hooks identifier (see {@link #getIdentifier()}). + * + *

If the action by this hook needs to be executed synchronously, then this method should + * directly execute the action synchronously and block until it is complete. The returned future + * (if any) would typically be a completed future. + * + *

If the action should be executed asynchronously and only needs to complete before the + * checkpoint is considered completed, then the method may use the given executor to execute the + * actual action and would signal its completion by completing the future. For hooks that do not + * need to store data, the future would be completed with null. + * + * @param checkpointId The ID (logical timestamp, monotonously increasing) of the checkpoint + * @param timestamp The wall clock timestamp when the checkpoint was triggered, for + * info/logging purposes. + * @param executor The executor for asynchronous actions + * + * @return Optionally, a future that signals when the hook has completed and that contains + * data to be stored with the checkpoint. + * + * @throws Exception Exceptions encountered when calling the hook will cause the checkpoint to abort. + */ + @Nullable + Future triggerCheckpoint(long checkpointId, long timestamp, Executor executor) throws Exception; + + /** + * This method is called by the checkpoint coordinator prior to restoring the state of a checkpoint. + * If the checkpoint did store data from this hook, that data will be passed to this method. + * + * @param checkpointId The The ID (logical timestamp) of the restored checkpoint + * @param checkpointData The data originally stored in the checkpoint by this hook, possibly null. + * + * @throws Exception Exceptions thrown while restoring the checkpoint will cause the restore + * operation to fail and to possibly fall back to another checkpoint. + */ + void restoreCheckpoint(long checkpointId, @Nullable T checkpointData) throws Exception; + + /** + * Creates a the serializer to (de)serializes the data stored by this hook. The serializer + * serializes the result of the Future returned by the {@link #triggerCheckpoint(long, long, Executor)} + * method, and deserializes the data stored in the checkpoint into the object passed to the + * {@link #restoreCheckpoint(long, Object)} method. + * + *

If the hook never returns any data to be stored, then this method may return null as the + * serializer. + * + * @return The serializer to (de)serializes the data stored by this hook + */ + @Nullable + SimpleVersionedSerializer createCheckpointDataSerializer(); + + // ------------------------------------------------------------------------ + // factory + // ------------------------------------------------------------------------ + + /** + * A factory to instantiate a {@code MasterTriggerRestoreHook}. + * + * The hooks are defined when creating the streaming dataflow graph and are attached + * to the job graph, which gets sent to the cluster for execution. To avoid having to make + * the hook implementation serializable, a serializable hook factory is actually attached to the + * job graph instead of the hook implementation itself. + */ + interface Factory extends java.io.Serializable { + + /** + * Instantiates the {@code MasterTriggerRestoreHook}. + */ + MasterTriggerRestoreHook create(); + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java index 900331bb6c5eb..ce97edc676ca1 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java @@ -21,7 +21,7 @@ import org.apache.flink.api.common.JobID; import org.apache.flink.runtime.checkpoint.savepoint.Savepoint; import org.apache.flink.runtime.checkpoint.savepoint.SavepointStore; -import org.apache.flink.runtime.checkpoint.savepoint.SavepointV1; +import org.apache.flink.runtime.checkpoint.savepoint.SavepointV2; import org.apache.flink.runtime.concurrent.Future; import org.apache.flink.runtime.concurrent.impl.FlinkCompletableFuture; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; @@ -41,8 +41,10 @@ import javax.annotation.Nullable; import javax.annotation.concurrent.GuardedBy; import java.io.IOException; +import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; +import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.Executor; @@ -89,6 +91,8 @@ public enum TaskAcknowledgeResult { private final Map notYetAcknowledgedTasks; + private final List masterState; + /** Set of acknowledged tasks */ private final Set acknowledgedTasks; @@ -143,6 +147,7 @@ public PendingCheckpoint( this.executor = Preconditions.checkNotNull(executor); this.taskStates = new HashMap<>(); + this.masterState = new ArrayList<>(); this.acknowledgedTasks = new HashSet<>(verticesToConfirm.size()); this.onCompletionPromise = new FlinkCompletableFuture<>(); } @@ -256,7 +261,7 @@ public CompletedCheckpoint finalizeCheckpointExternalized() throws IOException { // make sure we fulfill the promise with an exception if something fails try { // externalize the metadata - final Savepoint savepoint = new SavepointV1(checkpointId, taskStates.values()); + final Savepoint savepoint = new SavepointV2(checkpointId, taskStates.values()); // TEMP FIX - The savepoint store is strictly typed to file systems currently // but the checkpoints think more generic. we need to work with file handles @@ -321,7 +326,8 @@ private CompletedCheckpoint finalizeInternal( checkpointId, checkpointTimestamp, System.currentTimeMillis(), - new HashMap<>(taskStates), + taskStates, + masterState, props, externalMetadata, externalPointer); @@ -344,6 +350,17 @@ private CompletedCheckpoint finalizeInternal( return completed; } + /** + * Adds a master state (state generated on the checkpoint coordinator) to + * the pending checkpoint. + * + * @param state The state to add + */ + public void addMasterState(MasterState state) { + checkNotNull(state); + masterState.add(state); + } + /** * Acknowledges the task with the given execution attempt id and the given subtask state. * diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/hooks/MasterHooks.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/hooks/MasterHooks.java new file mode 100644 index 0000000000000..409019e8327ca --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/hooks/MasterHooks.java @@ -0,0 +1,273 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint.hooks; + +import org.apache.flink.api.common.time.Time; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.core.io.SimpleVersionedSerializer; +import org.apache.flink.runtime.checkpoint.MasterState; +import org.apache.flink.runtime.checkpoint.MasterTriggerRestoreHook; +import org.apache.flink.runtime.concurrent.Future; +import org.apache.flink.util.ExceptionUtils; +import org.apache.flink.util.FlinkException; + +import org.slf4j.Logger; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Executor; +import java.util.concurrent.TimeoutException; + +/** + * Collection of methods to deal with checkpoint master hooks. + */ +public class MasterHooks { + + // ------------------------------------------------------------------------ + // checkpoint triggering + // ------------------------------------------------------------------------ + + /** + * Triggers all given master hooks and returns state objects for each hook that + * produced a state. + * + * @param hooks The hooks to trigger + * @param checkpointId The checkpoint ID of the triggering checkpoint + * @param timestamp The (informational) timestamp for the triggering checkpoint + * @param executor An executor that can be used for asynchronous I/O calls + * @param timeout The maximum time that a hook may take to complete + * + * @return A list containing all states produced by the hooks + * + * @throws FlinkException Thrown, if the hooks throw an exception, or the state+ + * deserialization fails. + */ + public static List triggerMasterHooks( + Collection> hooks, + long checkpointId, + long timestamp, + Executor executor, + Time timeout) throws FlinkException { + + final ArrayList states = new ArrayList<>(hooks.size()); + + for (MasterTriggerRestoreHook hook : hooks) { + MasterState state = triggerHook(hook, checkpointId, timestamp, executor, timeout); + if (state != null) { + states.add(state); + } + } + + states.trimToSize(); + return states; + } + + private static MasterState triggerHook( + MasterTriggerRestoreHook hook, + long checkpointId, + long timestamp, + Executor executor, + Time timeout) throws FlinkException { + + @SuppressWarnings("unchecked") + final MasterTriggerRestoreHook typedHook = (MasterTriggerRestoreHook) hook; + + final String id = typedHook.getIdentifier(); + final SimpleVersionedSerializer serializer = typedHook.createCheckpointDataSerializer(); + + // call the hook! + final Future resultFuture; + try { + resultFuture = typedHook.triggerCheckpoint(checkpointId, timestamp, executor); + } + catch (Throwable t) { + ExceptionUtils.rethrowIfFatalErrorOrOOM(t); + throw new FlinkException("Error while triggering checkpoint master hook '" + id + '\'', t); + } + + // is there is a result future, wait for its completion + // in the future we want to make this asynchronous with futures (no pun intended) + if (resultFuture == null) { + return null; + } + else { + final T result; + try { + result = resultFuture.get(timeout.getSize(), timeout.getUnit()); + } + catch (InterruptedException e) { + // cannot continue here - restore interrupt status and leave + Thread.currentThread().interrupt(); + throw new FlinkException("Checkpoint master hook was interrupted"); + } + catch (ExecutionException e) { + throw new FlinkException("Checkpoint master hook '" + id + "' produced an exception", e.getCause()); + } + catch (TimeoutException e) { + throw new FlinkException("Checkpoint master hook '" + id + + "' did not complete in time (" + timeout + ')'); + } + + // if the result of the future is not null, return it as state + if (result == null) { + return null; + } + else if (serializer != null) { + try { + final int version = serializer.getVersion(); + final byte[] bytes = serializer.serialize(result); + + return new MasterState(id, bytes, version); + } + catch (Throwable t) { + ExceptionUtils.rethrowIfFatalErrorOrOOM(t); + throw new FlinkException("Failed to serialize state of master hook '" + id + '\'', t); + } + } + else { + throw new FlinkException("Checkpoint hook '" + id + " is stateful but creates no serializer"); + } + } + } + + // ------------------------------------------------------------------------ + // checkpoint restoring + // ------------------------------------------------------------------------ + + /** + * Calls the restore method given checkpoint master hooks and passes the given master + * state to them where state with a matching name is found. + * + *

If state is found and no hook with the same name is found, the method throws an + * exception, unless the {@code allowUnmatchedState} flag is set. + * + * @param masterHooks The hooks to call restore on + * @param states The state to pass to the hooks + * @param checkpointId The checkpoint ID of the restored checkpoint + * @param allowUnmatchedState True, + * @param log The logger for log messages + * + * @throws FlinkException Thrown, if the hooks throw an exception, or the state+ + * deserialization fails. + */ + public static void restoreMasterHooks( + final Map> masterHooks, + final Collection states, + final long checkpointId, + final boolean allowUnmatchedState, + final Logger log) throws FlinkException { + + // early out + if (states == null || states.isEmpty() || masterHooks == null || masterHooks.isEmpty()) { + log.info("No master state to restore"); + return; + } + + log.info("Calling master restore hooks"); + + // collect the hooks + final LinkedHashMap> allHooks = new LinkedHashMap<>(masterHooks); + + // first, deserialize all hook state + final ArrayList, Object>> hooksAndStates = new ArrayList<>(); + + for (MasterState state : states) { + if (state != null) { + final String name = state.name(); + final MasterTriggerRestoreHook hook = allHooks.remove(name); + + if (hook != null) { + log.debug("Found state to restore for hook '{}'", name); + + Object deserializedState = deserializeState(state, hook); + hooksAndStates.add(new Tuple2, Object>(hook, deserializedState)); + } + else if (!allowUnmatchedState) { + throw new IllegalStateException("Found state '" + state.name() + + "' which is not resumed by any hook."); + } + else { + log.info("Dropping unmatched state from '{}'", name); + } + } + } + + // now that all is deserialized, call the hooks + for (Tuple2, Object> hookAndState : hooksAndStates) { + restoreHook(hookAndState.f1, hookAndState.f0, checkpointId); + } + + // trigger the remaining hooks without checkpointed state + for (MasterTriggerRestoreHook hook : allHooks.values()) { + restoreHook(null, hook, checkpointId); + } + } + + private static T deserializeState(MasterState state, MasterTriggerRestoreHook hook) throws FlinkException { + @SuppressWarnings("unchecked") + final MasterTriggerRestoreHook typedHook = (MasterTriggerRestoreHook) hook; + final String id = hook.getIdentifier(); + + try { + final SimpleVersionedSerializer deserializer = typedHook.createCheckpointDataSerializer(); + if (deserializer == null) { + throw new FlinkException("null serializer for state of hook " + hook.getIdentifier()); + } + + return deserializer.deserialize(state.version(), state.bytes()); + } + catch (Throwable t) { + throw new FlinkException("Cannot deserialize state for master hook '" + id + '\'', t); + } + } + + private static void restoreHook( + final Object state, + final MasterTriggerRestoreHook hook, + final long checkpointId) throws FlinkException { + + @SuppressWarnings("unchecked") + final T typedState = (T) state; + + @SuppressWarnings("unchecked") + final MasterTriggerRestoreHook typedHook = (MasterTriggerRestoreHook) hook; + + try { + typedHook.restoreCheckpoint(checkpointId, typedState); + } + catch (FlinkException e) { + throw e; + } + catch (Throwable t) { + // catch all here, including Errors that may come from dependency and classpath issues + ExceptionUtils.rethrowIfFatalError(t); + throw new FlinkException("Error while calling restoreCheckpoint on checkpoint hook '" + + hook.getIdentifier() + '\'', t); + } + } + + // ------------------------------------------------------------------------ + + /** This class is not meant to be instantiated */ + private MasterHooks() {} +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/Savepoint.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/Savepoint.java index baad05f79b1e5..79ec59630bf73 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/Savepoint.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/Savepoint.java @@ -20,6 +20,7 @@ import org.apache.flink.core.io.Versioned; import org.apache.flink.runtime.checkpoint.CheckpointIDCounter; +import org.apache.flink.runtime.checkpoint.MasterState; import org.apache.flink.runtime.checkpoint.TaskState; import java.util.Collection; @@ -57,6 +58,11 @@ public interface Savepoint extends Versioned { */ Collection getTaskStates(); + /** + * Gets the checkpointed states generated by the master. + */ + Collection getMasterStates(); + /** * Disposes the savepoint. */ diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointLoader.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointLoader.java index 60f02878532b3..8ee38daaafd2f 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointLoader.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointLoader.java @@ -26,6 +26,7 @@ import org.apache.flink.runtime.executiongraph.ExecutionJobVertex; import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.state.StreamStateHandle; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -122,7 +123,7 @@ public static CompletedCheckpoint loadAndValidateSavepoint( // (3) convert to checkpoint so the system can fall back to it CheckpointProperties props = CheckpointProperties.forStandardSavepoint(); return new CompletedCheckpoint(jobId, savepoint.getCheckpointId(), 0L, 0L, - taskStates, props, metadataHandle, savepointPath); + taskStates, savepoint.getMasterStates(), props, metadataHandle, savepointPath); } // ------------------------------------------------------------------------ diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointSerializers.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointSerializers.java index 3155d609c8429..c1fcf4f0d4547 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointSerializers.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointSerializers.java @@ -18,6 +18,7 @@ package org.apache.flink.runtime.checkpoint.savepoint; +import org.apache.flink.migration.runtime.checkpoint.savepoint.SavepointV0; import org.apache.flink.migration.runtime.checkpoint.savepoint.SavepointV0Serializer; import org.apache.flink.util.Preconditions; @@ -30,14 +31,16 @@ public class SavepointSerializers { - private static final int SAVEPOINT_VERSION_0 = 0; private static final Map> SERIALIZERS = new HashMap<>(2); static { - SERIALIZERS.put(SAVEPOINT_VERSION_0, SavepointV0Serializer.INSTANCE); + SERIALIZERS.put(SavepointV0.VERSION, SavepointV0Serializer.INSTANCE); SERIALIZERS.put(SavepointV1.VERSION, SavepointV1Serializer.INSTANCE); + SERIALIZERS.put(SavepointV2.VERSION, SavepointV2Serializer.INSTANCE); } + // ------------------------------------------------------------------------ + /** * Returns the {@link SavepointSerializer} for the given savepoint. * diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1.java index 5976bbfe782ba..196c8704d5a1a 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1.java @@ -18,6 +18,7 @@ package org.apache.flink.runtime.checkpoint.savepoint; +import org.apache.flink.runtime.checkpoint.MasterState; import org.apache.flink.runtime.checkpoint.TaskState; import org.apache.flink.util.Preconditions; @@ -60,36 +61,21 @@ public Collection getTaskStates() { } @Override - public void dispose() throws Exception { - for (TaskState taskState : taskStates) { - taskState.discardState(); - } - taskStates.clear(); - } - - @Override - public String toString() { - return "Savepoint(version=" + VERSION + ")"; + public Collection getMasterStates() { + // since checkpoints are never deserialized into this format, + // this method should never be called + throw new UnsupportedOperationException(); } @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - - if (o == null || getClass() != o.getClass()) { - return false; - } - - SavepointV1 that = (SavepointV1) o; - return checkpointId == that.checkpointId && getTaskStates().equals(that.getTaskStates()); + public void dispose() throws Exception { + // since checkpoints are never deserialized into this format, + // this method should never be called + throw new UnsupportedOperationException(); } @Override - public int hashCode() { - int result = (int) (checkpointId ^ (checkpointId >>> 32)); - result = 31 * result + taskStates.hashCode(); - return result; + public String toString() { + return "Savepoint(version=" + VERSION + ")"; } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Serializer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Serializer.java index 44461d8b8e5a8..ae9f4a9b1ea2d 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Serializer.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Serializer.java @@ -19,6 +19,7 @@ package org.apache.flink.runtime.checkpoint.savepoint; import org.apache.flink.core.fs.Path; +import org.apache.flink.runtime.checkpoint.MasterState; import org.apache.flink.runtime.checkpoint.SubtaskState; import org.apache.flink.runtime.checkpoint.TaskState; import org.apache.flink.runtime.jobgraph.JobVertexID; @@ -37,18 +38,19 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Collection; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; /** - * Serializer for {@link SavepointV1} instances. - *

- *

In contrast to previous savepoint versions, this serializer makes sure - * that no default Java serialization is used for serialization. Therefore, we - * don't rely on any involved Java classes to stay the same. + * Deserializer for checkpoints written in format {@code 1} (Flink 1.2.x format) + * + *

In contrast to the previous versions, this serializer makes sure that no Java + * serialization is used for serialization. Therefore, we don't rely on any involved + * classes to stay the same. */ -class SavepointV1Serializer implements SavepointSerializer { +class SavepointV1Serializer implements SavepointSerializer { private static final byte NULL_HANDLE = 0; private static final byte BYTE_STREAM_STATE_HANDLE = 1; @@ -63,39 +65,12 @@ private SavepointV1Serializer() { } @Override - public void serialize(SavepointV1 savepoint, DataOutputStream dos) throws IOException { - try { - dos.writeLong(savepoint.getCheckpointId()); - - Collection taskStates = savepoint.getTaskStates(); - dos.writeInt(taskStates.size()); - - for (TaskState taskState : savepoint.getTaskStates()) { - // Vertex ID - dos.writeLong(taskState.getJobVertexID().getLowerPart()); - dos.writeLong(taskState.getJobVertexID().getUpperPart()); - - // Parallelism - int parallelism = taskState.getParallelism(); - dos.writeInt(parallelism); - dos.writeInt(taskState.getMaxParallelism()); - dos.writeInt(taskState.getChainLength()); - - // Sub task states - Map subtaskStateMap = taskState.getSubtaskStates(); - dos.writeInt(subtaskStateMap.size()); - for (Map.Entry entry : subtaskStateMap.entrySet()) { - dos.writeInt(entry.getKey()); - serializeSubtaskState(entry.getValue(), dos); - } - } - } catch (Exception e) { - throw new IOException(e); - } + public void serialize(SavepointV2 savepoint, DataOutputStream dos) throws IOException { + throw new UnsupportedOperationException("This serializer is read-only and only exists for backwards compatibility"); } @Override - public SavepointV1 deserialize(DataInputStream dis, ClassLoader cl) throws IOException { + public SavepointV2 deserialize(DataInputStream dis, ClassLoader cl) throws IOException { long checkpointId = dis.readLong(); // Task states @@ -122,7 +97,34 @@ public SavepointV1 deserialize(DataInputStream dis, ClassLoader cl) throws IOExc } } - return new SavepointV1(checkpointId, taskStates); + return new SavepointV2(checkpointId, taskStates, Collections.emptyList()); + } + + public void serializeOld(SavepointV1 savepoint, DataOutputStream dos) throws IOException { + dos.writeLong(savepoint.getCheckpointId()); + + Collection taskStates = savepoint.getTaskStates(); + dos.writeInt(taskStates.size()); + + for (TaskState taskState : savepoint.getTaskStates()) { + // Vertex ID + dos.writeLong(taskState.getJobVertexID().getLowerPart()); + dos.writeLong(taskState.getJobVertexID().getUpperPart()); + + // Parallelism + int parallelism = taskState.getParallelism(); + dos.writeInt(parallelism); + dos.writeInt(taskState.getMaxParallelism()); + dos.writeInt(taskState.getChainLength()); + + // Sub task states + Map subtaskStateMap = taskState.getSubtaskStates(); + dos.writeInt(subtaskStateMap.size()); + for (Map.Entry entry : subtaskStateMap.entrySet()) { + dos.writeInt(entry.getKey()); + serializeSubtaskState(entry.getValue(), dos); + } + } } private static void serializeSubtaskState(SubtaskState subtaskState, DataOutputStream dos) throws IOException { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV2.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV2.java new file mode 100644 index 0000000000000..100982d80a3c8 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV2.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint.savepoint; + +import org.apache.flink.runtime.checkpoint.MasterState; +import org.apache.flink.runtime.checkpoint.TaskState; + +import java.util.Collection; +import java.util.Collections; + +import static org.apache.flink.util.Preconditions.checkNotNull; + +/** + * The persistent checkpoint metadata, format version 2. + * his format was introduced with Flink 1.3.0. + */ +public class SavepointV2 implements Savepoint { + + /** The savepoint version. */ + public static final int VERSION = 2; + + /** The checkpoint ID */ + private final long checkpointId; + + /** The task states */ + private final Collection taskStates; + + /** The states generated by the CheckpointCoordinator */ + private final Collection masterStates; + + + public SavepointV2(long checkpointId, Collection taskStates) { + this(checkpointId, taskStates, Collections.emptyList()); + } + + public SavepointV2(long checkpointId, Collection taskStates, Collection masterStates) { + this.checkpointId = checkpointId; + this.taskStates = checkNotNull(taskStates, "taskStates"); + this.masterStates = checkNotNull(masterStates, "masterStates"); + } + + @Override + public int getVersion() { + return VERSION; + } + + @Override + public long getCheckpointId() { + return checkpointId; + } + + @Override + public Collection getTaskStates() { + return taskStates; + } + + @Override + public Collection getMasterStates() { + return masterStates; + } + + @Override + public void dispose() throws Exception { + for (TaskState taskState : taskStates) { + taskState.discardState(); + } + taskStates.clear(); + masterStates.clear(); + } + + @Override + public String toString() { + return "Checkpoint Metadata (version=" + VERSION + ')'; + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV2Serializer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV2Serializer.java new file mode 100644 index 0000000000000..307ea1641fece --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV2Serializer.java @@ -0,0 +1,468 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint.savepoint; + +import org.apache.flink.core.fs.Path; +import org.apache.flink.runtime.checkpoint.MasterState; +import org.apache.flink.runtime.checkpoint.SubtaskState; +import org.apache.flink.runtime.checkpoint.TaskState; +import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.state.ChainedStateHandle; +import org.apache.flink.runtime.state.KeyGroupRange; +import org.apache.flink.runtime.state.KeyGroupRangeOffsets; +import org.apache.flink.runtime.state.KeyGroupsStateHandle; +import org.apache.flink.runtime.state.KeyedStateHandle; +import org.apache.flink.runtime.state.OperatorStateHandle; +import org.apache.flink.runtime.state.StreamStateHandle; +import org.apache.flink.runtime.state.filesystem.FileStateHandle; +import org.apache.flink.runtime.state.memory.ByteStreamStateHandle; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * (De)serializer for checkpoint metadata format version 2. + * + *

This format version adds + * + *

Basic checkpoint metadata layout: + *

+ *  +--------------+---------------+-----------------+
+ *  | checkpointID | master states | operator states |
+ *  +--------------+---------------+-----------------+
+ *  
+ *  Master state:
+ *  +--------------+---------------------+---------+------+---------------+
+ *  | magic number | num remaining bytes | version | name | payload bytes |
+ *  +--------------+---------------------+---------+------+---------------+
+ * 
+ */ +class SavepointV2Serializer implements SavepointSerializer { + + /** Random magic number for consistency checks */ + private static final int MASTER_STATE_MAGIC_NUMBER = 0xc96b1696; + + private static final byte NULL_HANDLE = 0; + private static final byte BYTE_STREAM_STATE_HANDLE = 1; + private static final byte FILE_STREAM_STATE_HANDLE = 2; + private static final byte KEY_GROUPS_HANDLE = 3; + private static final byte PARTITIONABLE_OPERATOR_STATE_HANDLE = 4; + + /** The singleton instance of the serializer */ + public static final SavepointV2Serializer INSTANCE = new SavepointV2Serializer(); + + // ------------------------------------------------------------------------ + + /** Singleton, not meant to be instantiated */ + private SavepointV2Serializer() {} + + // ------------------------------------------------------------------------ + // (De)serialization entry points + // ------------------------------------------------------------------------ + + @Override + public void serialize(SavepointV2 checkpointMetadata, DataOutputStream dos) throws IOException { + // first: checkpoint ID + dos.writeLong(checkpointMetadata.getCheckpointId()); + + // second: master state + final Collection masterStates = checkpointMetadata.getMasterStates(); + dos.writeInt(masterStates.size()); + for (MasterState ms : masterStates) { + serializeMasterState(ms, dos); + } + + // third: task states + final Collection taskStates = checkpointMetadata.getTaskStates(); + dos.writeInt(taskStates.size()); + + for (TaskState taskState : checkpointMetadata.getTaskStates()) { + // Vertex ID + dos.writeLong(taskState.getJobVertexID().getLowerPart()); + dos.writeLong(taskState.getJobVertexID().getUpperPart()); + + // Parallelism + int parallelism = taskState.getParallelism(); + dos.writeInt(parallelism); + dos.writeInt(taskState.getMaxParallelism()); + dos.writeInt(taskState.getChainLength()); + + // Sub task states + Map subtaskStateMap = taskState.getSubtaskStates(); + dos.writeInt(subtaskStateMap.size()); + for (Map.Entry entry : subtaskStateMap.entrySet()) { + dos.writeInt(entry.getKey()); + serializeSubtaskState(entry.getValue(), dos); + } + } + } + + @Override + public SavepointV2 deserialize(DataInputStream dis, ClassLoader cl) throws IOException { + // first: checkpoint ID + final long checkpointId = dis.readLong(); + if (checkpointId < 0) { + throw new IOException("invalid checkpoint ID: " + checkpointId); + } + + // second: master state + final List masterStates; + final int numMasterStates = dis.readInt(); + + if (numMasterStates == 0) { + masterStates = Collections.emptyList(); + } + else if (numMasterStates > 0) { + masterStates = new ArrayList<>(numMasterStates); + for (int i = 0; i < numMasterStates; i++) { + masterStates.add(deserializeMasterState(dis)); + } + } + else { + throw new IOException("invalid number of master states: " + numMasterStates); + } + + // third: task states + final int numTaskStates = dis.readInt(); + final ArrayList taskStates = new ArrayList<>(numTaskStates); + + for (int i = 0; i < numTaskStates; i++) { + JobVertexID jobVertexId = new JobVertexID(dis.readLong(), dis.readLong()); + int parallelism = dis.readInt(); + int maxParallelism = dis.readInt(); + int chainLength = dis.readInt(); + + // Add task state + TaskState taskState = new TaskState(jobVertexId, parallelism, maxParallelism, chainLength); + taskStates.add(taskState); + + // Sub task states + int numSubTaskStates = dis.readInt(); + + for (int j = 0; j < numSubTaskStates; j++) { + int subtaskIndex = dis.readInt(); + SubtaskState subtaskState = deserializeSubtaskState(dis); + taskState.putState(subtaskIndex, subtaskState); + } + } + + return new SavepointV2(checkpointId, taskStates, masterStates); + } + + // ------------------------------------------------------------------------ + // master state (de)serialization methods + // ------------------------------------------------------------------------ + + private void serializeMasterState(MasterState state, DataOutputStream dos) throws IOException { + // magic number for error detection + dos.writeInt(MASTER_STATE_MAGIC_NUMBER); + + // for safety, we serialize first into an array and then write the array and its + // length into the checkpoint + final ByteArrayOutputStream baos = new ByteArrayOutputStream(); + final DataOutputStream out = new DataOutputStream(baos); + + out.writeInt(state.version()); + out.writeUTF(state.name()); + + final byte[] bytes = state.bytes(); + out.writeInt(bytes.length); + out.write(bytes, 0, bytes.length); + + out.close(); + byte[] data = baos.toByteArray(); + + dos.writeInt(data.length); + dos.write(data, 0, data.length); + } + + private MasterState deserializeMasterState(DataInputStream dis) throws IOException { + final int magicNumber = dis.readInt(); + if (magicNumber != MASTER_STATE_MAGIC_NUMBER) { + throw new IOException("incorrect magic number in master styte byte sequence"); + } + + final int numBytes = dis.readInt(); + if (numBytes <= 0) { + throw new IOException("found zero or negative length for master state bytes"); + } + + final byte[] data = new byte[numBytes]; + dis.readFully(data); + + final DataInputStream in = new DataInputStream(new ByteArrayInputStream(data)); + + final int version = in.readInt(); + final String name = in.readUTF(); + + final byte[] bytes = new byte[in.readInt()]; + in.readFully(bytes); + + // check that the data is not corrupt + if (in.read() != -1) { + throw new IOException("found trailing bytes in master state"); + } + + return new MasterState(name, bytes, version); + } + + // ------------------------------------------------------------------------ + // task state (de)serialization methods + // ------------------------------------------------------------------------ + + private static void serializeSubtaskState(SubtaskState subtaskState, DataOutputStream dos) throws IOException { + + dos.writeLong(-1); + + ChainedStateHandle nonPartitionableState = subtaskState.getLegacyOperatorState(); + + int len = nonPartitionableState != null ? nonPartitionableState.getLength() : 0; + dos.writeInt(len); + for (int i = 0; i < len; ++i) { + StreamStateHandle stateHandle = nonPartitionableState.get(i); + serializeStreamStateHandle(stateHandle, dos); + } + + ChainedStateHandle operatorStateBackend = subtaskState.getManagedOperatorState(); + + len = operatorStateBackend != null ? operatorStateBackend.getLength() : 0; + dos.writeInt(len); + for (int i = 0; i < len; ++i) { + OperatorStateHandle stateHandle = operatorStateBackend.get(i); + serializeOperatorStateHandle(stateHandle, dos); + } + + ChainedStateHandle operatorStateFromStream = subtaskState.getRawOperatorState(); + + len = operatorStateFromStream != null ? operatorStateFromStream.getLength() : 0; + dos.writeInt(len); + for (int i = 0; i < len; ++i) { + OperatorStateHandle stateHandle = operatorStateFromStream.get(i); + serializeOperatorStateHandle(stateHandle, dos); + } + + KeyedStateHandle keyedStateBackend = subtaskState.getManagedKeyedState(); + serializeKeyedStateHandle(keyedStateBackend, dos); + + KeyedStateHandle keyedStateStream = subtaskState.getRawKeyedState(); + serializeKeyedStateHandle(keyedStateStream, dos); + } + + private static SubtaskState deserializeSubtaskState(DataInputStream dis) throws IOException { + // Duration field has been removed from SubtaskState + long ignoredDuration = dis.readLong(); + + int len = dis.readInt(); + List nonPartitionableState = new ArrayList<>(len); + for (int i = 0; i < len; ++i) { + StreamStateHandle streamStateHandle = deserializeStreamStateHandle(dis); + nonPartitionableState.add(streamStateHandle); + } + + + len = dis.readInt(); + List operatorStateBackend = new ArrayList<>(len); + for (int i = 0; i < len; ++i) { + OperatorStateHandle streamStateHandle = deserializeOperatorStateHandle(dis); + operatorStateBackend.add(streamStateHandle); + } + + len = dis.readInt(); + List operatorStateStream = new ArrayList<>(len); + for (int i = 0; i < len; ++i) { + OperatorStateHandle streamStateHandle = deserializeOperatorStateHandle(dis); + operatorStateStream.add(streamStateHandle); + } + + KeyedStateHandle keyedStateBackend = deserializeKeyedStateHandle(dis); + + KeyedStateHandle keyedStateStream = deserializeKeyedStateHandle(dis); + + ChainedStateHandle nonPartitionableStateChain = + new ChainedStateHandle<>(nonPartitionableState); + + ChainedStateHandle operatorStateBackendChain = + new ChainedStateHandle<>(operatorStateBackend); + + ChainedStateHandle operatorStateStreamChain = + new ChainedStateHandle<>(operatorStateStream); + + return new SubtaskState( + nonPartitionableStateChain, + operatorStateBackendChain, + operatorStateStreamChain, + keyedStateBackend, + keyedStateStream); + } + + private static void serializeKeyedStateHandle( + KeyedStateHandle stateHandle, DataOutputStream dos) throws IOException { + + if (stateHandle == null) { + dos.writeByte(NULL_HANDLE); + } else if (stateHandle instanceof KeyGroupsStateHandle) { + KeyGroupsStateHandle keyGroupsStateHandle = (KeyGroupsStateHandle) stateHandle; + + dos.writeByte(KEY_GROUPS_HANDLE); + dos.writeInt(keyGroupsStateHandle.getKeyGroupRange().getStartKeyGroup()); + dos.writeInt(keyGroupsStateHandle.getKeyGroupRange().getNumberOfKeyGroups()); + for (int keyGroup : keyGroupsStateHandle.getKeyGroupRange()) { + dos.writeLong(keyGroupsStateHandle.getOffsetForKeyGroup(keyGroup)); + } + serializeStreamStateHandle(keyGroupsStateHandle.getDelegateStateHandle(), dos); + } else { + throw new IllegalStateException("Unknown KeyedStateHandle type: " + stateHandle.getClass()); + } + } + + private static KeyedStateHandle deserializeKeyedStateHandle(DataInputStream dis) throws IOException { + final int type = dis.readByte(); + if (NULL_HANDLE == type) { + return null; + } else if (KEY_GROUPS_HANDLE == type) { + int startKeyGroup = dis.readInt(); + int numKeyGroups = dis.readInt(); + KeyGroupRange keyGroupRange = KeyGroupRange.of(startKeyGroup, startKeyGroup + numKeyGroups - 1); + long[] offsets = new long[numKeyGroups]; + for (int i = 0; i < numKeyGroups; ++i) { + offsets[i] = dis.readLong(); + } + KeyGroupRangeOffsets keyGroupRangeOffsets = new KeyGroupRangeOffsets( + keyGroupRange, offsets); + StreamStateHandle stateHandle = deserializeStreamStateHandle(dis); + return new KeyGroupsStateHandle(keyGroupRangeOffsets, stateHandle); + } else { + throw new IllegalStateException("Reading invalid KeyedStateHandle, type: " + type); + } + } + + private static void serializeOperatorStateHandle( + OperatorStateHandle stateHandle, DataOutputStream dos) throws IOException { + + if (stateHandle != null) { + dos.writeByte(PARTITIONABLE_OPERATOR_STATE_HANDLE); + Map partitionOffsetsMap = + stateHandle.getStateNameToPartitionOffsets(); + dos.writeInt(partitionOffsetsMap.size()); + for (Map.Entry entry : partitionOffsetsMap.entrySet()) { + dos.writeUTF(entry.getKey()); + + OperatorStateHandle.StateMetaInfo stateMetaInfo = entry.getValue(); + + int mode = stateMetaInfo.getDistributionMode().ordinal(); + dos.writeByte(mode); + + long[] offsets = stateMetaInfo.getOffsets(); + dos.writeInt(offsets.length); + for (long offset : offsets) { + dos.writeLong(offset); + } + } + serializeStreamStateHandle(stateHandle.getDelegateStateHandle(), dos); + } else { + dos.writeByte(NULL_HANDLE); + } + } + + private static OperatorStateHandle deserializeOperatorStateHandle( + DataInputStream dis) throws IOException { + + final int type = dis.readByte(); + if (NULL_HANDLE == type) { + return null; + } else if (PARTITIONABLE_OPERATOR_STATE_HANDLE == type) { + int mapSize = dis.readInt(); + Map offsetsMap = new HashMap<>(mapSize); + for (int i = 0; i < mapSize; ++i) { + String key = dis.readUTF(); + + int modeOrdinal = dis.readByte(); + OperatorStateHandle.Mode mode = OperatorStateHandle.Mode.values()[modeOrdinal]; + + long[] offsets = new long[dis.readInt()]; + for (int j = 0; j < offsets.length; ++j) { + offsets[j] = dis.readLong(); + } + + OperatorStateHandle.StateMetaInfo metaInfo = + new OperatorStateHandle.StateMetaInfo(offsets, mode); + offsetsMap.put(key, metaInfo); + } + StreamStateHandle stateHandle = deserializeStreamStateHandle(dis); + return new OperatorStateHandle(offsetsMap, stateHandle); + } else { + throw new IllegalStateException("Reading invalid OperatorStateHandle, type: " + type); + } + } + + private static void serializeStreamStateHandle( + StreamStateHandle stateHandle, DataOutputStream dos) throws IOException { + + if (stateHandle == null) { + dos.writeByte(NULL_HANDLE); + + } else if (stateHandle instanceof FileStateHandle) { + dos.writeByte(FILE_STREAM_STATE_HANDLE); + FileStateHandle fileStateHandle = (FileStateHandle) stateHandle; + dos.writeLong(stateHandle.getStateSize()); + dos.writeUTF(fileStateHandle.getFilePath().toString()); + + } else if (stateHandle instanceof ByteStreamStateHandle) { + dos.writeByte(BYTE_STREAM_STATE_HANDLE); + ByteStreamStateHandle byteStreamStateHandle = (ByteStreamStateHandle) stateHandle; + dos.writeUTF(byteStreamStateHandle.getHandleName()); + byte[] internalData = byteStreamStateHandle.getData(); + dos.writeInt(internalData.length); + dos.write(byteStreamStateHandle.getData()); + + } else { + throw new IOException("Unknown implementation of StreamStateHandle: " + stateHandle.getClass()); + } + + dos.flush(); + } + + private static StreamStateHandle deserializeStreamStateHandle(DataInputStream dis) throws IOException { + final int type = dis.read(); + if (NULL_HANDLE == type) { + return null; + } else if (FILE_STREAM_STATE_HANDLE == type) { + long size = dis.readLong(); + String pathString = dis.readUTF(); + return new FileStateHandle(new Path(pathString), size); + } else if (BYTE_STREAM_STATE_HANDLE == type) { + String handleName = dis.readUTF(); + int numBytes = dis.readInt(); + byte[] data = new byte[numBytes]; + dis.readFully(data); + return new ByteStreamStateHandle(handleName, data); + } else { + throw new IOException("Unknown implementation of StreamStateHandle, code: " + type); + } + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraph.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraph.java index 29b980692b7ac..23ed99d26ee72 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraph.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraph.java @@ -39,6 +39,7 @@ import org.apache.flink.runtime.checkpoint.CheckpointStatsSnapshot; import org.apache.flink.runtime.checkpoint.CheckpointStatsTracker; import org.apache.flink.runtime.checkpoint.CompletedCheckpointStore; +import org.apache.flink.runtime.checkpoint.MasterTriggerRestoreHook; import org.apache.flink.runtime.concurrent.BiFunction; import org.apache.flink.runtime.concurrent.Future; import org.apache.flink.runtime.concurrent.FutureUtils; @@ -360,6 +361,7 @@ public void enableCheckpointing( List verticesToTrigger, List verticesToWaitFor, List verticesToCommitTo, + List> masterHooks, CheckpointIDCounter checkpointIDCounter, CompletedCheckpointStore checkpointStore, String checkpointDir, @@ -395,6 +397,13 @@ public void enableCheckpointing( checkpointDir, ioExecutor); + // register the master hooks on the checkpoint coordinator + for (MasterTriggerRestoreHook hook : masterHooks) { + if (!checkpointCoordinator.addMasterHook(hook)) { + LOG.warn("Trying to register multiple checkpoint hooks with the name: {}", hook.getIdentifier()); + } + } + checkpointCoordinator.setCheckpointStatsTracker(checkpointStatsTracker); // interval of max long value indicates disable periodic checkpoint, diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraphBuilder.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraphBuilder.java index a10c62e31882c..b40817ff8adc1 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraphBuilder.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraphBuilder.java @@ -31,6 +31,7 @@ import org.apache.flink.runtime.checkpoint.CheckpointRecoveryFactory; import org.apache.flink.runtime.checkpoint.CompletedCheckpointStore; import org.apache.flink.runtime.checkpoint.CheckpointStatsTracker; +import org.apache.flink.runtime.checkpoint.MasterTriggerRestoreHook; import org.apache.flink.runtime.client.JobExecutionException; import org.apache.flink.runtime.client.JobSubmissionException; import org.apache.flink.runtime.executiongraph.metrics.DownTimeGauge; @@ -51,6 +52,7 @@ import javax.annotation.Nullable; import java.io.IOException; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.concurrent.Executor; import java.util.concurrent.ScheduledExecutorService; @@ -230,6 +232,21 @@ public static ExecutionGraph buildGraph( } } + // instantiate the user-defined checkpoint hooks + + final MasterTriggerRestoreHook.Factory[] hookFactories = snapshotSettings.getMasterHooks(); + final List> hooks; + + if (hookFactories == null || hookFactories.length == 0) { + hooks = Collections.emptyList(); + } + else { + hooks = new ArrayList<>(hookFactories.length); + for (MasterTriggerRestoreHook.Factory factory : hookFactories) { + hooks.add(factory.create()); + } + } + executionGraph.enableCheckpointing( snapshotSettings.getCheckpointInterval(), snapshotSettings.getCheckpointTimeout(), @@ -239,6 +256,7 @@ public static ExecutionGraph buildGraph( triggerVertices, ackVertices, confirmVertices, + hooks, checkpointIdCounter, completedCheckpoints, externalizedCheckpointsDir, diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/tasks/JobCheckpointingSettings.java b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/tasks/JobCheckpointingSettings.java index 38130d4c9978a..3dd037ed71505 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/tasks/JobCheckpointingSettings.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/tasks/JobCheckpointingSettings.java @@ -18,6 +18,7 @@ package org.apache.flink.runtime.jobgraph.tasks; +import org.apache.flink.runtime.checkpoint.MasterTriggerRestoreHook; import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.state.StateBackend; @@ -32,21 +33,21 @@ * need to participate. */ public class JobCheckpointingSettings implements java.io.Serializable { - + private static final long serialVersionUID = -2593319571078198180L; - + private final List verticesToTrigger; private final List verticesToAcknowledge; private final List verticesToConfirm; - + private final long checkpointInterval; - + private final long checkpointTimeout; - + private final long minPauseBetweenCheckpoints; - + private final int maxConcurrentCheckpoints; /** Settings for externalized checkpoints. */ @@ -56,6 +57,9 @@ public class JobCheckpointingSettings implements java.io.Serializable { @Nullable private final StateBackend defaultStateBackend; + /** (Factories for) hooks that are executed on the checkpoint coordinator */ + private final MasterTriggerRestoreHook.Factory[] masterHooks; + /** * Flag indicating whether exactly once checkpoint mode has been configured. * If false, at least once mode has been configured. This is @@ -77,12 +81,30 @@ public JobCheckpointingSettings( @Nullable StateBackend defaultStateBackend, boolean isExactlyOnce) { + this(verticesToTrigger, verticesToAcknowledge, verticesToConfirm, + checkpointInterval, checkpointTimeout, minPauseBetweenCheckpoints, maxConcurrentCheckpoints, + externalizedCheckpointSettings, defaultStateBackend, null, isExactlyOnce); + } + + public JobCheckpointingSettings( + List verticesToTrigger, + List verticesToAcknowledge, + List verticesToConfirm, + long checkpointInterval, + long checkpointTimeout, + long minPauseBetweenCheckpoints, + int maxConcurrentCheckpoints, + ExternalizedCheckpointSettings externalizedCheckpointSettings, + @Nullable StateBackend defaultStateBackend, + @Nullable MasterTriggerRestoreHook.Factory[] masterHooks, + boolean isExactlyOnce) { + // sanity checks if (checkpointInterval < 1 || checkpointTimeout < 1 || minPauseBetweenCheckpoints < 0 || maxConcurrentCheckpoints < 1) { throw new IllegalArgumentException(); } - + this.verticesToTrigger = requireNonNull(verticesToTrigger); this.verticesToAcknowledge = requireNonNull(verticesToAcknowledge); this.verticesToConfirm = requireNonNull(verticesToConfirm); @@ -93,14 +115,16 @@ public JobCheckpointingSettings( this.externalizedCheckpointSettings = requireNonNull(externalizedCheckpointSettings); this.defaultStateBackend = defaultStateBackend; this.isExactlyOnce = isExactlyOnce; + + this.masterHooks = masterHooks != null ? masterHooks : new MasterTriggerRestoreHook.Factory[0]; } - + // -------------------------------------------------------------------------------------------- public List getVerticesToTrigger() { return verticesToTrigger; } - + public List getVerticesToAcknowledge() { return verticesToAcknowledge; } @@ -134,12 +158,16 @@ public StateBackend getDefaultStateBackend() { return defaultStateBackend; } + public MasterTriggerRestoreHook.Factory[] getMasterHooks() { + return masterHooks; + } + public boolean isExactlyOnce() { return isExactlyOnce; } // -------------------------------------------------------------------------------------------- - + @Override public String toString() { return String.format("SnapshotSettings: interval=%d, timeout=%d, pause-between=%d, " + diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorMasterHooksTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorMasterHooksTest.java new file mode 100644 index 0000000000000..0ec46065b3606 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorMasterHooksTest.java @@ -0,0 +1,421 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint; + +import org.apache.flink.api.common.JobID; +import org.apache.flink.core.io.SimpleVersionedSerializer; +import org.apache.flink.runtime.concurrent.Executors; +import org.apache.flink.runtime.concurrent.impl.FlinkCompletableFuture; +import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; +import org.apache.flink.runtime.executiongraph.ExecutionJobVertex; +import org.apache.flink.runtime.executiongraph.ExecutionVertex; +import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.jobgraph.tasks.ExternalizedCheckpointSettings; +import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint; + +import org.junit.Test; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.Executor; + +import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTest.mockExecutionVertex; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import static org.mockito.Matchers.eq; +import static org.mockito.Matchers.isNull; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.anyLong; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** + * Tests for the user-defined hooks that the checkpoint coordinator can call. + */ +public class CheckpointCoordinatorMasterHooksTest { + + // ------------------------------------------------------------------------ + // hook registration + // ------------------------------------------------------------------------ + + /** + * This method tests that hooks with the same identifier are not registered + * multiple times. + */ + @Test + public void testDeduplicateOnRegister() { + final CheckpointCoordinator cc = instantiateCheckpointCoordinator(new JobID()); + + MasterTriggerRestoreHook hook1 = mock(MasterTriggerRestoreHook.class); + when(hook1.getIdentifier()).thenReturn("test id"); + + MasterTriggerRestoreHook hook2 = mock(MasterTriggerRestoreHook.class); + when(hook2.getIdentifier()).thenReturn("test id"); + + MasterTriggerRestoreHook hook3 = mock(MasterTriggerRestoreHook.class); + when(hook3.getIdentifier()).thenReturn("anotherId"); + + assertTrue(cc.addMasterHook(hook1)); + assertFalse(cc.addMasterHook(hook2)); + assertTrue(cc.addMasterHook(hook3)); + } + + /** + * Test that validates correct exceptions when supplying hooks with invalid IDs. + */ + @Test + public void testNullOrInvalidId() { + final CheckpointCoordinator cc = instantiateCheckpointCoordinator(new JobID()); + + try { + cc.addMasterHook(null); + fail("expected an exception"); + } catch (NullPointerException ignored) {} + + try { + cc.addMasterHook(mock(MasterTriggerRestoreHook.class)); + fail("expected an exception"); + } catch (IllegalArgumentException ignored) {} + + try { + MasterTriggerRestoreHook hook = mock(MasterTriggerRestoreHook.class); + when(hook.getIdentifier()).thenReturn(" "); + + cc.addMasterHook(hook); + fail("expected an exception"); + } catch (IllegalArgumentException ignored) {} + } + + // ------------------------------------------------------------------------ + // trigger / restore behavior + // ------------------------------------------------------------------------ + + @Test + public void testHooksAreCalledOnTrigger() throws Exception { + final String id1 = "id1"; + final String id2 = "id2"; + + final String state1 = "the-test-string-state"; + final byte[] state1serialized = new StringSerializer().serialize(state1); + + final long state2 = 987654321L; + final byte[] state2serialized = new LongSerializer().serialize(state2); + + final MasterTriggerRestoreHook statefulHook1 = mockGeneric(MasterTriggerRestoreHook.class); + when(statefulHook1.getIdentifier()).thenReturn(id1); + when(statefulHook1.createCheckpointDataSerializer()).thenReturn(new StringSerializer()); + when(statefulHook1.triggerCheckpoint(anyLong(), anyLong(), any(Executor.class))) + .thenReturn(FlinkCompletableFuture.completed(state1)); + + final MasterTriggerRestoreHook statefulHook2 = mockGeneric(MasterTriggerRestoreHook.class); + when(statefulHook2.getIdentifier()).thenReturn(id2); + when(statefulHook2.createCheckpointDataSerializer()).thenReturn(new LongSerializer()); + when(statefulHook2.triggerCheckpoint(anyLong(), anyLong(), any(Executor.class))) + .thenReturn(FlinkCompletableFuture.completed(state2)); + + final MasterTriggerRestoreHook statelessHook = mockGeneric(MasterTriggerRestoreHook.class); + when(statelessHook.getIdentifier()).thenReturn("some-id"); + + // create the checkpoint coordinator + final JobID jid = new JobID(); + final ExecutionAttemptID execId = new ExecutionAttemptID(); + final ExecutionVertex ackVertex = mockExecutionVertex(execId); + final CheckpointCoordinator cc = instantiateCheckpointCoordinator(jid, ackVertex); + + cc.addMasterHook(statefulHook1); + cc.addMasterHook(statelessHook); + cc.addMasterHook(statefulHook2); + + // trigger a checkpoint + assertTrue(cc.triggerCheckpoint(System.currentTimeMillis(), false)); + assertEquals(1, cc.getNumberOfPendingCheckpoints()); + + verify(statefulHook1, times(1)).triggerCheckpoint(anyLong(), anyLong(), any(Executor.class)); + verify(statefulHook2, times(1)).triggerCheckpoint(anyLong(), anyLong(), any(Executor.class)); + verify(statelessHook, times(1)).triggerCheckpoint(anyLong(), anyLong(), any(Executor.class)); + + final long checkpointId = cc.getPendingCheckpoints().values().iterator().next().getCheckpointId(); + cc.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, execId, checkpointId)); + assertEquals(0, cc.getNumberOfPendingCheckpoints()); + + assertEquals(1, cc.getNumberOfRetainedSuccessfulCheckpoints()); + final CompletedCheckpoint chk = cc.getCheckpointStore().getLatestCheckpoint(); + + final Collection masterStates = chk.getMasterHookStates(); + assertEquals(2, masterStates.size()); + + for (MasterState ms : masterStates) { + if (ms.name().equals(id1)) { + assertArrayEquals(state1serialized, ms.bytes()); + assertEquals(StringSerializer.VERSION, ms.version()); + } + else if (ms.name().equals(id2)) { + assertArrayEquals(state2serialized, ms.bytes()); + assertEquals(LongSerializer.VERSION, ms.version()); + } + else { + fail("unrecognized state name: " + ms.name()); + } + } + } + + @Test + public void testHooksAreCalledOnRestore() throws Exception { + final String id1 = "id1"; + final String id2 = "id2"; + + final String state1 = "the-test-string-state"; + final byte[] state1serialized = new StringSerializer().serialize(state1); + + final long state2 = 987654321L; + final byte[] state2serialized = new LongSerializer().serialize(state2); + + final List masterHookStates = Arrays.asList( + new MasterState(id1, state1serialized, StringSerializer.VERSION), + new MasterState(id2, state2serialized, LongSerializer.VERSION)); + + final MasterTriggerRestoreHook statefulHook1 = mockGeneric(MasterTriggerRestoreHook.class); + when(statefulHook1.getIdentifier()).thenReturn(id1); + when(statefulHook1.createCheckpointDataSerializer()).thenReturn(new StringSerializer()); + when(statefulHook1.triggerCheckpoint(anyLong(), anyLong(), any(Executor.class))) + .thenThrow(new Exception("not expected")); + + final MasterTriggerRestoreHook statefulHook2 = mockGeneric(MasterTriggerRestoreHook.class); + when(statefulHook2.getIdentifier()).thenReturn(id2); + when(statefulHook2.createCheckpointDataSerializer()).thenReturn(new LongSerializer()); + when(statefulHook2.triggerCheckpoint(anyLong(), anyLong(), any(Executor.class))) + .thenThrow(new Exception("not expected")); + + final MasterTriggerRestoreHook statelessHook = mockGeneric(MasterTriggerRestoreHook.class); + when(statelessHook.getIdentifier()).thenReturn("some-id"); + + final JobID jid = new JobID(); + final long checkpointId = 13L; + + final CompletedCheckpoint checkpoint = new CompletedCheckpoint( + jid, checkpointId, 123L, 125L, + Collections.emptyMap(), + masterHookStates, + CheckpointProperties.forStandardCheckpoint(), + null, + null); + + final ExecutionAttemptID execId = new ExecutionAttemptID(); + final ExecutionVertex ackVertex = mockExecutionVertex(execId); + final CheckpointCoordinator cc = instantiateCheckpointCoordinator(jid, ackVertex); + + cc.addMasterHook(statefulHook1); + cc.addMasterHook(statelessHook); + cc.addMasterHook(statefulHook2); + + cc.getCheckpointStore().addCheckpoint(checkpoint); + cc.restoreLatestCheckpointedState( + Collections.emptyMap(), + true, + false); + + verify(statefulHook1, times(1)).restoreCheckpoint(eq(checkpointId), eq(state1)); + verify(statefulHook2, times(1)).restoreCheckpoint(eq(checkpointId), eq(state2)); + verify(statelessHook, times(1)).restoreCheckpoint(eq(checkpointId), isNull(Void.class)); + } + + @Test + public void checkUnMatchedStateOnRestore() throws Exception { + final String id1 = "id1"; + final String id2 = "id2"; + + final String state1 = "the-test-string-state"; + final byte[] state1serialized = new StringSerializer().serialize(state1); + + final long state2 = 987654321L; + final byte[] state2serialized = new LongSerializer().serialize(state2); + + final List masterHookStates = Arrays.asList( + new MasterState(id1, state1serialized, StringSerializer.VERSION), + new MasterState(id2, state2serialized, LongSerializer.VERSION)); + + final MasterTriggerRestoreHook statefulHook = mockGeneric(MasterTriggerRestoreHook.class); + when(statefulHook.getIdentifier()).thenReturn(id1); + when(statefulHook.createCheckpointDataSerializer()).thenReturn(new StringSerializer()); + when(statefulHook.triggerCheckpoint(anyLong(), anyLong(), any(Executor.class))) + .thenThrow(new Exception("not expected")); + + final MasterTriggerRestoreHook statelessHook = mockGeneric(MasterTriggerRestoreHook.class); + when(statelessHook.getIdentifier()).thenReturn("some-id"); + + final JobID jid = new JobID(); + final long checkpointId = 44L; + + final CompletedCheckpoint checkpoint = new CompletedCheckpoint( + jid, checkpointId, 123L, 125L, + Collections.emptyMap(), + masterHookStates, + CheckpointProperties.forStandardCheckpoint(), + null, + null); + + final ExecutionAttemptID execId = new ExecutionAttemptID(); + final ExecutionVertex ackVertex = mockExecutionVertex(execId); + final CheckpointCoordinator cc = instantiateCheckpointCoordinator(jid, ackVertex); + + cc.addMasterHook(statefulHook); + cc.addMasterHook(statelessHook); + + cc.getCheckpointStore().addCheckpoint(checkpoint); + + // since we have unmatched state, this should fail + try { + cc.restoreLatestCheckpointedState( + Collections.emptyMap(), + true, + false); + fail("exception expected"); + } + catch (IllegalStateException ignored) {} + + // permitting unmatched state should succeed + cc.restoreLatestCheckpointedState( + Collections.emptyMap(), + true, + true); + + verify(statefulHook, times(1)).restoreCheckpoint(eq(checkpointId), eq(state1)); + verify(statelessHook, times(1)).restoreCheckpoint(eq(checkpointId), isNull(Void.class)); + } + + // ------------------------------------------------------------------------ + // failure scenarios + // ------------------------------------------------------------------------ + + @Test + public void testSerializationFailsOnTrigger() { + } + + @Test + public void testHookCallFailsOnTrigger() { + } + + @Test + public void testDeserializationFailsOnRestore() { + } + + @Test + public void testHookCallFailsOnRestore() { + } + + @Test + public void testTypeIncompatibleWithSerializerOnStore() { + } + + @Test + public void testTypeIncompatibleWithHookOnRestore() { + } + + // ------------------------------------------------------------------------ + // utilities + // ------------------------------------------------------------------------ + + private static CheckpointCoordinator instantiateCheckpointCoordinator(JobID jid, ExecutionVertex... ackVertices) { + return new CheckpointCoordinator( + jid, + 10000000L, + 600000L, + 0L, + 1, + ExternalizedCheckpointSettings.none(), + new ExecutionVertex[0], + ackVertices, + new ExecutionVertex[0], + new StandaloneCheckpointIDCounter(), + new StandaloneCompletedCheckpointStore(10), + null, + Executors.directExecutor()); + } + + private static T mockGeneric(Class clazz) { + @SuppressWarnings("unchecked") + Class typedClass = (Class) clazz; + return mock(typedClass); + } + + // ------------------------------------------------------------------------ + + private static final class StringSerializer implements SimpleVersionedSerializer { + + static final int VERSION = 77; + + @Override + public int getVersion() { + return VERSION; + } + + @Override + public byte[] serialize(String checkpointData) throws IOException { + return checkpointData.getBytes(StandardCharsets.UTF_8); + } + + @Override + public String deserialize(int version, byte[] serialized) throws IOException { + if (version != VERSION) { + throw new IOException("version mismatch"); + } + return new String(serialized, StandardCharsets.UTF_8); + } + } + + // ------------------------------------------------------------------------ + + private static final class LongSerializer implements SimpleVersionedSerializer { + + static final int VERSION = 5; + + @Override + public int getVersion() { + return VERSION; + } + + @Override + public byte[] serialize(Long checkpointData) throws IOException { + final byte[] bytes = new byte[8]; + ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN).putLong(0, checkpointData); + return bytes; + } + + @Override + public Long deserialize(int version, byte[] serialized) throws IOException { + assertEquals(VERSION, version); + assertEquals(8, serialized.length); + + return ByteBuffer.wrap(serialized).order(ByteOrder.LITTLE_ENDIAN).getLong(0); + } + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java index 4a36dd282c0a6..fc6e51664306e 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java @@ -273,7 +273,7 @@ public TestCompletedCheckpoint( Map taskGroupStates, CheckpointProperties props) { - super(jobId, checkpointId, timestamp, Long.MAX_VALUE, taskGroupStates, props); + super(jobId, checkpointId, timestamp, Long.MAX_VALUE, taskGroupStates, null, props); } @Override diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointTest.java index 0b759d4056273..652cc767af366 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointTest.java @@ -45,7 +45,7 @@ public class CompletedCheckpointTest { @Rule - public TemporaryFolder tmpFolder = new TemporaryFolder(); + public final TemporaryFolder tmpFolder = new TemporaryFolder(); /** * Tests that persistent checkpoints discard their header file. @@ -61,7 +61,10 @@ public void testDiscard() throws Exception { // Verify discard call is forwarded to state CompletedCheckpoint checkpoint = new CompletedCheckpoint( - new JobID(), 0, 0, 1, taskStates, CheckpointProperties.forStandardCheckpoint(), + new JobID(), 0, 0, 1, + taskStates, + Collections.emptyList(), + CheckpointProperties.forStandardCheckpoint(), new FileStateHandle(new Path(file.toURI()), file.length()), file.getAbsolutePath()); @@ -81,8 +84,12 @@ public void testCleanUpOnSubsume() throws Exception { boolean discardSubsumed = true; CheckpointProperties props = new CheckpointProperties(false, false, discardSubsumed, true, true, true, true); + CompletedCheckpoint checkpoint = new CompletedCheckpoint( - new JobID(), 0, 0, 1, taskStates, props); + new JobID(), 0, 0, 1, + taskStates, + Collections.emptyList(), + props); SharedStateRegistry sharedStateRegistry = new SharedStateRegistry(); checkpoint.registerSharedStates(sharedStateRegistry); @@ -117,7 +124,10 @@ public void testCleanUpOnShutdown() throws Exception { // Keep CheckpointProperties props = new CheckpointProperties(false, true, false, false, false, false, false); CompletedCheckpoint checkpoint = new CompletedCheckpoint( - new JobID(), 0, 0, 1, new HashMap<>(taskStates), props, + new JobID(), 0, 0, 1, + new HashMap<>(taskStates), + Collections.emptyList(), + props, new FileStateHandle(new Path(file.toURI()), file.length()), externalPath); @@ -132,7 +142,10 @@ public void testCleanUpOnShutdown() throws Exception { // Discard props = new CheckpointProperties(false, false, true, true, true, true, true); checkpoint = new CompletedCheckpoint( - new JobID(), 0, 0, 1, new HashMap<>(taskStates), props); + new JobID(), 0, 0, 1, + new HashMap<>(taskStates), + Collections.emptyList(), + props); checkpoint.discardOnShutdown(status, sharedStateRegistry); verify(state, times(1)).discardState(); @@ -155,6 +168,7 @@ public void testCompletedCheckpointStatsCallbacks() throws Exception { 0, 1, new HashMap<>(taskStates), + Collections.emptyList(), CheckpointProperties.forStandardCheckpoint()); CompletedCheckpointStats.DiscardCallback callback = mock(CompletedCheckpointStats.DiscardCallback.class); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ExecutionGraphCheckpointCoordinatorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ExecutionGraphCheckpointCoordinatorTest.java index 5fce62ebc2528..1f038bd274743 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ExecutionGraphCheckpointCoordinatorTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ExecutionGraphCheckpointCoordinatorTest.java @@ -105,6 +105,7 @@ private ExecutionGraph createExecutionGraphAndEnableCheckpointing( Collections.emptyList(), Collections.emptyList(), Collections.emptyList(), + Collections.>emptyList(), counter, store, null, diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Test.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/CheckpointTestUtils.java similarity index 69% rename from flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Test.java rename to flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/CheckpointTestUtils.java index 08ec35e31994c..7d9874e96e2c2 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Test.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/CheckpointTestUtils.java @@ -19,6 +19,7 @@ package org.apache.flink.runtime.checkpoint.savepoint; import org.apache.flink.configuration.ConfigConstants; +import org.apache.flink.runtime.checkpoint.MasterState; import org.apache.flink.runtime.checkpoint.SubtaskState; import org.apache.flink.runtime.checkpoint.TaskState; import org.apache.flink.runtime.jobgraph.JobVertexID; @@ -26,50 +27,41 @@ import org.apache.flink.runtime.state.KeyGroupRangeOffsets; import org.apache.flink.runtime.state.KeyGroupsStateHandle; import org.apache.flink.runtime.state.OperatorStateHandle; +import org.apache.flink.runtime.state.OperatorStateHandle.StateMetaInfo; import org.apache.flink.runtime.state.StreamStateHandle; import org.apache.flink.runtime.util.TestByteStreamStateHandleDeepCompare; -import org.junit.Test; +import org.apache.flink.util.StringUtils; -import java.io.IOException; import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Random; -import java.util.concurrent.ThreadLocalRandom; +import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; -public class SavepointV1Test { +/** + * A collection of utility methods for testing the (de)serialization of + * checkpoint metadata for persistence. + */ +public class CheckpointTestUtils { /** - * Simple test of savepoint methods. + * Creates a random collection of TaskState objects containing various types of state handles. */ - @Test - public void testSavepointV1() throws Exception { - long checkpointId = ThreadLocalRandom.current().nextLong(Integer.MAX_VALUE); - int numTaskStates = 4; - int numSubtaskStates = 16; - - Collection expected = createTaskStates(numTaskStates, numSubtaskStates); - - SavepointV1 savepoint = new SavepointV1(checkpointId, expected); - - assertEquals(SavepointV1.VERSION, savepoint.getVersion()); - assertEquals(checkpointId, savepoint.getCheckpointId()); - assertEquals(expected, savepoint.getTaskStates()); - - assertFalse(savepoint.getTaskStates().isEmpty()); - savepoint.dispose(); - assertTrue(savepoint.getTaskStates().isEmpty()); + public static Collection createTaskStates(int numTaskStates, int numSubtasksPerTask) { + return createTaskStates(new Random(), numTaskStates, numSubtasksPerTask); } - static Collection createTaskStates(int numTaskStates, int numSubtasksPerTask) throws IOException { - - Random random = new Random(numTaskStates * 31 + numSubtasksPerTask); + /** + * Creates a random collection of TaskState objects containing various types of state handles. + */ + public static Collection createTaskStates( + Random random, + int numTaskStates, + int numSubtasksPerTask) { List taskStates = new ArrayList<>(numTaskStates); @@ -96,12 +88,12 @@ static Collection createTaskStates(int numTaskStates, int numSubtasks StreamStateHandle nonPartitionableState = new TestByteStreamStateHandleDeepCompare("a-" + chainIdx, ("Hi-" + chainIdx).getBytes( - ConfigConstants.DEFAULT_CHARSET)); + ConfigConstants.DEFAULT_CHARSET)); StreamStateHandle operatorStateBackend = new TestByteStreamStateHandleDeepCompare("b-" + chainIdx, ("Beautiful-" + chainIdx).getBytes(ConfigConstants.DEFAULT_CHARSET)); StreamStateHandle operatorStateStream = new TestByteStreamStateHandleDeepCompare("b-" + chainIdx, ("Beautiful-" + chainIdx).getBytes(ConfigConstants.DEFAULT_CHARSET)); - Map offsetsMap = new HashMap<>(); + Map offsetsMap = new HashMap<>(); offsetsMap.put("A", new OperatorStateHandle.StateMetaInfo(new long[]{0, 10, 20}, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE)); offsetsMap.put("B", new OperatorStateHandle.StateMetaInfo(new long[]{30, 40, 50}, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE)); offsetsMap.put("C", new OperatorStateHandle.StateMetaInfo(new long[]{60, 70, 80}, OperatorStateHandle.Mode.BROADCAST)); @@ -130,14 +122,14 @@ static Collection createTaskStates(int numTaskStates, int numSubtasks keyedStateBackend = new KeyGroupsStateHandle( new KeyGroupRangeOffsets(1, 1, new long[]{42}), new TestByteStreamStateHandleDeepCompare("c", "Hello" - .getBytes(ConfigConstants.DEFAULT_CHARSET))); + .getBytes(ConfigConstants.DEFAULT_CHARSET))); } if (hasKeyedStream) { keyedStateStream = new KeyGroupsStateHandle( new KeyGroupRangeOffsets(1, 1, new long[]{23}), new TestByteStreamStateHandleDeepCompare("d", "World" - .getBytes(ConfigConstants.DEFAULT_CHARSET))); + .getBytes(ConfigConstants.DEFAULT_CHARSET))); } taskState.putState(subtaskIdx, new SubtaskState( @@ -154,4 +146,39 @@ static Collection createTaskStates(int numTaskStates, int numSubtasks return taskStates; } + /** + * Creates a bunch of random master states. + */ + public static Collection createRandomMasterStates(Random random, int num) { + final ArrayList states = new ArrayList<>(num); + + for (int i = 0; i < num; i++) { + int version = random.nextInt(10); + String name = StringUtils.getRandomString(random, 5, 500); + byte[] bytes = new byte[random.nextInt(5000) + 1]; + random.nextBytes(bytes); + + states.add(new MasterState(name, bytes, version)); + } + + return states; + } + + /** + * Asserts that two MasterStates are equal. + * + *

The MasterState avoids overriding {@code equals()} on purpose, because equality is not well + * defined in the raw contents. + */ + public static void assertMasterStateEquality(MasterState a, MasterState b) { + assertEquals(a.version(), b.version()); + assertEquals(a.name(), b.name()); + assertArrayEquals(a.bytes(), b.bytes()); + + } + + // ------------------------------------------------------------------------ + + /** utility class, not meant to be instantiated */ + private CheckpointTestUtils() {} } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointLoaderTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointLoaderTest.java index c66b29d9a82dd..20b1e5713bfa4 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointLoaderTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointLoaderTest.java @@ -67,7 +67,7 @@ public void testLoadAndValidateSavepoint() throws Exception { JobID jobId = new JobID(); // Store savepoint - SavepointV1 savepoint = new SavepointV1(checkpointId, taskStates.values()); + SavepointV2 savepoint = new SavepointV2(checkpointId, taskStates.values()); String path = SavepointStore.storeSavepoint(tmp.getAbsolutePath(), savepoint); ExecutionJobVertex vertex = mock(ExecutionJobVertex.class); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointStoreTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointStoreTest.java index 1eb805599967c..cf79282797479 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointStoreTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointStoreTest.java @@ -23,6 +23,7 @@ import org.apache.flink.core.fs.FSDataOutputStream; import org.apache.flink.core.fs.FileSystem; import org.apache.flink.core.fs.Path; +import org.apache.flink.runtime.checkpoint.MasterState; import org.apache.flink.runtime.checkpoint.TaskState; import org.apache.flink.runtime.state.filesystem.FileStateHandle; import org.junit.Rule; @@ -68,7 +69,7 @@ public void testStoreLoadDispose() throws Exception { // Store String savepointDirectory = SavepointStore.createSavepointDirectory(root, new JobID()); - SavepointV1 stored = new SavepointV1(1929292, SavepointV1Test.createTaskStates(4, 24)); + SavepointV2 stored = new SavepointV2(1929292, CheckpointTestUtils.createTaskStates(4, 24)); String path = SavepointStore.storeSavepoint(savepointDirectory, stored); list = rootFile.listFiles(); @@ -77,7 +78,10 @@ public void testStoreLoadDispose() throws Exception { // Load Savepoint loaded = SavepointStore.loadSavepoint(path, Thread.currentThread().getContextClassLoader()); - assertEquals(stored, loaded); + + assertEquals(stored.getCheckpointId(), loaded.getCheckpointId()); + assertEquals(stored.getTaskStates(), loaded.getTaskStates()); + assertEquals(stored.getMasterStates(), loaded.getMasterStates()); loaded.dispose(); @@ -126,8 +130,8 @@ public void testMultipleSavepointVersions() throws Exception { File rootFile = new File(root); // New savepoint type for test - int version = ThreadLocalRandom.current().nextInt(); - long checkpointId = ThreadLocalRandom.current().nextLong(); + int version = ThreadLocalRandom.current().nextInt(Integer.MAX_VALUE); // make this a positive number + long checkpointId = ThreadLocalRandom.current().nextLong(Long.MAX_VALUE); // make this a positive number // Add serializer serializers.put(version, NewSavepointSerializer.INSTANCE); @@ -143,7 +147,7 @@ public void testMultipleSavepointVersions() throws Exception { // Savepoint v0 String savepointDirectory2 = SavepointStore.createSavepointDirectory(root, new JobID()); - Savepoint savepoint = new SavepointV1(checkpointId, SavepointV1Test.createTaskStates(4, 32)); + SavepointV2 savepoint = new SavepointV2(checkpointId, CheckpointTestUtils.createTaskStates(4, 32)); String pathSavepoint = SavepointStore.storeSavepoint(savepointDirectory2, savepoint); list = rootFile.listFiles(); @@ -156,7 +160,9 @@ public void testMultipleSavepointVersions() throws Exception { assertEquals(newSavepoint, loaded); loaded = SavepointStore.loadSavepoint(pathSavepoint, Thread.currentThread().getContextClassLoader()); - assertEquals(savepoint, loaded); + assertEquals(savepoint.getCheckpointId(), loaded.getCheckpointId()); + assertEquals(savepoint.getTaskStates(), loaded.getTaskStates()); + assertEquals(savepoint.getMasterStates(), loaded.getMasterStates()); } /** @@ -199,7 +205,7 @@ public void testStoreExternalizedCheckpointsToSameDirectory() throws Exception { FileSystem fs = FileSystem.get(new Path(root).toUri()); // Store - SavepointV1 savepoint = new SavepointV1(1929292, SavepointV1Test.createTaskStates(4, 24)); + SavepointV2 savepoint = new SavepointV2(1929292, CheckpointTestUtils.createTaskStates(4, 24)); FileStateHandle store1 = SavepointStore.storeExternalizedCheckpointToHandle(root, savepoint); fs.exists(store1.getFilePath()); @@ -251,7 +257,12 @@ public long getCheckpointId() { @Override public Collection getTaskStates() { - return Collections.EMPTY_LIST; + return Collections.emptyList(); + } + + @Override + public Collection getMasterStates() { + return Collections.emptyList(); } @Override diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1SerializerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1SerializerTest.java index 58cf1aa9b79e0..0eff7bc398f8e 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1SerializerTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1SerializerTest.java @@ -27,6 +27,7 @@ import java.util.Random; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; public class SavepointV1SerializerTest { @@ -35,25 +36,29 @@ public class SavepointV1SerializerTest { */ @Test public void testSerializeDeserializeV1() throws Exception { - Random r = new Random(42); - for (int i = 0; i < 100; ++i) { + final Random r = new Random(42); + + for (int i = 0; i < 50; ++i) { SavepointV1 expected = - new SavepointV1(i+ 123123, SavepointV1Test.createTaskStates(1 + r.nextInt(64), 1 + r.nextInt(64))); + new SavepointV1(i+ 123123, CheckpointTestUtils.createTaskStates(r, 1 + r.nextInt(64), 1 + r.nextInt(64))); SavepointV1Serializer serializer = SavepointV1Serializer.INSTANCE; // Serialize ByteArrayOutputStreamWithPos baos = new ByteArrayOutputStreamWithPos(); - serializer.serialize(expected, new DataOutputViewStreamWrapper(baos)); + serializer.serializeOld(expected, new DataOutputViewStreamWrapper(baos)); byte[] bytes = baos.toByteArray(); // Deserialize ByteArrayInputStream bais = new ByteArrayInputStream(bytes); - Savepoint actual = serializer.deserialize( + SavepointV2 actual = serializer.deserialize( new DataInputViewStreamWrapper(bais), Thread.currentThread().getContextClassLoader()); - assertEquals(expected, actual); + + assertEquals(expected.getCheckpointId(), actual.getCheckpointId()); + assertEquals(expected.getTaskStates(), actual.getTaskStates()); + assertTrue(actual.getMasterStates().isEmpty()); } } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV2SerializerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV2SerializerTest.java new file mode 100644 index 0000000000000..deb14ddc85a8c --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV2SerializerTest.java @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint.savepoint; + +import org.apache.flink.core.memory.ByteArrayInputStreamWithPos; +import org.apache.flink.core.memory.ByteArrayOutputStreamWithPos; +import org.apache.flink.core.memory.DataInputViewStreamWrapper; +import org.apache.flink.core.memory.DataOutputViewStreamWrapper; +import org.apache.flink.runtime.checkpoint.MasterState; +import org.apache.flink.runtime.checkpoint.TaskState; + +import org.junit.Test; + +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.util.Collection; +import java.util.Collections; +import java.util.Iterator; +import java.util.Random; + +import static org.junit.Assert.assertEquals; + +/** + * Various tests for the version 2 format serializer of a checkpoint. + */ +public class SavepointV2SerializerTest { + + @Test + public void testCheckpointWithNoState() throws Exception { + final Random rnd = new Random(); + + for (int i = 0; i < 100; ++i) { + final long checkpointId = rnd.nextLong() & 0x7fffffffffffffffL; + final Collection taskStates = Collections.emptyList(); + final Collection masterStates = Collections.emptyList(); + + testCheckpointSerialization(checkpointId, taskStates, masterStates); + } + } + + @Test + public void testCheckpointWithOnlyMasterState() throws Exception { + final Random rnd = new Random(); + final int maxNumMasterStates = 5; + + for (int i = 0; i < 100; ++i) { + final long checkpointId = rnd.nextLong() & 0x7fffffffffffffffL; + + final Collection taskStates = Collections.emptyList(); + + final int numMasterStates = rnd.nextInt(maxNumMasterStates) + 1; + final Collection masterStates = + CheckpointTestUtils.createRandomMasterStates(rnd, numMasterStates); + + testCheckpointSerialization(checkpointId, taskStates, masterStates); + } + } + + @Test + public void testCheckpointWithOnlyTaskState() throws Exception { + final Random rnd = new Random(); + final int maxTaskStates = 20; + final int maxNumSubtasks = 20; + + for (int i = 0; i < 100; ++i) { + final long checkpointId = rnd.nextLong() & 0x7fffffffffffffffL; + + final int numTasks = rnd.nextInt(maxTaskStates) + 1; + final int numSubtasks = rnd.nextInt(maxNumSubtasks) + 1; + final Collection taskStates = + CheckpointTestUtils.createTaskStates(rnd, numTasks, numSubtasks); + + final Collection masterStates = Collections.emptyList(); + + testCheckpointSerialization(checkpointId, taskStates, masterStates); + } + } + + @Test + public void testCheckpointWithMasterAndTaskState() throws Exception { + final Random rnd = new Random(); + + final int maxNumMasterStates = 5; + final int maxTaskStates = 20; + final int maxNumSubtasks = 20; + + for (int i = 0; i < 100; ++i) { + final long checkpointId = rnd.nextLong() & 0x7fffffffffffffffL; + + final int numTasks = rnd.nextInt(maxTaskStates) + 1; + final int numSubtasks = rnd.nextInt(maxNumSubtasks) + 1; + final Collection taskStates = + CheckpointTestUtils.createTaskStates(rnd, numTasks, numSubtasks); + + final int numMasterStates = rnd.nextInt(maxNumMasterStates) + 1; + final Collection masterStates = + CheckpointTestUtils.createRandomMasterStates(rnd, numMasterStates); + + testCheckpointSerialization(checkpointId, taskStates, masterStates); + } + } + + private void testCheckpointSerialization( + long checkpointId, + Collection taskStates, + Collection masterStates) throws IOException { + + SavepointV2Serializer serializer = SavepointV2Serializer.INSTANCE; + + ByteArrayOutputStreamWithPos baos = new ByteArrayOutputStreamWithPos(); + DataOutputStream out = new DataOutputViewStreamWrapper(baos); + + serializer.serialize(new SavepointV2(checkpointId, taskStates, masterStates), out); + out.close(); + + byte[] bytes = baos.toByteArray(); + + DataInputStream in = new DataInputViewStreamWrapper(new ByteArrayInputStreamWithPos(bytes)); + SavepointV2 deserialized = serializer.deserialize(in, getClass().getClassLoader()); + + assertEquals(checkpointId, deserialized.getCheckpointId()); + assertEquals(taskStates, deserialized.getTaskStates()); + + assertEquals(masterStates.size(), deserialized.getMasterStates().size()); + for (Iterator a = masterStates.iterator(), b = deserialized.getMasterStates().iterator(); + a.hasNext();) + { + CheckpointTestUtils.assertMasterStateEquality(a.next(), b.next()); + } + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV2Test.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV2Test.java new file mode 100644 index 0000000000000..428a62a675c7c --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV2Test.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint.savepoint; + +import org.apache.flink.runtime.checkpoint.MasterState; +import org.apache.flink.runtime.checkpoint.TaskState; + +import org.junit.Test; + +import java.util.Collection; +import java.util.Random; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +public class SavepointV2Test { + + /** + * Simple test of savepoint methods. + */ + @Test + public void testSavepointV1() throws Exception { + final Random rnd = new Random(); + + final long checkpointId = rnd.nextInt(Integer.MAX_VALUE) + 1; + final int numTaskStates = 4; + final int numSubtaskStates = 16; + final int numMasterStates = 7; + + Collection taskStates = + CheckpointTestUtils.createTaskStates(rnd, numTaskStates, numSubtaskStates); + + Collection masterStates = + CheckpointTestUtils.createRandomMasterStates(rnd, numMasterStates); + + SavepointV2 checkpoint = new SavepointV2(checkpointId, taskStates, masterStates); + + assertEquals(2, checkpoint.getVersion()); + assertEquals(checkpointId, checkpoint.getCheckpointId()); + assertEquals(taskStates, checkpoint.getTaskStates()); + assertEquals(masterStates, checkpoint.getMasterStates()); + + assertFalse(checkpoint.getTaskStates().isEmpty()); + assertFalse(checkpoint.getMasterStates().isEmpty()); + + checkpoint.dispose(); + + assertTrue(checkpoint.getTaskStates().isEmpty()); + assertTrue(checkpoint.getMasterStates().isEmpty()); + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ArchivedExecutionGraphTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ArchivedExecutionGraphTest.java index f96b62456dcf8..4e1d0f7175da7 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ArchivedExecutionGraphTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ArchivedExecutionGraphTest.java @@ -31,6 +31,7 @@ import org.apache.flink.runtime.akka.AkkaUtils; import org.apache.flink.runtime.checkpoint.CheckpointStatsSnapshot; import org.apache.flink.runtime.checkpoint.CheckpointStatsTracker; +import org.apache.flink.runtime.checkpoint.MasterTriggerRestoreHook; import org.apache.flink.runtime.checkpoint.StandaloneCheckpointIDCounter; import org.apache.flink.runtime.checkpoint.StandaloneCompletedCheckpointStore; import org.apache.flink.runtime.execution.ExecutionState; @@ -128,6 +129,7 @@ public static void setupExecutionGraph() throws Exception { Collections.emptyList(), Collections.emptyList(), Collections.emptyList(), + Collections.>emptyList(), new StandaloneCheckpointIDCounter(), new StandaloneCompletedCheckpointStore(1), null, diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/ExternallyInducedSource.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/ExternallyInducedSource.java new file mode 100644 index 0000000000000..b26cf4f1505e8 --- /dev/null +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/ExternallyInducedSource.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.api.checkpoint; + +import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.streaming.api.functions.source.SourceFunction; +import org.apache.flink.util.FlinkException; + +/** + * Sources that implement this interface do not trigger checkpoints when receiving a + * trigger message from the checkpoint coordinator, but when their input data/events + * indicate that a checkpoint should be triggered. + * + *

Since sources cannot simply create a new checkpoint on their own, this mechanism + * always goes together with a {@link WithMasterCheckpointHook hook on the master side}. + * In a typical setup, the hook on the master tells the source system (for example + * the message queue) to prepare a checkpoint. The exact point when the checkpoint is + * taken is then controlled by the event stream received from the source, and triggered + * by the source function (implementing this interface) in Flink when seeing the relevant + * events. + * + * @param Type of the elements produced by the source function + * @param The type of the data stored in the checkpoint by the master that triggers + */ +@PublicEvolving +public interface ExternallyInducedSource extends SourceFunction, WithMasterCheckpointHook { + + /** + * Sets the checkpoint trigger through which the source can trigger the checkpoint. + * + * @param checkpointTrigger The checkpoint trigger to set + */ + void setCheckpointTrigger(CheckpointTrigger checkpointTrigger); + + // ------------------------------------------------------------------------ + + /** + * Through the {@code CheckpointTrigger}, the source function notifies the Flink + * source operator when to trigger the checkpoint. + */ + interface CheckpointTrigger { + + /** + * Triggers a checkpoint. This method should be called by the source + * when it sees the event that indicates that a checkpoint should be triggered. + * + *

When this method is called, the parallel operator instance in which the + * calling source function runs will perform its checkpoint and insert the + * checkpoint barrier into the data stream. + * + * @param checkpointId The ID that identifies the checkpoint. + * + * @throws FlinkException Thrown when the checkpoint could not be triggered, for example + * because of an invalid state or errors when storing the + * checkpoint state. + */ + void triggerCheckpoint(long checkpointId) throws FlinkException; + } +} diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/WithMasterCheckpointHook.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/WithMasterCheckpointHook.java new file mode 100644 index 0000000000000..ef872de0e34d6 --- /dev/null +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/WithMasterCheckpointHook.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.api.checkpoint; + +import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.runtime.checkpoint.MasterTriggerRestoreHook; + +/** + * This interface can be implemented by streaming functions that need to trigger a + * "global action" on the master (in the checkpoint coordinator) as part of every + * checkpoint and restore operation. + * + * @param The type of the data stored by the hook in the checkpoint, or {@code Void}, if none. + */ +@PublicEvolving +public interface WithMasterCheckpointHook extends java.io.Serializable { + + /** + * Creates the hook that should be called by the checkpoint coordinator. + */ + MasterTriggerRestoreHook createMasterTriggerRestoreHook(); +} diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/FunctionMasterCheckpointHookFactory.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/FunctionMasterCheckpointHookFactory.java new file mode 100644 index 0000000000000..c2566980052c2 --- /dev/null +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/FunctionMasterCheckpointHookFactory.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.api.graph; + +import static org.apache.flink.util.Preconditions.checkNotNull; + +import org.apache.flink.runtime.checkpoint.MasterTriggerRestoreHook; +import org.apache.flink.streaming.api.checkpoint.WithMasterCheckpointHook; + +/** + * Utility class that turns a {@link WithMasterCheckpointHook} into a + * {@link org.apache.flink.runtime.checkpoint.MasterTriggerRestoreHook.Factory}. + */ +class FunctionMasterCheckpointHookFactory implements MasterTriggerRestoreHook.Factory { + + private static final long serialVersionUID = 2L; + + private final WithMasterCheckpointHook creator; + + FunctionMasterCheckpointHookFactory(WithMasterCheckpointHook creator) { + this.creator = checkNotNull(creator); + } + + @SuppressWarnings("unchecked") + @Override + public MasterTriggerRestoreHook create() { + return (MasterTriggerRestoreHook) creator.createMasterTriggerRestoreHook(); + } +} diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java index a1d33d88713ee..7f24cd358e353 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java @@ -29,6 +29,7 @@ import org.apache.commons.lang3.StringUtils; import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.functions.Function; import org.apache.flink.api.common.operators.ResourceSpec; import org.apache.flink.api.common.operators.util.UserCodeObjectWrapper; import org.apache.flink.api.common.restartstrategy.RestartStrategies; @@ -36,6 +37,7 @@ import org.apache.flink.configuration.Configuration; import org.apache.flink.configuration.IllegalConfigurationException; import org.apache.flink.migration.streaming.api.graph.StreamGraphHasherV1; +import org.apache.flink.runtime.checkpoint.MasterTriggerRestoreHook; import org.apache.flink.runtime.io.network.partition.ResultPartitionType; import org.apache.flink.runtime.jobgraph.DistributionPattern; import org.apache.flink.runtime.jobgraph.InputFormatVertex; @@ -51,7 +53,9 @@ import org.apache.flink.runtime.jobmanager.scheduler.SlotSharingGroup; import org.apache.flink.runtime.operators.util.TaskConfig; import org.apache.flink.streaming.api.CheckpointingMode; +import org.apache.flink.streaming.api.checkpoint.WithMasterCheckpointHook; import org.apache.flink.streaming.api.environment.CheckpointConfig; +import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator; import org.apache.flink.streaming.api.operators.ChainingStrategy; import org.apache.flink.streaming.api.operators.StreamOperator; import org.apache.flink.streaming.runtime.partitioner.ForwardPartitioner; @@ -542,6 +546,8 @@ private void configureCheckpointing() { interval = Long.MAX_VALUE; } + // --- configure the participating vertices --- + // collect the vertices that receive "trigger checkpoint" messages. // currently, these are all the sources List triggerVertices = new ArrayList<>(); @@ -552,7 +558,7 @@ private void configureCheckpointing() { // collect the vertices that receive "commit checkpoint" messages // currently, these are all vertices - List commitVertices = new ArrayList<>(); + List commitVertices = new ArrayList<>(jobVertices.size()); for (JobVertex vertex : jobVertices.values()) { if (vertex.isInputVertex()) { @@ -562,6 +568,8 @@ private void configureCheckpointing() { ackVertices.add(vertex.getID()); } + // --- configure options --- + ExternalizedCheckpointSettings externalizedCheckpointSettings; if (cfg.isExternalizedCheckpointsEnabled()) { CheckpointConfig.ExternalizedCheckpointCleanup cleanup = cfg.getExternalizedCheckpointCleanup(); @@ -587,12 +595,30 @@ private void configureCheckpointing() { "exactly-once or at-least-once."); } + // --- configure the master-side checkpoint hooks --- + + final ArrayList hooks = new ArrayList<>(); + + for (StreamNode node : streamGraph.getStreamNodes()) { + StreamOperator op = node.getOperator(); + if (op instanceof AbstractUdfStreamOperator) { + Function f = ((AbstractUdfStreamOperator) op).getUserFunction(); + + if (f instanceof WithMasterCheckpointHook) { + hooks.add(new FunctionMasterCheckpointHookFactory((WithMasterCheckpointHook) f)); + } + } + } + + // --- done, put it all together --- + JobCheckpointingSettings settings = new JobCheckpointingSettings( triggerVertices, ackVertices, commitVertices, interval, cfg.getCheckpointTimeout(), cfg.getMinPauseBetweenCheckpoints(), cfg.getMaxConcurrentCheckpoints(), externalizedCheckpointSettings, streamGraph.getStateBackend(), + hooks.toArray(new MasterTriggerRestoreHook.Factory[hooks.size()]), isExactlyOnce); jobGraph.setSnapshotSettings(settings); diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/SourceStreamTask.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/SourceStreamTask.java index 66e92dff70179..31cd7c18d24d5 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/SourceStreamTask.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/SourceStreamTask.java @@ -19,8 +19,12 @@ package org.apache.flink.streaming.runtime.tasks; import org.apache.flink.annotation.Internal; +import org.apache.flink.runtime.checkpoint.CheckpointMetaData; +import org.apache.flink.runtime.checkpoint.CheckpointOptions; +import org.apache.flink.streaming.api.checkpoint.ExternallyInducedSource; import org.apache.flink.streaming.api.functions.source.SourceFunction; import org.apache.flink.streaming.api.operators.StreamSource; +import org.apache.flink.util.FlinkException; /** * {@link StreamTask} for executing a {@link StreamSource}. @@ -40,9 +44,44 @@ public class SourceStreamTask, OP extends StreamSource> extends StreamTask { + private volatile boolean externallyInducedCheckpoints; + @Override protected void init() { // does not hold any resources, so no initialization needed + + // we check if the source is actually inducing the checkpoints, rather + // than the trigger ch + SourceFunction source = headOperator.getUserFunction(); + if (source instanceof ExternallyInducedSource) { + externallyInducedCheckpoints = true; + + ExternallyInducedSource.CheckpointTrigger triggerHook = new ExternallyInducedSource.CheckpointTrigger() { + + @Override + public void triggerCheckpoint(long checkpointId) throws FlinkException { + // TODO - we need to see how to derive those. We should probably not encode this in the + // TODO - source's trigger message, but do a handshake in this task between the trigger + // TODO - message from the master, and the source's trigger notification + final CheckpointOptions checkpointOptions = CheckpointOptions.forFullCheckpoint(); + final long timestamp = System.currentTimeMillis(); + + final CheckpointMetaData checkpointMetaData = new CheckpointMetaData(checkpointId, timestamp); + + try { + SourceStreamTask.super.triggerCheckpoint(checkpointMetaData, checkpointOptions); + } + catch (RuntimeException | FlinkException e) { + throw e; + } + catch (Exception e) { + throw new FlinkException(e.getMessage(), e); + } + } + }; + + ((ExternallyInducedSource) source).setCheckpointTrigger(triggerHook); + } } @Override @@ -62,4 +101,21 @@ protected void cancelTask() throws Exception { headOperator.cancel(); } } + + // ------------------------------------------------------------------------ + // Checkpointing + // ------------------------------------------------------------------------ + + @Override + public boolean triggerCheckpoint(CheckpointMetaData checkpointMetaData, CheckpointOptions checkpointOptions) throws Exception { + if (!externallyInducedCheckpoints) { + return super.triggerCheckpoint(checkpointMetaData, checkpointOptions); + } + else { + // we do not trigger checkpoints here, we simply state whether we can trigger them + synchronized (getCheckpointLock()) { + return isRunning(); + } + } + } } diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/graph/WithMasterCheckpointHookConfigTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/graph/WithMasterCheckpointHookConfigTest.java new file mode 100644 index 0000000000000..b5a95eb3c0c18 --- /dev/null +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/graph/WithMasterCheckpointHookConfigTest.java @@ -0,0 +1,189 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.graph; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.core.io.SimpleVersionedSerializer; +import org.apache.flink.runtime.checkpoint.MasterTriggerRestoreHook; +import org.apache.flink.runtime.checkpoint.MasterTriggerRestoreHook.Factory; +import org.apache.flink.runtime.concurrent.Future; +import org.apache.flink.runtime.jobgraph.JobGraph; +import org.apache.flink.streaming.api.checkpoint.WithMasterCheckpointHook; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.functions.sink.DiscardingSink; +import org.apache.flink.streaming.api.functions.source.SourceFunction; + +import org.junit.Test; + +import javax.annotation.Nullable; +import java.util.HashSet; +import java.util.Set; +import java.util.concurrent.Executor; + +import static java.util.Arrays.asList; +import static org.junit.Assert.*; + +/** + * Tests that when sources implement {@link WithMasterCheckpointHook} the hooks are properly + * configured in the job's checkpoint settings. + */ +@SuppressWarnings("serial") +public class WithMasterCheckpointHookConfigTest { + + /** + * This test creates a program with 4 sources (2 with master hooks, 2 without). + * The resulting job graph must have 2 configured master hooks. + */ + @Test + public void testHookConfiguration() throws Exception { + // create some sources some of which configure master hooks + final TestSource source1 = new TestSource(); + final TestSourceWithHook source2 = new TestSourceWithHook("foo"); + final TestSource source3 = new TestSource(); + final TestSourceWithHook source4 = new TestSourceWithHook("bar"); + + final MapFunction identity = new Identity<>(); + final IdentityWithHook identityWithHook1 = new IdentityWithHook<>("apple"); + final IdentityWithHook identityWithHook2 = new IdentityWithHook<>("orange"); + + final Set> hooks = new HashSet>(asList( + source2.createMasterTriggerRestoreHook(), + source4.createMasterTriggerRestoreHook(), + identityWithHook1.createMasterTriggerRestoreHook(), + identityWithHook2.createMasterTriggerRestoreHook())); + + // we can instantiate a local environment here, because we never actually execute something + final StreamExecutionEnvironment env = StreamExecutionEnvironment.createLocalEnvironment(); + env.enableCheckpointing(500); + + env + .addSource(source1).map(identity) + .union(env.addSource(source2).map(identity)) + .union(env.addSource(source3).map(identityWithHook1)) + .union(env.addSource(source4).map(identityWithHook2)) + .addSink(new DiscardingSink()); + + final JobGraph jg = env.getStreamGraph().getJobGraph(); + assertEquals(hooks.size(), jg.getCheckpointingSettings().getMasterHooks().length); + + // check that all hooks are contained and exist exactly once + for (Factory f : jg.getCheckpointingSettings().getMasterHooks()) { + MasterTriggerRestoreHook hook = f.create(); + assertTrue(hooks.remove(hook)); + } + assertTrue(hooks.isEmpty()); + } + + // ----------------------------------------------------------------------- + + private static class TestHook implements MasterTriggerRestoreHook { + + private final String id; + + TestHook(String id) { + this.id = id; + } + + @Override + public String getIdentifier() { + return id; + } + + @Override + public Future triggerCheckpoint(long checkpointId, long timestamp, Executor executor) { + throw new UnsupportedOperationException(); + } + + @Override + public void restoreCheckpoint(long checkpointId, @Nullable String checkpointData) throws Exception { + throw new UnsupportedOperationException(); + } + + @Nullable + @Override + public SimpleVersionedSerializer createCheckpointDataSerializer() { + throw new UnsupportedOperationException(); + } + + @Override + public boolean equals(Object obj) { + return obj == this || (obj != null && obj.getClass() == getClass() && ((TestHook) obj).id.equals(id)); + } + + @Override + public int hashCode() { + return id.hashCode(); + } + } + + // ----------------------------------------------------------------------- + + private static class TestSource implements SourceFunction { + + @Override + public void run(SourceContext ctx) { + throw new UnsupportedOperationException(); + } + + @Override + public void cancel() {} + } + + // ----------------------------------------------------------------------- + + private static class TestSourceWithHook extends TestSource implements WithMasterCheckpointHook { + + private final String id; + + TestSourceWithHook(String id) { + this.id = id; + } + + @Override + public TestHook createMasterTriggerRestoreHook() { + return new TestHook(id); + } + } + + // ----------------------------------------------------------------------- + + private static class Identity implements MapFunction { + + @Override + public T map(T value) { + return value; + } + } + + // ----------------------------------------------------------------------- + + private static class IdentityWithHook extends Identity implements WithMasterCheckpointHook { + + private final String id; + + IdentityWithHook(String id) { + this.id = id; + } + + @Override + public TestHook createMasterTriggerRestoreHook() { + return new TestHook(id); + } + } +} diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/StreamRecordWriterTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/StreamRecordWriterTest.java index 38741baac2022..54cd186773c94 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/StreamRecordWriterTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/StreamRecordWriterTest.java @@ -29,12 +29,9 @@ import org.apache.flink.types.LongValue; import org.junit.Test; -import org.junit.runner.RunWith; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; -import org.powermock.core.classloader.annotations.PrepareForTest; -import org.powermock.modules.junit4.PowerMockRunner; import java.io.IOException; @@ -45,8 +42,6 @@ * This test uses the PowerMockRunner runner to work around the fact that the * {@link ResultPartitionWriter} class is final. */ -@RunWith(PowerMockRunner.class) -@PrepareForTest(ResultPartitionWriter.class) public class StreamRecordWriterTest { /** diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceExternalCheckpointTriggerTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceExternalCheckpointTriggerTest.java new file mode 100644 index 0000000000000..e5caff35777d8 --- /dev/null +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceExternalCheckpointTriggerTest.java @@ -0,0 +1,171 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.runtime.tasks; + +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.core.testutils.MultiShotLatch; +import org.apache.flink.core.testutils.OneShotLatch; +import org.apache.flink.runtime.checkpoint.CheckpointMetaData; +import org.apache.flink.runtime.checkpoint.CheckpointOptions; +import org.apache.flink.runtime.checkpoint.MasterTriggerRestoreHook; +import org.apache.flink.runtime.io.network.api.CheckpointBarrier; +import org.apache.flink.streaming.api.checkpoint.ExternallyInducedSource; +import org.apache.flink.streaming.api.functions.source.ParallelSourceFunction; +import org.apache.flink.streaming.api.graph.StreamConfig; +import org.apache.flink.streaming.api.operators.StreamSource; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; + +import org.junit.Test; + +import java.util.concurrent.BlockingQueue; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +/** + * These tests verify the behavior of a source function that triggers checkpoints + * in response to received events. + */ +@SuppressWarnings("serial") +public class SourceExternalCheckpointTriggerTest { + + static final OneShotLatch ready = new OneShotLatch(); + static final MultiShotLatch sync = new MultiShotLatch(); + + @Test + public void testCheckpointsTriggeredBySource() throws Exception { + // set up the basic test harness + final SourceStreamTask sourceTask = new SourceStreamTask>(); + final StreamTaskTestHarness testHarness = new StreamTaskTestHarness<>(sourceTask, BasicTypeInfo.LONG_TYPE_INFO); + testHarness.setupOutputForSingletonOperatorChain(); + testHarness.getExecutionConfig().setLatencyTrackingInterval(-1); + + final long numElements = 10; + final long checkpointEvery = 3; + + // set up the source function + ExternalCheckpointsSource source = new ExternalCheckpointsSource(numElements, checkpointEvery); + StreamConfig streamConfig = testHarness.getStreamConfig(); + StreamSource sourceOperator = new StreamSource<>(source); + streamConfig.setStreamOperator(sourceOperator); + + // this starts the source thread + testHarness.invoke(); + ready.await(); + + // now send an external trigger that should be ignored + assertTrue(sourceTask.triggerCheckpoint(new CheckpointMetaData(32, 829), CheckpointOptions.forFullCheckpoint())); + + // step by step let the source thread emit elements + sync.trigger(); + verifyNextElement(testHarness.getOutput(), 1L); + sync.trigger(); + verifyNextElement(testHarness.getOutput(), 2L); + sync.trigger(); + verifyNextElement(testHarness.getOutput(), 3L); + + verifyCheckpointBarrier(testHarness.getOutput(), 1L); + + sync.trigger(); + verifyNextElement(testHarness.getOutput(), 4L); + + // now send an regular trigger command that should be ignored + assertTrue(sourceTask.triggerCheckpoint(new CheckpointMetaData(34, 900), CheckpointOptions.forFullCheckpoint())); + + sync.trigger(); + verifyNextElement(testHarness.getOutput(), 5L); + sync.trigger(); + verifyNextElement(testHarness.getOutput(), 6L); + + verifyCheckpointBarrier(testHarness.getOutput(), 2L); + + // let the remainder run + + for (long l = 7L, checkpoint = 3L; l <= numElements; l++) { + sync.trigger(); + verifyNextElement(testHarness.getOutput(), l); + + if (l % checkpointEvery == 0) { + verifyCheckpointBarrier(testHarness.getOutput(), checkpoint++); + } + } + + // done! + } + + @SuppressWarnings("unchecked") + private void verifyNextElement(BlockingQueue output, long expectedElement) throws InterruptedException { + Object next = output.take(); + assertTrue("next element is not an event", next instanceof StreamRecord); + assertEquals("wrong event", expectedElement, ((StreamRecord) next).getValue().longValue()); + } + + private void verifyCheckpointBarrier(BlockingQueue output, long checkpointId) throws InterruptedException { + Object next = output.take(); + assertTrue("next element is not a checkpoint barrier", next instanceof CheckpointBarrier); + assertEquals("wrong checkpoint id", checkpointId, ((CheckpointBarrier) next).getId()); + } + + // ------------------------------------------------------------------------ + + private static class ExternalCheckpointsSource + implements ParallelSourceFunction, ExternallyInducedSource { + + private final long numEvents; + private final long checkpointFrequency; + + private CheckpointTrigger trigger; + + ExternalCheckpointsSource(long numEvents, long checkpointFrequency) { + this.numEvents = numEvents; + this.checkpointFrequency = checkpointFrequency; + } + + @Override + public void run(SourceContext ctx) throws Exception { + ready.trigger(); + + // for simplicity in this test, we just trigger checkpoints in ascending order + long checkpoint = 1; + + for (long num = 1; num <= numEvents; num++) { + sync.await(); + ctx.collect(num); + if (num % checkpointFrequency == 0) { + trigger.triggerCheckpoint(checkpoint++); + } + } + } + + @Override + public void cancel() {} + + @Override + public void setCheckpointTrigger(CheckpointTrigger checkpointTrigger) { + this.trigger = checkpointTrigger; + } + + @Override + public MasterTriggerRestoreHook createMasterTriggerRestoreHook() { + // not relevant in this test + throw new UnsupportedOperationException("not implemented"); + } + } +} + diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTestHarness.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTestHarness.java index c51af4e1b7610..0be85b1968b13 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTestHarness.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTestHarness.java @@ -45,6 +45,7 @@ import java.util.LinkedList; import java.util.List; import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.LinkedBlockingQueue; /** * Test harness for testing a {@link StreamTask}. @@ -83,7 +84,7 @@ public class StreamTaskTestHarness { private TypeSerializer outputSerializer; private TypeSerializer outputStreamRecordSerializer; - private ConcurrentLinkedQueue outputList; + private LinkedBlockingQueue outputList; protected TaskThread taskThread; @@ -125,7 +126,7 @@ protected void initializeInputs() throws IOException, InterruptedException {} @SuppressWarnings("unchecked") private void initializeOutput() { - outputList = new ConcurrentLinkedQueue(); + outputList = new LinkedBlockingQueue(); mockEnv.addOutput(outputList, outputStreamRecordSerializer); } @@ -265,7 +266,7 @@ public void waitForTaskRunning(long timeout) throws Exception { * {@link org.apache.flink.streaming.util.TestHarnessUtil#getRawElementsFromOutput(java.util.Queue)}} * to extract only the StreamRecords. */ - public ConcurrentLinkedQueue getOutput() { + public LinkedBlockingQueue getOutput() { return outputList; } diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/SavepointITCase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/SavepointITCase.java index 3718a947f9221..e0de7d2f39bfd 100644 --- a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/SavepointITCase.java +++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/SavepointITCase.java @@ -40,7 +40,7 @@ import org.apache.flink.runtime.akka.AkkaUtils; import org.apache.flink.runtime.checkpoint.SubtaskState; import org.apache.flink.runtime.checkpoint.TaskState; -import org.apache.flink.runtime.checkpoint.savepoint.SavepointV1; +import org.apache.flink.runtime.checkpoint.savepoint.SavepointV2; import org.apache.flink.runtime.client.JobExecutionException; import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor; import org.apache.flink.runtime.executiongraph.TaskInformation; @@ -200,7 +200,7 @@ public void testTriggerSavepointAndResumeWithFileBasedCheckpoints() throws Excep LOG.info("Requesting the savepoint."); Future savepointFuture = jobManager.ask(new RequestSavepoint(savepointPath), deadline.timeLeft()); - SavepointV1 savepoint = (SavepointV1) ((ResponseSavepoint) Await.result(savepointFuture, deadline.timeLeft())).savepoint(); + SavepointV2 savepoint = (SavepointV2) ((ResponseSavepoint) Await.result(savepointFuture, deadline.timeLeft())).savepoint(); LOG.info("Retrieved savepoint: " + savepointPath + "."); // Shut down the Flink cluster (thereby canceling the job)