From 5caf3fd699087d503a9108585c8e6cf62b4090a2 Mon Sep 17 00:00:00 2001 From: "xiaogang.sxg" Date: Mon, 13 Mar 2017 19:23:47 +0800 Subject: [PATCH 1/4] Allow registration of state objects in checkpoints --- .../checkpoint/CheckpointCoordinator.java | 14 ++- .../checkpoint/CompletedCheckpoint.java | 41 ++++++--- .../checkpoint/CompletedCheckpointStore.java | 2 +- .../StandaloneCompletedCheckpointStore.java | 18 +++- .../runtime/checkpoint/SubtaskState.java | 52 ++++++++--- .../flink/runtime/checkpoint/TaskState.java | 23 ++++- .../ZooKeeperCompletedCheckpointStore.java | 86 +++++++++++++++-- .../runtime/state/CompositeStateHandle.java | 49 ++++++++++ .../runtime/state/SharedStateHandle.java | 38 ++++++++ .../runtime/state/SharedStateRegistry.java | 92 +++++++++++++++++++ .../CheckpointCoordinatorFailureTest.java | 21 +++-- .../CompletedCheckpointStoreTest.java | 10 +- .../checkpoint/CompletedCheckpointTest.java | 10 +- .../jobmanager/JobManagerHARecoveryTest.java | 10 +- 14 files changed, 405 insertions(+), 61 deletions(-) create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/state/CompositeStateHandle.java create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateHandle.java create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateRegistry.java diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java index cc608377a21cf..1fd3b2e220aee 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java @@ -792,14 +792,22 @@ private void completePendingCheckpoint(PendingCheckpoint pendingCheckpoint) thro rememberRecentCheckpointId(checkpointId); dropSubsumedCheckpoints(checkpointId); - } - catch (Exception exception) { + } catch (Exception exception) { // abort the current pending checkpoint if it has not been discarded yet if (!pendingCheckpoint.isDiscarded()) { pendingCheckpoint.abortError(exception); } if (completedCheckpoint != null) { + + // TODO:: fix possible recovery from corrupted checkpoints + // The completed checkpoint may have already been added into + // the store, but the method may still throw an exception + // due to other operations performed later (e.g., the subsuming + // of old checkpoints). To make the code work properly here, the + // store should not throw any exception if the checkpoint is + // already in the store. + // we failed to store the completed checkpoint. Let's clean up final CompletedCheckpoint cc = completedCheckpoint; @@ -807,7 +815,7 @@ private void completePendingCheckpoint(PendingCheckpoint pendingCheckpoint) thro @Override public void run() { try { - cc.discard(); + cc.discard(cc.getTaskStates().values()); } catch (Throwable t) { LOG.warn("Could not properly discard completed checkpoint {}.", cc.getCheckpointID(), t); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpoint.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpoint.java index 17ce4d51b8bab..bc6786a1facb0 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpoint.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpoint.java @@ -23,6 +23,8 @@ import org.apache.flink.runtime.checkpoint.CompletedCheckpointStats.DiscardCallback; import org.apache.flink.runtime.jobgraph.JobStatus; import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.state.StateObject; +import org.apache.flink.runtime.state.StateRegistry; import org.apache.flink.runtime.state.StateUtil; import org.apache.flink.runtime.state.StreamStateHandle; import org.apache.flink.util.ExceptionUtils; @@ -32,6 +34,8 @@ import javax.annotation.Nullable; import java.io.Serializable; +import java.util.Collection; +import java.util.List; import java.util.Map; import static org.apache.flink.util.Preconditions.checkArgument; @@ -184,22 +188,36 @@ public CheckpointProperties getProperties() { return props; } - public boolean subsume() throws Exception { + public void register(StateRegistry stateRegistry) { + for (TaskState taskState : taskStates.values()) { + taskState.register(stateRegistry); + } + } + + public List unregister(StateRegistry stateRegistry) { + for (TaskState taskState : taskStates.values()) { + taskState.unregister(stateRegistry); + } + + return stateRegistry.getAndResetDiscardedStates(); + } + + public boolean subsume(Collection discardedStates) throws Exception { if (props.discardOnSubsumed()) { - discard(); + discard(discardedStates); return true; } return false; } - public boolean discard(JobStatus jobStatus) throws Exception { + public boolean discard(JobStatus jobStatus, Collection discardedStates) throws Exception { if (jobStatus == JobStatus.FINISHED && props.discardOnJobFinished() || jobStatus == JobStatus.CANCELED && props.discardOnJobCancelled() || jobStatus == JobStatus.FAILED && props.discardOnJobFailed() || jobStatus == JobStatus.SUSPENDED && props.discardOnJobSuspended()) { - discard(); + discard(discardedStates); return true; } else { if (externalPointer != null) { @@ -211,7 +229,7 @@ public boolean discard(JobStatus jobStatus) throws Exception { } } - void discard() throws Exception { + void discard(Collection discardedStates) throws Exception { try { // collect exceptions and continue cleanup Exception exception = null; @@ -220,25 +238,22 @@ void discard() throws Exception { if (externalizedMetadata != null) { try { externalizedMetadata.discardState(); - } - catch (Exception e) { + } catch (Exception e) { exception = e; } } - // drop the actual state + // drop unreferenced state objects try { - StateUtil.bestEffortDiscardAllStateObjects(taskStates.values()); - } - catch (Exception e) { + StateUtil.bestEffortDiscardAllStateObjects(discardedStates); + } catch (Exception e) { exception = ExceptionUtils.firstOrSuppressed(e, exception); } if (exception != null) { throw exception; } - } - finally { + } finally { taskStates.clear(); // to be null-pointer safe, copy reference to stack diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStore.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStore.java index 9c2b19983b109..0c721bd75cda0 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStore.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStore.java @@ -40,7 +40,7 @@ public interface CompletedCheckpointStore { * *

Only a bounded number of checkpoints is kept. When exceeding the maximum number of * retained checkpoints, the oldest one will be discarded via {@link - * CompletedCheckpoint#discard()}. + * CompletedCheckpoint#subsume(org.apache.flink.runtime.state.StateRegistry)} )}. */ void addCheckpoint(CompletedCheckpoint checkpoint) throws Exception; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStore.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStore.java index 6eb5242ce2e26..7b1322e89a3c2 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStore.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStore.java @@ -20,6 +20,8 @@ import org.apache.flink.runtime.jobgraph.JobStatus; import org.apache.flink.runtime.jobmanager.HighAvailabilityMode; +import org.apache.flink.runtime.state.StateObject; +import org.apache.flink.runtime.state.StateRegistry; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -39,6 +41,9 @@ public class StandaloneCompletedCheckpointStore implements CompletedCheckpointSt /** The maximum number of checkpoints to retain (at least 1). */ private final int maxNumberOfCheckpointsToRetain; + /** The registry for completed checkpoints to register used state objects. */ + private final StateRegistry stateRegistry; + /** The completed checkpoints. */ private final ArrayDeque checkpoints; @@ -53,6 +58,7 @@ public StandaloneCompletedCheckpointStore(int maxNumberOfCheckpointsToRetain) { checkArgument(maxNumberOfCheckpointsToRetain >= 1, "Must retain at least one checkpoint."); this.maxNumberOfCheckpointsToRetain = maxNumberOfCheckpointsToRetain; this.checkpoints = new ArrayDeque<>(maxNumberOfCheckpointsToRetain + 1); + this.stateRegistry = new StateRegistry(); } @Override @@ -62,10 +68,15 @@ public void recover() throws Exception { @Override public void addCheckpoint(CompletedCheckpoint checkpoint) throws Exception { - checkpoints.add(checkpoint); + + checkpoint.register(stateRegistry); + checkpoints.addLast(checkpoint); + if (checkpoints.size() > maxNumberOfCheckpointsToRetain) { try { - checkpoints.remove().subsume(); + CompletedCheckpoint oldCheckpoint = checkpoints.remove(); + List discardedStates = oldCheckpoint.unregister(stateRegistry); + oldCheckpoint.subsume(discardedStates); } catch (Exception e) { LOG.warn("Fail to subsume the old checkpoint.", e); } @@ -98,7 +109,8 @@ public void shutdown(JobStatus jobStatus) throws Exception { LOG.info("Shutting down"); for (CompletedCheckpoint checkpoint : checkpoints) { - checkpoint.discard(jobStatus); + List discardedStates = checkpoint.unregister(stateRegistry); + checkpoint.discard(jobStatus, discardedStates); } } finally { checkpoints.clear(); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java index 9e195b116229c..f6da7b02cbc13 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java @@ -19,11 +19,15 @@ package org.apache.flink.runtime.checkpoint; import org.apache.flink.runtime.state.ChainedStateHandle; +import org.apache.flink.runtime.state.CompositeStateHandle; import org.apache.flink.runtime.state.KeyedStateHandle; import org.apache.flink.runtime.state.OperatorStateHandle; import org.apache.flink.runtime.state.StateObject; +import org.apache.flink.runtime.state.StateRegistry; import org.apache.flink.runtime.state.StateUtil; import org.apache.flink.runtime.state.StreamStateHandle; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.util.Arrays; @@ -33,7 +37,9 @@ * Container for the chained state of one parallel subtask of an operator/task. This is part of the * {@link TaskState}. */ -public class SubtaskState implements StateObject { +public class SubtaskState implements CompositeStateHandle { + + private static final Logger LOG = LoggerFactory.getLogger(SubtaskState.class); private static final long serialVersionUID = -2394696997971923995L; @@ -123,19 +129,41 @@ public KeyedStateHandle getRawKeyedState() { } @Override - public long getStateSize() { - return stateSize; + public void discardState() { + try { + StateUtil.bestEffortDiscardAllStateObjects( + Arrays.asList( + legacyOperatorState, + managedOperatorState, + rawOperatorState, + managedKeyedState, + rawKeyedState)); + } catch (Exception e) { + LOG.warn("Error while discarding operator states.", e); + } } @Override - public void discardState() throws Exception { - StateUtil.bestEffortDiscardAllStateObjects( - Arrays.asList( - legacyOperatorState, - managedOperatorState, - rawOperatorState, - managedKeyedState, - rawKeyedState)); + public void register(StateRegistry stateRegistry) { + stateRegistry.register(legacyOperatorState); + stateRegistry.register(managedOperatorState); + stateRegistry.register(rawOperatorState); + stateRegistry.register(managedKeyedState); + stateRegistry.register(rawKeyedState); + } + + @Override + public void unregister(StateRegistry stateRegistry) { + stateRegistry.unregister(legacyOperatorState); + stateRegistry.unregister(managedOperatorState); + stateRegistry.unregister(rawOperatorState); + stateRegistry.unregister(managedKeyedState); + stateRegistry.unregister(rawKeyedState); + } + + @Override + public long getStateSize() { + return stateSize; } // -------------------------------------------------------------------------------------------- @@ -199,7 +227,7 @@ public String toString() { ", operatorStateFromBackend=" + managedOperatorState + ", operatorStateFromStream=" + rawOperatorState + ", keyedStateFromBackend=" + managedKeyedState + - ", keyedStateHandleFromStream=" + rawKeyedState + + ", keyedStateFromStream=" + rawKeyedState + ", stateSize=" + stateSize + '}'; } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskState.java index 76f1c510c50bf..52093fb15bbb4 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskState.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskState.java @@ -19,8 +19,8 @@ package org.apache.flink.runtime.checkpoint; import org.apache.flink.runtime.jobgraph.JobVertexID; -import org.apache.flink.runtime.state.StateObject; -import org.apache.flink.runtime.state.StateUtil; +import org.apache.flink.runtime.state.CompositeStateHandle; +import org.apache.flink.runtime.state.StateRegistry; import org.apache.flink.util.Preconditions; import java.util.Collection; @@ -35,7 +35,7 @@ * * This class basically groups all non-partitioned state and key-group state belonging to the same job vertex together. */ -public class TaskState implements StateObject { +public class TaskState implements CompositeStateHandle { private static final long serialVersionUID = -4845578005863201810L; @@ -124,9 +124,24 @@ public boolean hasNonPartitionedState() { @Override public void discardState() throws Exception { - StateUtil.bestEffortDiscardAllStateObjects(subtaskStates.values()); + for (SubtaskState subtaskState : subtaskStates.values()) { + subtaskState.discardState(); + } + } + + @Override + public void register(StateRegistry snapshotRegistry) { + for (SubtaskState subtaskState : subtaskStates.values()) { + subtaskState.register(snapshotRegistry); + } } + @Override + public void unregister(StateRegistry snapshotRegistry) { + for (SubtaskState subtaskState : subtaskStates.values()) { + subtaskState.unregister(snapshotRegistry); + } + } @Override public long getStateSize() { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStore.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStore.java index af7bcc4f57f73..9d578f9bf21f2 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStore.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStore.java @@ -27,6 +27,8 @@ import org.apache.flink.runtime.jobgraph.JobStatus; import org.apache.flink.runtime.jobmanager.HighAvailabilityMode; import org.apache.flink.runtime.state.RetrievableStateHandle; +import org.apache.flink.runtime.state.StateObject; +import org.apache.flink.runtime.state.StateRegistry; import org.apache.flink.runtime.zookeeper.RetrievableStateStorageHelper; import org.apache.flink.runtime.zookeeper.ZooKeeperStateHandleStore; import org.apache.flink.util.FlinkException; @@ -80,6 +82,9 @@ public class ZooKeeperCompletedCheckpointStore implements CompletedCheckpointSto /** The maximum number of checkpoints to retain (at least 1). */ private final int maxNumberOfCheckpointsToRetain; + /** The registry for completed checkpoints to register used state objects. */ + private final StateRegistry stateRegistry; + /** Local completed checkpoints. */ private final ArrayDeque, String>> checkpointStateHandles; @@ -124,6 +129,8 @@ public ZooKeeperCompletedCheckpointStore( this.checkpointStateHandles = new ArrayDeque<>(maxNumberOfCheckpointsToRetain + 1); + this.stateRegistry = new StateRegistry(); + LOG.info("Initialized in '{}'.", checkpointsPath); } @@ -163,10 +170,40 @@ public void recover() throws Exception { int numberOfInitialCheckpoints = initialCheckpoints.size(); LOG.info("Found {} checkpoints in ZooKeeper.", numberOfInitialCheckpoints); + + for (Tuple2, String> initialCheckpoint : initialCheckpoints) { + long checkpointId = pathToCheckpointId(initialCheckpoint.f1); + LOG.info("Trying to retrieve checkpoint {}.", checkpointId); + + CompletedCheckpoint completedCheckpoint; + try { + completedCheckpoint = initialCheckpoint.f0.retrieveState(); + } catch (Exception e) { + throw new Exception( + "Could not retrieve the completed checkpoint " + + checkpointId + " from the state storage.", e); + } + + completedCheckpoint.register(stateRegistry); + } + + // Discard the checkpoints here. + for (int i = 0; i < numberOfInitialCheckpoints - 1; ++i) { + Tuple2, String> initialCheckpoint = initialCheckpoints.get(i); - for (Tuple2, String> checkpoint : initialCheckpoints) { - checkpointStateHandles.add(checkpoint); + try { + removeSubsumed(initialCheckpoint); + } catch (Exception e) { + LOG.error("Failed to discard checkpoint", e); + } } + + // Take the last one. This is the latest checkpoints, because path names are strictly + // increasing (checkpoint ID). + Tuple2, String> latest = initialCheckpoints + .get(numberOfInitialCheckpoints - 1); + + checkpointStateHandles.add(latest); } /** @@ -177,12 +214,29 @@ public void recover() throws Exception { @Override public void addCheckpoint(CompletedCheckpoint checkpoint) throws Exception { checkNotNull(checkpoint, "Checkpoint"); + + final String path = checkpointIdToPath(checkpoint.getCheckpointID()); + final RetrievableStateHandle stateHandle; + + // Register the states in the checkpoint before performing any subsuming + synchronized (stateRegistry) { + checkpoint.register(stateRegistry); + } // First add the new one. If it fails, we don't want to loose existing data. - String path = checkpointIdToPath(checkpoint.getCheckpointID()); + try { + stateHandle = checkpointsInZooKeeper.add(path, checkpoint); + } catch (Exception e) { + // Unregister the states if we fail to add the new checkpoint. The + // list of unreferenced state objects produced by the unregistration + // is ignored here. The coordinator is supposed to discard these + // state objects correctly. + synchronized (stateRegistry) { + checkpoint.unregister(stateRegistry); + } - final RetrievableStateHandle stateHandle = - checkpointsInZooKeeper.add(path, checkpoint); + throw e; + } checkpointStateHandles.addLast(new Tuple2<>(stateHandle, path)); @@ -293,7 +347,16 @@ private void removeSubsumed(final Tuple2 action = new Callable() { @Override public Void call() throws Exception { - stateHandleAndPath.f0.retrieveState().subsume(); + + CompletedCheckpoint completedCheckpoint = stateHandleAndPath.f0.retrieveState(); + + List discardedStates; + synchronized (stateRegistry) { + discardedStates = completedCheckpoint.unregister(stateRegistry); + } + + stateHandleAndPath.f0.retrieveState().subsume(discardedStates); + return null; } }; @@ -308,8 +371,15 @@ private void removeShutdown( Callable action = new Callable() { @Override public Void call() throws Exception { - CompletedCheckpoint checkpoint = stateHandleAndPath.f0.retrieveState(); - checkpoint.discard(jobStatus); + CompletedCheckpoint completedCheckpoint = stateHandleAndPath.f0.retrieveState(); + + List discardedStates; + synchronized (stateRegistry) { + discardedStates = completedCheckpoint.unregister(stateRegistry); + } + + stateHandleAndPath.f0.retrieveState().discard(jobStatus, discardedStates); + return null; } }; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/CompositeStateHandle.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/CompositeStateHandle.java new file mode 100644 index 0000000000000..33b7c55a32bfb --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/CompositeStateHandle.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state; + +/** + * Base of all snapshots that are taken by {@link StateBackend}s and some other + * components in tasks. + * + *

Each snapshot is composed of a collection of {@link StateObject}s. The + * {@link StateObject}s in a completed checkpoint may be referenced by other + * completed checkpoints. To avoid the deletion of those objects still in use, + * the handle should register all its objects when the checkpoint completes and + * unregister its objects when the checkpoint is discarded. + */ +public interface CompositeStateHandle extends StateObject { + + /** + * This method is called when the checkpoint is added into + * {@link org.apache.flink.runtime.checkpoint.CompletedCheckpointStore}. + * That happens when the pending checkpoint succeeds to complete or the + * completed checkpoint is reloaded in the recovery. In both cases, the + * snapshot handle should register all its objects in the given + * {@link StateRegistry}. + */ + void register(StateRegistry stateRegistry); + + /** + * This method is called when the completed checkpoint is discarded. In such + * cases, the snapshot handle should unregister all its objects. An object + * will be deleted if it is not referenced by any checkpoint. + */ + void unregister(StateRegistry stateRegistry); +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateHandle.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateHandle.java new file mode 100644 index 0000000000000..ad340f5651c33 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateHandle.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state; + +/** + * Base for those state handles that are shared among different checkpoints. + * + * Each shared state handle is identified by an unique key. It will be + * registered at the {@link SharedStateRegistry} once it is received by the + * {@link org.apache.flink.runtime.checkpoint.CheckpointCoordinator}. Each + * registered state handle is unregistered when the checkpoint to which the + * state handle belongs is discarded. + */ +public interface SharedStateHandle extends StateObject { + + /** + * Returns the unique identifier of the state handle + * + * @return the unique identifier of the state handle + */ + String getKey(); +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateRegistry.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateRegistry.java new file mode 100644 index 0000000000000..a171a2f97d3fe --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateRegistry.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state; + +import org.apache.flink.api.java.tuple.Tuple2; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * A {@code SharedStateRegistry} will be deployed in the + * {@link org.apache.flink.runtime.checkpoint.CheckpointCoordinator} to + * maintain the reference count of the {@link SharedStateHandle}s. + */ +public class SharedStateRegistry { + + /** All registered state objects */ + private final Map> registeredStates = new HashMap<>(); + + /** All state objects that are not referenced any more */ + private List discardedStates = new ArrayList<>(); + + /** + * Register the state in the registry + * + * @param state The state to register + */ + public void register(StateObject state) { + synchronized (this) { + Integer referenceCount = registeredStates.get(state); + + if (referenceCount != null) { + registeredStates.put(state, referenceCount + 1); + } else { + registeredStates.put(state, 1); + } + } + } + + /** + * Decrease the reference count of the state in the registry + * + * @param state The state to unregister + */ + public void unregister(StateObject state) { + synchronized (this) { + Integer referenceCount = registeredStates.get(state); + + if (referenceCount == null) { + throw new IllegalStateException("Cannot unregister an unexisted state."); + } + + referenceCount--; + + if (referenceCount == 0) { + registeredStates.remove(state); + discardedStates.add(state); + } + } + } + + /** + * Gets and resets the list of discarded state objects + * + * @return A list of cached unreferenced state objects + */ + public List getAndResetDiscardedStates() { + synchronized (this) { + List result = new ArrayList<>(discardedStates); + discardedStates.clear(); + return result; + } + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java index 340e2a78607e8..5d7e0e07de8de 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java @@ -25,13 +25,17 @@ import org.apache.flink.runtime.jobgraph.JobStatus; import org.apache.flink.runtime.jobgraph.tasks.ExternalizedCheckpointSettings; import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint; +import org.apache.flink.runtime.state.ChainedStateHandle; +import org.apache.flink.runtime.state.StreamStateHandle; import org.apache.flink.util.TestLogger; +import org.junit.Ignore; import org.junit.Test; import org.junit.runner.RunWith; import org.powermock.api.mockito.PowerMockito; import org.powermock.core.classloader.annotations.PrepareForTest; import org.powermock.modules.junit4.PowerMockRunner; +import java.util.Collections; import java.util.List; import static org.junit.Assert.assertEquals; @@ -40,6 +44,7 @@ import static org.junit.Assert.fail; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; @RunWith(PowerMockRunner.class) @PrepareForTest(PendingCheckpoint.class) @@ -83,12 +88,13 @@ public void testFailingCompletedCheckpointStoreAdd() throws Exception { assertFalse(pendingCheckpoint.isDiscarded()); final long checkpointId = coord.getPendingCheckpoints().keySet().iterator().next(); - - AcknowledgeCheckpoint acknowledgeMessage = new AcknowledgeCheckpoint(jid, executionAttemptId, checkpointId); - - CompletedCheckpoint completedCheckpoint = mock(CompletedCheckpoint.class); - PowerMockito.whenNew(CompletedCheckpoint.class).withAnyArguments().thenReturn(completedCheckpoint); - + + SubtaskState subtaskState = mock(SubtaskState.class); + PowerMockito.when(subtaskState.getLegacyOperatorState()).thenReturn(null); + PowerMockito.when(subtaskState.getManagedOperatorState()).thenReturn(null); + + AcknowledgeCheckpoint acknowledgeMessage = new AcknowledgeCheckpoint(jid, executionAttemptId, checkpointId, new CheckpointMetrics(), subtaskState); + try { coord.receiveAcknowledgeMessage(acknowledgeMessage); fail("Expected a checkpoint exception because the completed checkpoint store could not " + @@ -100,7 +106,8 @@ public void testFailingCompletedCheckpointStoreAdd() throws Exception { // make sure that the pending checkpoint has been discarded after we could not complete it assertTrue(pendingCheckpoint.isDiscarded()); - verify(completedCheckpoint).discard(); + // make sure that the subtask state has been discarded after we could not complete it. + verify(subtaskState).discardState(); } private static final class FailingCompletedCheckpointStore implements CompletedCheckpointStore { diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java index f77c755797827..f6a7fdaed8470 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java @@ -23,11 +23,13 @@ import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.messages.CheckpointMessagesTest; import org.apache.flink.runtime.state.ChainedStateHandle; +import org.apache.flink.runtime.state.StateObject; import org.apache.flink.runtime.state.StreamStateHandle; import org.apache.flink.util.TestLogger; import org.junit.Test; import java.io.IOException; +import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -241,8 +243,8 @@ public TestCompletedCheckpoint( } @Override - public boolean subsume() throws Exception { - if (super.subsume()) { + public boolean subsume(Collection discardedStates) throws Exception { + if (super.subsume(discardedStates)) { discard(); return true; } else { @@ -251,8 +253,8 @@ public boolean subsume() throws Exception { } @Override - public boolean discard(JobStatus jobStatus) throws Exception { - if (super.discard(jobStatus)) { + public boolean discard(JobStatus jobStatus, Collection discardedStates) throws Exception { + if (super.discard(jobStatus, discardedStates)) { discard(); return true; } else { diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointTest.java index b34e9a6ac26f5..d0c98ac3c2df0 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointTest.java @@ -61,7 +61,7 @@ public void testDiscard() throws Exception { new FileStateHandle(new Path(file.toURI()), file.length()), file.getAbsolutePath()); - checkpoint.discard(JobStatus.FAILED); + checkpoint.discard(JobStatus.FAILED, checkpoint.getTaskStates().values()); assertEquals(false, file.exists()); } @@ -81,7 +81,7 @@ public void testCleanUpOnSubsume() throws Exception { new JobID(), 0, 0, 1, taskStates, props); // Subsume - checkpoint.subsume(); + checkpoint.subsume(taskStates.values()); verify(state, times(1)).discardState(); } @@ -112,7 +112,7 @@ public void testCleanUpOnShutdown() throws Exception { new FileStateHandle(new Path(file.toURI()), file.length()), externalPath); - checkpoint.discard(status); + checkpoint.discard(status, taskStates.values()); verify(state, times(0)).discardState(); assertEquals(true, file.exists()); @@ -121,7 +121,7 @@ public void testCleanUpOnShutdown() throws Exception { checkpoint = new CompletedCheckpoint( new JobID(), 0, 0, 1, new HashMap<>(taskStates), props); - checkpoint.discard(status); + checkpoint.discard(taskStates.values()); verify(state, times(1)).discardState(); } } @@ -146,7 +146,7 @@ public void testCompletedCheckpointStatsCallbacks() throws Exception { CompletedCheckpointStats.DiscardCallback callback = mock(CompletedCheckpointStats.DiscardCallback.class); completed.setDiscardCallback(callback); - completed.discard(JobStatus.FINISHED); + completed.discard(JobStatus.FINISHED, taskStates.values()); verify(callback, times(1)).notifyDiscardedCheckpoint(); } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java index dcf4722a7313e..a2cda84428aff 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java @@ -71,6 +71,8 @@ import org.apache.flink.runtime.messages.JobManagerMessages; import org.apache.flink.runtime.metrics.MetricRegistry; import org.apache.flink.runtime.state.ChainedStateHandle; +import org.apache.flink.runtime.state.StateObject; +import org.apache.flink.runtime.state.StateRegistry; import org.apache.flink.runtime.state.StreamStateHandle; import org.apache.flink.runtime.state.TaskStateHandles; import org.apache.flink.runtime.state.memory.ByteStreamStateHandle; @@ -448,6 +450,8 @@ static class MyCheckpointStore implements CompletedCheckpointStore { private final ArrayDeque suspended = new ArrayDeque<>(2); + private final StateRegistry stateRegistry = new StateRegistry(); + @Override public void recover() throws Exception { checkpoints.addAll(suspended); @@ -456,9 +460,13 @@ public void recover() throws Exception { @Override public void addCheckpoint(CompletedCheckpoint checkpoint) throws Exception { + checkpoint.register(stateRegistry); checkpoints.addLast(checkpoint); + if (checkpoints.size() > 1) { - checkpoints.removeFirst().subsume(); + CompletedCheckpoint subsumedCheckpoint = checkpoints.removeFirst(); + List discardedStates = subsumedCheckpoint.unregister(stateRegistry); + subsumedCheckpoint.subsume(discardedStates); } } From bd54de1980755c2557a990e423aef543d982651d Mon Sep 17 00:00:00 2001 From: "xiaogang.sxg" Date: Mon, 13 Mar 2017 19:29:21 +0800 Subject: [PATCH 2/4] remove unnecessary synchronization in StateRegistry --- .../runtime/state/SharedStateRegistry.java | 38 ++++++++----------- 1 file changed, 16 insertions(+), 22 deletions(-) diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateRegistry.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateRegistry.java index a171a2f97d3fe..9f3e76c7e873a 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateRegistry.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateRegistry.java @@ -44,14 +44,12 @@ public class SharedStateRegistry { * @param state The state to register */ public void register(StateObject state) { - synchronized (this) { - Integer referenceCount = registeredStates.get(state); + Integer referenceCount = registeredStates.get(state); - if (referenceCount != null) { - registeredStates.put(state, referenceCount + 1); - } else { - registeredStates.put(state, 1); - } + if (referenceCount != null) { + registeredStates.put(state, referenceCount + 1); + } else { + registeredStates.put(state, 1); } } @@ -61,19 +59,17 @@ public void register(StateObject state) { * @param state The state to unregister */ public void unregister(StateObject state) { - synchronized (this) { - Integer referenceCount = registeredStates.get(state); + Integer referenceCount = registeredStates.get(state); - if (referenceCount == null) { - throw new IllegalStateException("Cannot unregister an unexisted state."); - } + if (referenceCount == null) { + throw new IllegalStateException("Cannot unregister an unexisted state."); + } - referenceCount--; + referenceCount--; - if (referenceCount == 0) { - registeredStates.remove(state); - discardedStates.add(state); - } + if (referenceCount == 0) { + registeredStates.remove(state); + discardedStates.add(state); } } @@ -83,10 +79,8 @@ public void unregister(StateObject state) { * @return A list of cached unreferenced state objects */ public List getAndResetDiscardedStates() { - synchronized (this) { - List result = new ArrayList<>(discardedStates); - discardedStates.clear(); - return result; - } + List result = new ArrayList<>(discardedStates); + discardedStates.clear(); + return result; } } From 0d99c19d0946448d23bd0bfd26d47077dc43f67d Mon Sep 17 00:00:00 2001 From: "xiaogang.sxg" Date: Fri, 24 Mar 2017 15:29:49 +0800 Subject: [PATCH 3/4] Only register shared states in the registry --- .../checkpoint/CheckpointCoordinator.java | 183 ++++++++++---- .../checkpoint/CompletedCheckpoint.java | 32 +-- .../checkpoint/CompletedCheckpointStore.java | 7 +- .../runtime/checkpoint/PendingCheckpoint.java | 23 +- .../StandaloneCompletedCheckpointStore.java | 37 +-- .../runtime/checkpoint/SubtaskState.java | 18 +- .../flink/runtime/checkpoint/TaskState.java | 10 +- .../ZooKeeperCompletedCheckpointStore.java | 114 +++------ .../runtime/state/CompositeStateHandle.java | 36 +-- .../runtime/state/SharedStateHandle.java | 38 --- .../runtime/state/SharedStateRegistry.java | 150 ++++++++++-- .../flink/runtime/jobmanager/JobManager.scala | 27 +-- .../CheckpointCoordinatorFailureTest.java | 13 +- .../checkpoint/CheckpointCoordinatorTest.java | 226 ++++++++++++++---- .../CheckpointStateRestoreTest.java | 4 +- .../CompletedCheckpointStoreTest.java | 48 ++-- .../checkpoint/CompletedCheckpointTest.java | 10 +- ...ecutionGraphCheckpointCoordinatorTest.java | 9 +- .../checkpoint/PendingCheckpointTest.java | 32 ++- ...tandaloneCompletedCheckpointStoreTest.java | 22 +- ...oKeeperCompletedCheckpointStoreITCase.java | 25 +- ...ZooKeeperCompletedCheckpointStoreTest.java | 7 +- .../jobmanager/JobManagerHARecoveryTest.java | 73 +----- .../state/SharedStateRegistryTest.java | 127 ++++++++++ .../RecoverableCompletedCheckpointStore.java | 116 +++++++++ 25 files changed, 923 insertions(+), 464 deletions(-) delete mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateHandle.java create mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/state/SharedStateRegistryTest.java create mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/testutils/RecoverableCompletedCheckpointStore.java diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java index 1fd3b2e220aee..e0ff1980ff61b 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java @@ -21,6 +21,7 @@ import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.api.common.JobID; import org.apache.flink.configuration.ConfigConstants; +import org.apache.flink.runtime.checkpoint.savepoint.SavepointLoader; import org.apache.flink.runtime.checkpoint.savepoint.SavepointStore; import org.apache.flink.runtime.concurrent.ApplyFunction; import org.apache.flink.runtime.concurrent.Future; @@ -36,10 +37,13 @@ import org.apache.flink.runtime.jobgraph.tasks.ExternalizedCheckpointSettings; import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint; import org.apache.flink.runtime.messages.checkpoint.DeclineCheckpoint; +import org.apache.flink.runtime.state.SharedStateRegistry; import org.apache.flink.runtime.state.StateObject; +import org.apache.flink.runtime.state.StateUtil; import org.apache.flink.runtime.state.TaskStateHandles; import org.apache.flink.runtime.taskmanager.DispatcherThreadFactory; +import org.apache.flink.util.Preconditions; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -108,6 +112,9 @@ public class CheckpointCoordinator { /** Completed checkpoints. Implementations can be blocking. Make sure calls to methods * accessing this don't block the job manager actor and run asynchronously. */ private final CompletedCheckpointStore completedCheckpointStore; + + /** Registry for shared states */ + private final SharedStateRegistry sharedStateRegistry; /** Default directory for persistent checkpoints; null if none configured. * THIS WILL BE REPLACED BY PROPER STATE-BACKEND METADATA WRITING */ @@ -218,6 +225,7 @@ public CheckpointCoordinator( this.completedCheckpointStore = checkNotNull(completedCheckpointStore); this.checkpointDirectory = checkpointDirectory; this.executor = checkNotNull(executor); + this.sharedStateRegistry = new SharedStateRegistry(); this.recentPendingCheckpoints = new ArrayDeque<>(NUM_GHOST_CHECKPOINT_IDS); @@ -282,7 +290,7 @@ public void shutdown(JobStatus jobStatus) throws Exception { } pendingCheckpoints.clear(); - completedCheckpointStore.shutdown(jobStatus); + completedCheckpointStore.shutdown(jobStatus, sharedStateRegistry); checkpointIdCounter.shutdown(jobStatus); } } @@ -481,7 +489,8 @@ CheckpointTriggerResult triggerCheckpoint( ackTasks, props, targetDirectory, - executor); + executor, + sharedStateRegistry); if (statsTracker != null) { PendingCheckpointStats callback = statsTracker.reportPendingCheckpoint( @@ -615,7 +624,7 @@ public void receiveDeclineMessage(DeclineCheckpoint message) { throw new IllegalArgumentException("Received DeclineCheckpoint message for job " + message.getJob() + " while this coordinator handles job " + job); } - + final long checkpointId = message.getCheckpointId(); final String reason = (message.getReason() != null ? message.getReason().getMessage() : ""); @@ -695,7 +704,7 @@ public boolean receiveAcknowledgeMessage(AcknowledgeCheckpoint message) throws C } final long checkpointId = message.getCheckpointId(); - + synchronized (lock) { // we need to check inside the lock for being shutdown as well, otherwise we // get races and invalid error log messages @@ -709,6 +718,8 @@ public boolean receiveAcknowledgeMessage(AcknowledgeCheckpoint message) throws C switch (checkpoint.acknowledgeTask(message.getTaskExecutionId(), message.getSubtaskState(), message.getCheckpointMetrics())) { case SUCCESS: + sharedStateRegistry.registerAll(message.getSubtaskState()); + LOG.debug("Received acknowledge message for checkpoint {} from task {} of job {}.", checkpointId, message.getTaskExecutionId(), message.getJob()); @@ -721,6 +732,8 @@ public boolean receiveAcknowledgeMessage(AcknowledgeCheckpoint message) throws C message.getCheckpointId(), message.getTaskExecutionId(), message.getJob()); break; case UNKNOWN: + sharedStateRegistry.registerAll(message.getSubtaskState()); + LOG.warn("Could not acknowledge the checkpoint {} for task {} of job {}, " + "because the task's execution attempt id was unknown. Discarding " + "the state handle to avoid lingering state.", message.getCheckpointId(), @@ -730,6 +743,8 @@ public boolean receiveAcknowledgeMessage(AcknowledgeCheckpoint message) throws C break; case DISCARDED: + sharedStateRegistry.registerAll(message.getSubtaskState()); + LOG.warn("Could not acknowledge the checkpoint {} for task {} of job {}, " + "because the pending checkpoint had been discarded. Discarding the " + "state handle tp avoid lingering state.", @@ -748,6 +763,8 @@ else if (checkpoint != null) { else { boolean wasPendingCheckpoint; + sharedStateRegistry.registerAll(message.getSubtaskState()); + // message is for an unknown checkpoint, or comes too late (checkpoint disposed) if (recentPendingCheckpoints.contains(checkpointId)) { wasPendingCheckpoint = true; @@ -778,57 +795,57 @@ else if (checkpoint != null) { */ private void completePendingCheckpoint(PendingCheckpoint pendingCheckpoint) throws CheckpointException { final long checkpointId = pendingCheckpoint.getCheckpointId(); - CompletedCheckpoint completedCheckpoint = null; + final CompletedCheckpoint completedCheckpoint; try { - // externalize the checkpoint if required - if (pendingCheckpoint.getProps().externalizeCheckpoint()) { - completedCheckpoint = pendingCheckpoint.finalizeCheckpointExternalized(); - } else { - completedCheckpoint = pendingCheckpoint.finalizeCheckpointNonExternalized(); - } - - completedCheckpointStore.addCheckpoint(completedCheckpoint); - - rememberRecentCheckpointId(checkpointId); - dropSubsumedCheckpoints(checkpointId); - } catch (Exception exception) { - // abort the current pending checkpoint if it has not been discarded yet - if (!pendingCheckpoint.isDiscarded()) { - pendingCheckpoint.abortError(exception); + try { + // externalize the checkpoint if required + if (pendingCheckpoint.getProps().externalizeCheckpoint()) { + completedCheckpoint = pendingCheckpoint.finalizeCheckpointExternalized(); + } else { + completedCheckpoint = pendingCheckpoint.finalizeCheckpointNonExternalized(); + } + } catch (Exception e1) { + // abort the current pending checkpoint if we fails to finalize the pending checkpoint. + if (!pendingCheckpoint.isDiscarded()) { + pendingCheckpoint.abortError(e1); + } + + throw new CheckpointException("Could not finalize the pending checkpoint " + checkpointId + '.', e1); } - - if (completedCheckpoint != null) { - - // TODO:: fix possible recovery from corrupted checkpoints - // The completed checkpoint may have already been added into - // the store, but the method may still throw an exception - // due to other operations performed later (e.g., the subsuming - // of old checkpoints). To make the code work properly here, the - // store should not throw any exception if the checkpoint is - // already in the store. - + + // the pending checkpoint must be discarded after the finalization + Preconditions.checkState(pendingCheckpoint.isDiscarded() && completedCheckpoint != null); + + try { + completedCheckpointStore.addCheckpoint(completedCheckpoint, sharedStateRegistry); + } catch (Exception exception) { // we failed to store the completed checkpoint. Let's clean up - final CompletedCheckpoint cc = completedCheckpoint; - executor.execute(new Runnable() { @Override public void run() { + sharedStateRegistry.unregisterAll(completedCheckpoint.getTaskStates().values()); + try { - cc.discard(cc.getTaskStates().values()); + completedCheckpoint.discard(); } catch (Throwable t) { - LOG.warn("Could not properly discard completed checkpoint {}.", cc.getCheckpointID(), t); + LOG.warn("Could not properly discard completed checkpoint {}.", completedCheckpoint.getCheckpointID(), t); } } }); + + throw new CheckpointException("Could not complete the pending checkpoint " + checkpointId + '.', exception); } - - throw new CheckpointException("Could not complete the pending checkpoint " + checkpointId + '.', exception); } finally { pendingCheckpoints.remove(checkpointId); triggerQueuedRequests(); } + + rememberRecentCheckpointId(checkpointId); + + // drop those pending checkpoints that are at prior to the completed one + dropSubsumedCheckpoints(checkpointId); // record the time when this was completed, to calculate // the 'min delay between checkpoints' @@ -951,10 +968,19 @@ public boolean restoreLatestCheckpointedState( // Recover the checkpoints completedCheckpointStore.recover(); - // restore from the latest checkpoint - CompletedCheckpoint latest = completedCheckpointStore.getLatestCheckpoint(); + // Recover the registry for shared states + CompletedCheckpoint latestCompletedCheckpoint = null; + List completedCheckpoints = completedCheckpointStore.getAllCheckpoints(); + for (CompletedCheckpoint completedCheckpoint : completedCheckpoints) { + sharedStateRegistry.registerAll(completedCheckpoint.getTaskStates().values()); - if (latest == null) { + if (latestCompletedCheckpoint == null || + latestCompletedCheckpoint.getCheckpointID() > completedCheckpoint.getCheckpointID()) { + latestCompletedCheckpoint = completedCheckpoint; + } + } + + if (latestCompletedCheckpoint == null) { if (errorIfNoCheckpoint) { throw new IllegalStateException("No completed checkpoint available"); } else { @@ -962,9 +988,9 @@ public boolean restoreLatestCheckpointedState( } } - LOG.info("Restoring from latest valid checkpoint: {}.", latest); + LOG.info("Restoring from latest valid checkpoint: {}.", latestCompletedCheckpoint); - final Map taskStates = latest.getTaskStates(); + final Map taskStates = latestCompletedCheckpoint.getTaskStates(); StateAssignmentOperation stateAssignmentOperation = new StateAssignmentOperation(LOG, tasks, taskStates, allowNonRestoredState); @@ -974,10 +1000,10 @@ public boolean restoreLatestCheckpointedState( if (statsTracker != null) { long restoreTimestamp = System.currentTimeMillis(); RestoredCheckpointStats restored = new RestoredCheckpointStats( - latest.getCheckpointID(), - latest.getProperties(), + latestCompletedCheckpoint.getCheckpointID(), + latestCompletedCheckpoint.getProperties(), restoreTimestamp, - latest.getExternalPointer()); + latestCompletedCheckpoint.getExternalPointer()); statsTracker.reportRestoredCheckpoint(restored); } @@ -986,6 +1012,44 @@ public boolean restoreLatestCheckpointedState( } } + /** + * Restore the state with given savepoint + * + * @param savepointPath Location of the savepoint + * @param allowNonRestored True if allowing checkpoint state that cannot be + * mapped to any job vertex in tasks. + * @param tasks Map of job vertices to restore. State for these + * vertices is restored via + * {@link Execution#setInitialState(TaskStateHandles)}. + * @param userClassLoader The class loader to resolve serialized classes in + * legacy savepoint versions. + */ + public boolean restoreSavepoint( + String savepointPath, + boolean allowNonRestored, + Map tasks, + ClassLoader userClassLoader) throws Exception { + + Preconditions.checkNotNull(savepointPath, "The savepoint path cannot be null."); + + LOG.info("Starting job from savepoint {} ({})", + savepointPath, (allowNonRestored ? "allowing non restored state" : "")); + + // Load the savepoint as a checkpoint into the system + CompletedCheckpoint savepoint = SavepointLoader.loadAndValidateSavepoint( + job, tasks, savepointPath, userClassLoader, allowNonRestored); + + completedCheckpointStore.addCheckpoint(savepoint, sharedStateRegistry); + + // Reset the checkpoint ID counter + long nextCheckpointId = savepoint.getCheckpointID() + 1; + checkpointIdCounter.setCount(nextCheckpointId); + + LOG.info("Reset the checkpoint ID to {}.", nextCheckpointId); + + return restoreLatestCheckpointedState(tasks, true, allowNonRestored); + } + // -------------------------------------------------------------------------------------------- // Accessors // -------------------------------------------------------------------------------------------- @@ -1015,6 +1079,10 @@ public List getSuccessfulCheckpoints() throws Exception { public CompletedCheckpointStore getCheckpointStore() { return completedCheckpointStore; } + + public SharedStateRegistry getSharedStateRegistry() { + return sharedStateRegistry; + } public CheckpointIDCounter getCheckpointIdCounter() { return checkpointIdCounter; @@ -1103,24 +1171,33 @@ public void run() { * @param jobId identifying the job to which the state object belongs * @param executionAttemptID identifying the task to which the state object belongs * @param checkpointId of the state object - * @param stateObject to discard asynchronously + * @param subtaskState to discard asynchronously */ private void discardState( final JobID jobId, final ExecutionAttemptID executionAttemptID, final long checkpointId, - final StateObject stateObject) { + final SubtaskState subtaskState) { - if (stateObject != null) { + if (subtaskState != null) { executor.execute(new Runnable() { @Override public void run() { + List discardedSharedStates = sharedStateRegistry.unregisterAll(subtaskState); + + try { + StateUtil.bestEffortDiscardAllStateObjects(discardedSharedStates); + } catch (Throwable t1) { + LOG.warn("Could not properly discard shared states of checkpoint {} " + + "belonging to task {} of job {}.", checkpointId, executionAttemptID, jobId, t1 + ); + } + try { - stateObject.discardState(); - } catch (Throwable throwable) { - LOG.warn("Could not properly discard state object of checkpoint {} " + - "belonging to task {} of job {}.", checkpointId, executionAttemptID, jobId, - throwable); + subtaskState.discardState(); + } catch (Throwable t2) { + LOG.warn("Could not properly discard state object of checkpoint {} " + + "belonging to task {} of job {}.", checkpointId, executionAttemptID, jobId, t2); } } }); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpoint.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpoint.java index bc6786a1facb0..d1b182fd38967 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpoint.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpoint.java @@ -23,8 +23,6 @@ import org.apache.flink.runtime.checkpoint.CompletedCheckpointStats.DiscardCallback; import org.apache.flink.runtime.jobgraph.JobStatus; import org.apache.flink.runtime.jobgraph.JobVertexID; -import org.apache.flink.runtime.state.StateObject; -import org.apache.flink.runtime.state.StateRegistry; import org.apache.flink.runtime.state.StateUtil; import org.apache.flink.runtime.state.StreamStateHandle; import org.apache.flink.util.ExceptionUtils; @@ -34,8 +32,6 @@ import javax.annotation.Nullable; import java.io.Serializable; -import java.util.Collection; -import java.util.List; import java.util.Map; import static org.apache.flink.util.Preconditions.checkArgument; @@ -188,36 +184,22 @@ public CheckpointProperties getProperties() { return props; } - public void register(StateRegistry stateRegistry) { - for (TaskState taskState : taskStates.values()) { - taskState.register(stateRegistry); - } - } - - public List unregister(StateRegistry stateRegistry) { - for (TaskState taskState : taskStates.values()) { - taskState.unregister(stateRegistry); - } - - return stateRegistry.getAndResetDiscardedStates(); - } - - public boolean subsume(Collection discardedStates) throws Exception { + public boolean subsume() throws Exception { if (props.discardOnSubsumed()) { - discard(discardedStates); + discard(); return true; } return false; } - public boolean discard(JobStatus jobStatus, Collection discardedStates) throws Exception { + public boolean discard(JobStatus jobStatus) throws Exception { if (jobStatus == JobStatus.FINISHED && props.discardOnJobFinished() || jobStatus == JobStatus.CANCELED && props.discardOnJobCancelled() || jobStatus == JobStatus.FAILED && props.discardOnJobFailed() || jobStatus == JobStatus.SUSPENDED && props.discardOnJobSuspended()) { - discard(discardedStates); + discard(); return true; } else { if (externalPointer != null) { @@ -229,7 +211,7 @@ public boolean discard(JobStatus jobStatus, Collection di } } - void discard(Collection discardedStates) throws Exception { + void discard() throws Exception { try { // collect exceptions and continue cleanup Exception exception = null; @@ -243,9 +225,9 @@ void discard(Collection discardedStates) throws Exception } } - // drop unreferenced state objects + // discard private state objects try { - StateUtil.bestEffortDiscardAllStateObjects(discardedStates); + StateUtil.bestEffortDiscardAllStateObjects(taskStates.values()); } catch (Exception e) { exception = ExceptionUtils.firstOrSuppressed(e, exception); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStore.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStore.java index 0c721bd75cda0..27d4d686845e8 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStore.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStore.java @@ -19,6 +19,7 @@ package org.apache.flink.runtime.checkpoint; import org.apache.flink.runtime.jobgraph.JobStatus; +import org.apache.flink.runtime.state.SharedStateRegistry; import java.util.List; @@ -40,9 +41,9 @@ public interface CompletedCheckpointStore { * *

Only a bounded number of checkpoints is kept. When exceeding the maximum number of * retained checkpoints, the oldest one will be discarded via {@link - * CompletedCheckpoint#subsume(org.apache.flink.runtime.state.StateRegistry)} )}. + * CompletedCheckpoint#subsume()} )}. */ - void addCheckpoint(CompletedCheckpoint checkpoint) throws Exception; + void addCheckpoint(CompletedCheckpoint checkpoint, SharedStateRegistry sharedStateRegistry) throws Exception; /** * Returns the latest {@link CompletedCheckpoint} instance or null if none was @@ -58,7 +59,7 @@ public interface CompletedCheckpointStore { * * @param jobStatus Job state on shut down */ - void shutdown(JobStatus jobStatus) throws Exception; + void shutdown(JobStatus jobStatus, SharedStateRegistry sharedStateRegistry) throws Exception; /** * Returns all {@link CompletedCheckpoint} instances. diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java index b7eb037a10c4b..34df45fc3b942 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java @@ -29,6 +29,8 @@ import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.state.ChainedStateHandle; import org.apache.flink.runtime.state.OperatorStateHandle; +import org.apache.flink.runtime.state.SharedStateRegistry; +import org.apache.flink.runtime.state.StateObject; import org.apache.flink.runtime.state.StateUtil; import org.apache.flink.runtime.state.StreamStateHandle; import org.apache.flink.runtime.state.filesystem.FileStateHandle; @@ -43,6 +45,7 @@ import java.io.IOException; import java.util.HashMap; import java.util.HashSet; +import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.Executor; @@ -92,6 +95,9 @@ public class PendingCheckpoint { /** The executor for potentially blocking I/O operations, like state disposal */ private final Executor executor; + /** The registry where shared states are registered */ + private final SharedStateRegistry sharedStateRegistry; + private int numAcknowledgedTasks; private boolean discarded; @@ -111,7 +117,8 @@ public PendingCheckpoint( Map verticesToConfirm, CheckpointProperties props, String targetDirectory, - Executor executor) { + Executor executor, + SharedStateRegistry sharedStateRegistry) { // Sanity check if (props.externalizeCheckpoint() && targetDirectory == null) { @@ -128,6 +135,7 @@ public PendingCheckpoint( this.props = checkNotNull(props); this.targetDirectory = targetDirectory; this.executor = Preconditions.checkNotNull(executor); + this.sharedStateRegistry = Preconditions.checkNotNull(sharedStateRegistry); this.taskStates = new HashMap<>(); this.acknowledgedTasks = new HashSet<>(verticesToConfirm.size()); @@ -491,6 +499,7 @@ public void abortError(Throwable cause) { } private void dispose(boolean releaseState) { + synchronized (lock) { try { numAcknowledgedTasks = -1; @@ -498,11 +507,19 @@ private void dispose(boolean releaseState) { executor.execute(new Runnable() { @Override public void run() { + List unreferencedSharedStates = sharedStateRegistry.unregisterAll(taskStates.values()); + + try { + StateUtil.bestEffortDiscardAllStateObjects(unreferencedSharedStates); + } catch (Throwable t) { + LOG.warn("Could not properly dispose unreferenced shared states."); + } + try { StateUtil.bestEffortDiscardAllStateObjects(taskStates.values()); } catch (Throwable t) { - LOG.warn("Could not properly dispose the pending checkpoint {} of job {}.", - checkpointId, jobId, t); + LOG.warn("Could not properly dispose the private states in the pending checkpoint {} of job {}.", + checkpointId, jobId, t); } finally { taskStates.clear(); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStore.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStore.java index 7b1322e89a3c2..65faef207b2d1 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStore.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStore.java @@ -20,8 +20,9 @@ import org.apache.flink.runtime.jobgraph.JobStatus; import org.apache.flink.runtime.jobmanager.HighAvailabilityMode; +import org.apache.flink.runtime.state.SharedStateRegistry; import org.apache.flink.runtime.state.StateObject; -import org.apache.flink.runtime.state.StateRegistry; +import org.apache.flink.runtime.state.StateUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -41,9 +42,6 @@ public class StandaloneCompletedCheckpointStore implements CompletedCheckpointSt /** The maximum number of checkpoints to retain (at least 1). */ private final int maxNumberOfCheckpointsToRetain; - /** The registry for completed checkpoints to register used state objects. */ - private final StateRegistry stateRegistry; - /** The completed checkpoints. */ private final ArrayDeque checkpoints; @@ -58,7 +56,6 @@ public StandaloneCompletedCheckpointStore(int maxNumberOfCheckpointsToRetain) { checkArgument(maxNumberOfCheckpointsToRetain >= 1, "Must retain at least one checkpoint."); this.maxNumberOfCheckpointsToRetain = maxNumberOfCheckpointsToRetain; this.checkpoints = new ArrayDeque<>(maxNumberOfCheckpointsToRetain + 1); - this.stateRegistry = new StateRegistry(); } @Override @@ -67,16 +64,22 @@ public void recover() throws Exception { } @Override - public void addCheckpoint(CompletedCheckpoint checkpoint) throws Exception { - - checkpoint.register(stateRegistry); + public void addCheckpoint(CompletedCheckpoint checkpoint, SharedStateRegistry sharedStateRegistry) throws Exception { + checkpoints.addLast(checkpoint); if (checkpoints.size() > maxNumberOfCheckpointsToRetain) { try { - CompletedCheckpoint oldCheckpoint = checkpoints.remove(); - List discardedStates = oldCheckpoint.unregister(stateRegistry); - oldCheckpoint.subsume(discardedStates); + CompletedCheckpoint checkpointToSubsume = checkpoints.removeFirst(); + + List unreferencedSharedStates = sharedStateRegistry.unregisterAll(checkpointToSubsume.getTaskStates().values()); + try { + StateUtil.bestEffortDiscardAllStateObjects(unreferencedSharedStates); + } catch (Exception e) { + LOG.warn("Could not properly discard unreferenced shared states.", e); + } + + checkpointToSubsume.subsume(); } catch (Exception e) { LOG.warn("Fail to subsume the old checkpoint.", e); } @@ -104,13 +107,19 @@ public int getMaxNumberOfRetainedCheckpoints() { } @Override - public void shutdown(JobStatus jobStatus) throws Exception { + public void shutdown(JobStatus jobStatus, SharedStateRegistry sharedStateRegistry) throws Exception { try { LOG.info("Shutting down"); for (CompletedCheckpoint checkpoint : checkpoints) { - List discardedStates = checkpoint.unregister(stateRegistry); - checkpoint.discard(jobStatus, discardedStates); + List unreferencedSharedStates = sharedStateRegistry.unregisterAll(checkpoint.getTaskStates().values()); + try { + StateUtil.bestEffortDiscardAllStateObjects(unreferencedSharedStates); + } catch (Exception e) { + LOG.warn("Could not properly discard unreferenced shared states.", e); + } + + checkpoint.discard(jobStatus); } } finally { checkpoints.clear(); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java index f6da7b02cbc13..487b191fed9ff 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java @@ -22,8 +22,8 @@ import org.apache.flink.runtime.state.CompositeStateHandle; import org.apache.flink.runtime.state.KeyedStateHandle; import org.apache.flink.runtime.state.OperatorStateHandle; +import org.apache.flink.runtime.state.SharedStateRegistry; import org.apache.flink.runtime.state.StateObject; -import org.apache.flink.runtime.state.StateRegistry; import org.apache.flink.runtime.state.StateUtil; import org.apache.flink.runtime.state.StreamStateHandle; import org.slf4j.Logger; @@ -144,21 +144,13 @@ public void discardState() { } @Override - public void register(StateRegistry stateRegistry) { - stateRegistry.register(legacyOperatorState); - stateRegistry.register(managedOperatorState); - stateRegistry.register(rawOperatorState); - stateRegistry.register(managedKeyedState); - stateRegistry.register(rawKeyedState); + public void register(SharedStateRegistry sharedStateRegistry) { + // No shared states } @Override - public void unregister(StateRegistry stateRegistry) { - stateRegistry.unregister(legacyOperatorState); - stateRegistry.unregister(managedOperatorState); - stateRegistry.unregister(rawOperatorState); - stateRegistry.unregister(managedKeyedState); - stateRegistry.unregister(rawKeyedState); + public void unregister(SharedStateRegistry sharedStateRegistry) { + // No shared states } @Override diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskState.java index 52093fb15bbb4..dd86c87602b47 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskState.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskState.java @@ -20,7 +20,7 @@ import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.state.CompositeStateHandle; -import org.apache.flink.runtime.state.StateRegistry; +import org.apache.flink.runtime.state.SharedStateRegistry; import org.apache.flink.util.Preconditions; import java.util.Collection; @@ -130,16 +130,16 @@ public void discardState() throws Exception { } @Override - public void register(StateRegistry snapshotRegistry) { + public void register(SharedStateRegistry sharedStateRegistry) { for (SubtaskState subtaskState : subtaskStates.values()) { - subtaskState.register(snapshotRegistry); + subtaskState.register(sharedStateRegistry); } } @Override - public void unregister(StateRegistry snapshotRegistry) { + public void unregister(SharedStateRegistry sharedStateRegistry) { for (SubtaskState subtaskState : subtaskStates.values()) { - subtaskState.unregister(snapshotRegistry); + subtaskState.unregister(sharedStateRegistry); } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStore.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStore.java index 9d578f9bf21f2..33da79baf6901 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStore.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStore.java @@ -27,8 +27,9 @@ import org.apache.flink.runtime.jobgraph.JobStatus; import org.apache.flink.runtime.jobmanager.HighAvailabilityMode; import org.apache.flink.runtime.state.RetrievableStateHandle; +import org.apache.flink.runtime.state.SharedStateRegistry; import org.apache.flink.runtime.state.StateObject; -import org.apache.flink.runtime.state.StateRegistry; +import org.apache.flink.runtime.state.StateUtil; import org.apache.flink.runtime.zookeeper.RetrievableStateStorageHelper; import org.apache.flink.runtime.zookeeper.ZooKeeperStateHandleStore; import org.apache.flink.util.FlinkException; @@ -82,9 +83,6 @@ public class ZooKeeperCompletedCheckpointStore implements CompletedCheckpointSto /** The maximum number of checkpoints to retain (at least 1). */ private final int maxNumberOfCheckpointsToRetain; - /** The registry for completed checkpoints to register used state objects. */ - private final StateRegistry stateRegistry; - /** Local completed checkpoints. */ private final ArrayDeque, String>> checkpointStateHandles; @@ -128,9 +126,7 @@ public ZooKeeperCompletedCheckpointStore( this.checkpointsInZooKeeper = new ZooKeeperStateHandleStore<>(this.client, stateStorage, executor); this.checkpointStateHandles = new ArrayDeque<>(maxNumberOfCheckpointsToRetain + 1); - - this.stateRegistry = new StateRegistry(); - + LOG.info("Initialized in '{}'.", checkpointsPath); } @@ -170,40 +166,10 @@ public void recover() throws Exception { int numberOfInitialCheckpoints = initialCheckpoints.size(); LOG.info("Found {} checkpoints in ZooKeeper.", numberOfInitialCheckpoints); - - for (Tuple2, String> initialCheckpoint : initialCheckpoints) { - long checkpointId = pathToCheckpointId(initialCheckpoint.f1); - LOG.info("Trying to retrieve checkpoint {}.", checkpointId); - - CompletedCheckpoint completedCheckpoint; - try { - completedCheckpoint = initialCheckpoint.f0.retrieveState(); - } catch (Exception e) { - throw new Exception( - "Could not retrieve the completed checkpoint " + - checkpointId + " from the state storage.", e); - } - - completedCheckpoint.register(stateRegistry); - } - - // Discard the checkpoints here. - for (int i = 0; i < numberOfInitialCheckpoints - 1; ++i) { - Tuple2, String> initialCheckpoint = initialCheckpoints.get(i); - try { - removeSubsumed(initialCheckpoint); - } catch (Exception e) { - LOG.error("Failed to discard checkpoint", e); - } + for (Tuple2, String> checkpoint : initialCheckpoints) { + checkpointStateHandles.add(checkpoint); } - - // Take the last one. This is the latest checkpoints, because path names are strictly - // increasing (checkpoint ID). - Tuple2, String> latest = initialCheckpoints - .get(numberOfInitialCheckpoints - 1); - - checkpointStateHandles.add(latest); } /** @@ -212,38 +178,21 @@ public void recover() throws Exception { * @param checkpoint Completed checkpoint to add. */ @Override - public void addCheckpoint(CompletedCheckpoint checkpoint) throws Exception { + public void addCheckpoint(final CompletedCheckpoint checkpoint, final SharedStateRegistry sharedStateRegistry) throws Exception { checkNotNull(checkpoint, "Checkpoint"); final String path = checkpointIdToPath(checkpoint.getCheckpointID()); final RetrievableStateHandle stateHandle; - // Register the states in the checkpoint before performing any subsuming - synchronized (stateRegistry) { - checkpoint.register(stateRegistry); - } - // First add the new one. If it fails, we don't want to loose existing data. - try { - stateHandle = checkpointsInZooKeeper.add(path, checkpoint); - } catch (Exception e) { - // Unregister the states if we fail to add the new checkpoint. The - // list of unreferenced state objects produced by the unregistration - // is ignored here. The coordinator is supposed to discard these - // state objects correctly. - synchronized (stateRegistry) { - checkpoint.unregister(stateRegistry); - } - - throw e; - } + stateHandle = checkpointsInZooKeeper.add(path, checkpoint); checkpointStateHandles.addLast(new Tuple2<>(stateHandle, path)); // Everything worked, let's remove a previous checkpoint if necessary. while (checkpointStateHandles.size() > maxNumberOfCheckpointsToRetain) { try { - removeSubsumed(checkpointStateHandles.removeFirst()); + removeSubsumed(checkpointStateHandles.removeFirst(), sharedStateRegistry); } catch (Exception e) { LOG.warn("Failed to subsume the old checkpoint", e); } @@ -315,13 +264,13 @@ public int getMaxNumberOfRetainedCheckpoints() { } @Override - public void shutdown(JobStatus jobStatus) throws Exception { + public void shutdown(JobStatus jobStatus, SharedStateRegistry sharedStateRegistry) throws Exception { if (jobStatus.isGloballyTerminalState()) { LOG.info("Shutting down"); for (Tuple2, String> checkpoint : checkpointStateHandles) { try { - removeShutdown(checkpoint, jobStatus); + removeShutdown(checkpoint, jobStatus, sharedStateRegistry); } catch (Exception e) { LOG.error("Failed to discard checkpoint.", e); } @@ -343,20 +292,27 @@ public void shutdown(JobStatus jobStatus) throws Exception { // ------------------------------------------------------------------------ - private void removeSubsumed(final Tuple2, String> stateHandleAndPath) throws Exception { + private void removeSubsumed( + final Tuple2, String> stateHandleAndPath, + final SharedStateRegistry sharedStateRegistry) throws Exception { + Callable action = new Callable() { @Override public Void call() throws Exception { + CompletedCheckpoint completedCheckpoint = retrieveCompletedCheckpoint(stateHandleAndPath); + + if (completedCheckpoint != null) { + List unreferencedSharedStates = sharedStateRegistry.unregisterAll(completedCheckpoint.getTaskStates().values()); - CompletedCheckpoint completedCheckpoint = stateHandleAndPath.f0.retrieveState(); - - List discardedStates; - synchronized (stateRegistry) { - discardedStates = completedCheckpoint.unregister(stateRegistry); + try { + StateUtil.bestEffortDiscardAllStateObjects(unreferencedSharedStates); + } catch (Throwable t) { + LOG.warn("Could not properly discard unreferenced shared states.", t); + } + + completedCheckpoint.subsume(); } - stateHandleAndPath.f0.retrieveState().subsume(discardedStates); - return null; } }; @@ -366,20 +322,26 @@ public Void call() throws Exception { private void removeShutdown( final Tuple2, String> stateHandleAndPath, - final JobStatus jobStatus) throws Exception { + final JobStatus jobStatus, + final SharedStateRegistry sharedStateRegistry) throws Exception { Callable action = new Callable() { @Override public Void call() throws Exception { - CompletedCheckpoint completedCheckpoint = stateHandleAndPath.f0.retrieveState(); + CompletedCheckpoint completedCheckpoint = retrieveCompletedCheckpoint(stateHandleAndPath); + + if (completedCheckpoint != null) { + List unreferencedSharedStates = sharedStateRegistry.unregisterAll(completedCheckpoint.getTaskStates().values()); - List discardedStates; - synchronized (stateRegistry) { - discardedStates = completedCheckpoint.unregister(stateRegistry); + try { + StateUtil.bestEffortDiscardAllStateObjects(unreferencedSharedStates); + } catch (Throwable t) { + LOG.warn("Could not properly discard unreferenced shared states.", t); + } + + completedCheckpoint.discard(jobStatus); } - stateHandleAndPath.f0.retrieveState().discard(jobStatus, discardedStates); - return null; } }; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/CompositeStateHandle.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/CompositeStateHandle.java index 33b7c55a32bfb..4c528659372b5 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/CompositeStateHandle.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/CompositeStateHandle.java @@ -22,28 +22,32 @@ * Base of all snapshots that are taken by {@link StateBackend}s and some other * components in tasks. * - *

Each snapshot is composed of a collection of {@link StateObject}s. The - * {@link StateObject}s in a completed checkpoint may be referenced by other - * completed checkpoints. To avoid the deletion of those objects still in use, - * the handle should register all its objects when the checkpoint completes and - * unregister its objects when the checkpoint is discarded. + *

Each snapshot is composed of a collection of {@link StateObject}s some of + * which may be referenced by other checkpoints. The shared states will be + * registered at the given {@link SharedStateRegistry} when the handle is + * received by the {@link org.apache.flink.runtime.checkpoint.CheckpointCoordinator} + * and will be discarded when the checkpoint is discarded. + * + *

The {@link SharedStateRegistry} is responsible for the discarding of the + * shared states. The composite state handle should only delete those private + * states in the {@link StateObject#discardState()} method. */ public interface CompositeStateHandle extends StateObject { /** - * This method is called when the checkpoint is added into - * {@link org.apache.flink.runtime.checkpoint.CompletedCheckpointStore}. - * That happens when the pending checkpoint succeeds to complete or the - * completed checkpoint is reloaded in the recovery. In both cases, the - * snapshot handle should register all its objects in the given - * {@link StateRegistry}. + * Register shared states in the given {@link SharedStateRegistry}. This + * method is called when the state handle is received by the + * {@link org.apache.flink.runtime.checkpoint.CheckpointCoordinator}. + * + * @param stateRegistry The registry where shared states are registered. */ - void register(StateRegistry stateRegistry); + void register(SharedStateRegistry stateRegistry); /** - * This method is called when the completed checkpoint is discarded. In such - * cases, the snapshot handle should unregister all its objects. An object - * will be deleted if it is not referenced by any checkpoint. + * Unregister shared states in the given {@link SharedStateRegistry}. This + * method is called when the state handle is discarded. + * + * @param stateRegistry The registry where shared states are registered. */ - void unregister(StateRegistry stateRegistry); + void unregister(SharedStateRegistry stateRegistry); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateHandle.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateHandle.java deleted file mode 100644 index ad340f5651c33..0000000000000 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateHandle.java +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.runtime.state; - -/** - * Base for those state handles that are shared among different checkpoints. - * - * Each shared state handle is identified by an unique key. It will be - * registered at the {@link SharedStateRegistry} once it is received by the - * {@link org.apache.flink.runtime.checkpoint.CheckpointCoordinator}. Each - * registered state handle is unregistered when the checkpoint to which the - * state handle belongs is discarded. - */ -public interface SharedStateHandle extends StateObject { - - /** - * Returns the unique identifier of the state handle - * - * @return the unique identifier of the state handle - */ - String getKey(); -} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateRegistry.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateRegistry.java index 9f3e76c7e873a..6bef231aa0ba5 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateRegistry.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateRegistry.java @@ -18,9 +18,12 @@ package org.apache.flink.runtime.state; +import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.api.java.tuple.Tuple2; +import java.io.Serializable; import java.util.ArrayList; +import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -28,59 +31,162 @@ /** * A {@code SharedStateRegistry} will be deployed in the * {@link org.apache.flink.runtime.checkpoint.CheckpointCoordinator} to - * maintain the reference count of the {@link SharedStateHandle}s. + * maintain the reference count of those state objects shared among different + * checkpoints. Each shared state object must be identified by a unique key. */ -public class SharedStateRegistry { +public class SharedStateRegistry implements Serializable { + + private static final long serialVersionUID = -8357254413007773970L; /** All registered state objects */ - private final Map> registeredStates = new HashMap<>(); + private final Map> registeredStates = new HashMap<>(); /** All state objects that are not referenced any more */ - private List discardedStates = new ArrayList<>(); + private transient final List discardedStates = new ArrayList<>(); /** * Register the state in the registry * + * @param key The key of the state to register * @param state The state to register */ - public void register(StateObject state) { - Integer referenceCount = registeredStates.get(state); + public void register(String key, StateObject state) { + Tuple2 stateAndRefCnt = registeredStates.get(key); - if (referenceCount != null) { - registeredStates.put(state, referenceCount + 1); + if (stateAndRefCnt == null) { + registeredStates.put(key, new Tuple2<>(state, 1)); } else { - registeredStates.put(state, 1); + if (!stateAndRefCnt.f0.equals(state)) { + throw new IllegalStateException("Cannot register a key with different states."); + } + + stateAndRefCnt.f1++; } } /** * Decrease the reference count of the state in the registry * - * @param state The state to unregister + * @param key The key of the state to unregister */ - public void unregister(StateObject state) { - Integer referenceCount = registeredStates.get(state); + public void unregister(String key) { + Tuple2 stateAndRefCnt = registeredStates.get(key); - if (referenceCount == null) { + if (stateAndRefCnt == null) { throw new IllegalStateException("Cannot unregister an unexisted state."); } - referenceCount--; + stateAndRefCnt.f1--; + + // Remove the state from the registry when it's not referenced any more. + if (stateAndRefCnt.f1 == 0) { + registeredStates.remove(key); + discardedStates.add(stateAndRefCnt.f0); + } + } + + /** + * Register all the shared states in the given state handles. + * + * @param stateHandles The state handles to register their shared states + */ + public void registerAll(Collection stateHandles) { + synchronized (this) { + if (stateHandles != null) { + for (CompositeStateHandle stateHandle : stateHandles) { + stateHandle.register(this); + } + } + } + } + + /** + * Register all the shared states in the given state handle. + * + * @param stateHandle The state handle to register its shared states + */ + public void registerAll(CompositeStateHandle stateHandle) { + if (stateHandle != null) { + synchronized (this) { + stateHandle.register(this); + } + } + } + + /** + * Unregister all the shared states in the given state handles and return + * those unreferenced states after these shared states are unregistered. + * + * @param stateHandles The state handles to unregister their shared states + * @return The states that are not referenced any more + */ + public List unregisterAll(Collection stateHandles) { + synchronized (this) { + discardedStates.clear(); + + if (stateHandles != null) { + for (CompositeStateHandle stateHandle : stateHandles) { + stateHandle.unregister(this); + } + } - if (referenceCount == 0) { - registeredStates.remove(state); - discardedStates.add(state); + return discardedStates; } + } /** - * Gets and resets the list of discarded state objects + * Unregister all the shared states in the given state handles and return + * those unreferenced states after these shared states are unregistered. * - * @return A list of cached unreferenced state objects + * @param stateHandle The state handle to unregister its shared states + * @return The states that are not referenced any more */ - public List getAndResetDiscardedStates() { - List result = new ArrayList<>(discardedStates); - discardedStates.clear(); + public List unregisterAll(CompositeStateHandle stateHandle) { + if (stateHandle == null) { + return null; + } else { + synchronized (this) { + discardedStates.clear(); + + stateHandle.unregister(this); + + return discardedStates; + } + } + } + + @VisibleForTesting + int getReferenceCount(String key) { + Tuple2 stateAndRefCnt = registeredStates.get(key); + + return stateAndRefCnt == null ? 0 : stateAndRefCnt.f1; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + SharedStateRegistry that = (SharedStateRegistry) o; + + return registeredStates.equals(that.registeredStates); + } + + @Override + public int hashCode() { + int result = registeredStates.hashCode(); + result = 31 * result + discardedStates.hashCode(); return result; } + + @Override + public String toString() { + return "SharedStateRegistry{" + "registeredStates=" + registeredStates + + ", discardedStates=" + discardedStates + '}'; + } } diff --git a/flink-runtime/src/main/scala/org/apache/flink/runtime/jobmanager/JobManager.scala b/flink-runtime/src/main/scala/org/apache/flink/runtime/jobmanager/JobManager.scala index 1e6d8d3b91284..451639daaa42a 100644 --- a/flink-runtime/src/main/scala/org/apache/flink/runtime/jobmanager/JobManager.scala +++ b/flink-runtime/src/main/scala/org/apache/flink/runtime/jobmanager/JobManager.scala @@ -1365,27 +1365,12 @@ class JobManager( val savepointPath = savepointSettings.getRestorePath() val allowNonRestored = savepointSettings.allowNonRestoredState() - log.info(s"Starting job from savepoint '$savepointPath'" + - (if (allowNonRestored) " (allowing non restored state)" else "") + ".") - - // load the savepoint as a checkpoint into the system - val savepoint: CompletedCheckpoint = SavepointLoader.loadAndValidateSavepoint( - jobId, - executionGraph.getAllVertices, - savepointPath, - executionGraph.getUserClassLoader, - allowNonRestored) - - executionGraph.getCheckpointCoordinator.getCheckpointStore - .addCheckpoint(savepoint) - - // Reset the checkpoint ID counter - val nextCheckpointId: Long = savepoint.getCheckpointID + 1 - log.info(s"Reset the checkpoint ID to $nextCheckpointId") - executionGraph.getCheckpointCoordinator.getCheckpointIdCounter - .setCount(nextCheckpointId) - - executionGraph.restoreLatestCheckpointedState(true, allowNonRestored) + executionGraph.getCheckpointCoordinator.restoreSavepoint( + savepointPath, + allowNonRestored, + executionGraph.getAllVertices, + executionGraph.getUserClassLoader + ) } catch { case e: Exception => jobInfo.notifyClients( diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java index 5d7e0e07de8de..853ba9cbc18b7 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java @@ -25,26 +25,24 @@ import org.apache.flink.runtime.jobgraph.JobStatus; import org.apache.flink.runtime.jobgraph.tasks.ExternalizedCheckpointSettings; import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint; -import org.apache.flink.runtime.state.ChainedStateHandle; -import org.apache.flink.runtime.state.StreamStateHandle; +import org.apache.flink.runtime.state.SharedStateRegistry; import org.apache.flink.util.TestLogger; -import org.junit.Ignore; import org.junit.Test; import org.junit.runner.RunWith; import org.powermock.api.mockito.PowerMockito; import org.powermock.core.classloader.annotations.PrepareForTest; import org.powermock.modules.junit4.PowerMockRunner; -import java.util.Collections; import java.util.List; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; +import static org.mockito.Matchers.any; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; @RunWith(PowerMockRunner.class) @PrepareForTest(PendingCheckpoint.class) @@ -107,6 +105,7 @@ public void testFailingCompletedCheckpointStoreAdd() throws Exception { assertTrue(pendingCheckpoint.isDiscarded()); // make sure that the subtask state has been discarded after we could not complete it. + verify(subtaskState, times(1)).unregister(any(SharedStateRegistry.class)); verify(subtaskState).discardState(); } @@ -118,7 +117,7 @@ public void recover() throws Exception { } @Override - public void addCheckpoint(CompletedCheckpoint checkpoint) throws Exception { + public void addCheckpoint(CompletedCheckpoint checkpoint, SharedStateRegistry sharedStateRegistry) throws Exception { throw new Exception("The failing completed checkpoint store failed again... :-("); } @@ -128,7 +127,7 @@ public CompletedCheckpoint getLatestCheckpoint() throws Exception { } @Override - public void shutdown(JobStatus jobStatus) throws Exception { + public void shutdown(JobStatus jobStatus, SharedStateRegistry sharedStateRegistry) throws Exception { throw new UnsupportedOperationException("Not implemented."); } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java index 117c70d3139fe..d806ee2ccb44e 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java @@ -43,11 +43,13 @@ import org.apache.flink.runtime.state.KeyGroupsStateHandle; import org.apache.flink.runtime.state.KeyedStateHandle; import org.apache.flink.runtime.state.OperatorStateHandle; +import org.apache.flink.runtime.state.SharedStateRegistry; import org.apache.flink.runtime.state.StreamStateHandle; import org.apache.flink.runtime.state.TaskStateHandles; import org.apache.flink.runtime.state.filesystem.FileStateHandle; import org.apache.flink.runtime.state.memory.ByteStreamStateHandle; import org.apache.flink.runtime.testutils.CommonTestUtils; +import org.apache.flink.runtime.testutils.RecoverableCompletedCheckpointStore; import org.apache.flink.runtime.util.TestByteStreamStateHandleDeepCompare; import org.apache.flink.util.InstantiationUtil; import org.apache.flink.util.Preconditions; @@ -86,13 +88,17 @@ import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyLong; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; +import static org.mockito.Mockito.reset; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.mockito.Mockito.withSettings; /** * Tests for the checkpoint coordinator. @@ -545,19 +551,25 @@ public void testTriggerAndConfirmSimpleCheckpoint() { } // acknowledge from one of the tasks - coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID2, checkpointId)); + SubtaskState subtaskState2 = mock(SubtaskState.class); + AcknowledgeCheckpoint acknowledgeCheckpoint1 = new AcknowledgeCheckpoint(jid, attemptID2, checkpointId, new CheckpointMetrics(), subtaskState2); + coord.receiveAcknowledgeMessage(acknowledgeCheckpoint1); assertEquals(1, checkpoint.getNumberOfAcknowledgedTasks()); assertEquals(1, checkpoint.getNumberOfNonAcknowledgedTasks()); assertFalse(checkpoint.isDiscarded()); assertFalse(checkpoint.isFullyAcknowledged()); + verify(subtaskState2, times(1)).register(any(SharedStateRegistry.class)); // acknowledge the same task again (should not matter) - coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID2, checkpointId)); + coord.receiveAcknowledgeMessage(acknowledgeCheckpoint1); assertFalse(checkpoint.isDiscarded()); assertFalse(checkpoint.isFullyAcknowledged()); + verify(subtaskState2, times(1)).register(any(SharedStateRegistry.class)); // acknowledge the other task. - coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID1, checkpointId)); + SubtaskState subtaskState1 = mock(SubtaskState.class); + coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID1, checkpointId, new CheckpointMetrics(), subtaskState1)); + verify(subtaskState1, times(1)).register(any(SharedStateRegistry.class)); // the checkpoint is internally converted to a successful checkpoint and the // pending checkpoint object is disposed @@ -570,6 +582,12 @@ public void testTriggerAndConfirmSimpleCheckpoint() { // the canceler should be removed now assertEquals(0, coord.getNumScheduledTasks()); + // validate that the subtasks states have not unregistered their shared states. + { + verify(subtaskState1, never()).unregister(any(SharedStateRegistry.class)); + verify(subtaskState2, never()).unregister(any(SharedStateRegistry.class)); + } + // validate that the relevant tasks got a confirmation message { verify(vertex1.getCurrentExecutionAttempt(), times(1)).triggerCheckpoint(eq(checkpointId), eq(timestamp), any(CheckpointOptions.class)); @@ -580,7 +598,7 @@ public void testTriggerAndConfirmSimpleCheckpoint() { assertEquals(jid, success.getJobId()); assertEquals(timestamp, success.getTimestamp()); assertEquals(checkpoint.getCheckpointId(), success.getCheckpointID()); - assertTrue(success.getTaskStates().isEmpty()); + assertEquals(2, success.getTaskStates().size()); // --------------- // trigger another checkpoint and see that this one replaces the other checkpoint @@ -602,6 +620,12 @@ public void testTriggerAndConfirmSimpleCheckpoint() { assertEquals(checkpointIdNew, successNew.getCheckpointID()); assertTrue(successNew.getTaskStates().isEmpty()); + // validate that the subtask states in old savepoint have unregister their shared states + { + verify(subtaskState1, times(1)).unregister(any(SharedStateRegistry.class)); + verify(subtaskState2, times(1)).unregister(any(SharedStateRegistry.class)); + } + // validate that the relevant tasks got a confirmation message { verify(vertex1.getCurrentExecutionAttempt(), times(1)).triggerCheckpoint(eq(checkpointIdNew), eq(timestampNew), any(CheckpointOptions.class)); @@ -678,8 +702,6 @@ public void testMultipleConcurrentCheckpoints() { verify(triggerVertex1.getCurrentExecutionAttempt(), times(1)).triggerCheckpoint(eq(checkpointId1), eq(timestamp1), any(CheckpointOptions.class)); verify(triggerVertex2.getCurrentExecutionAttempt(), times(1)).triggerCheckpoint(eq(checkpointId1), eq(timestamp1), any(CheckpointOptions.class)); - CheckpointMetaData checkpointMetaData1 = new CheckpointMetaData(checkpointId1, 0L); - // acknowledge one of the three tasks coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID2, checkpointId1)); @@ -699,8 +721,6 @@ public void testMultipleConcurrentCheckpoints() { } long checkpointId2 = pending2.getCheckpointId(); - CheckpointMetaData checkpointMetaData2 = new CheckpointMetaData(checkpointId2, 0L); - // trigger messages should have been sent verify(triggerVertex1.getCurrentExecutionAttempt(), times(1)).triggerCheckpoint(eq(checkpointId2), eq(timestamp2), any(CheckpointOptions.class)); verify(triggerVertex2.getCurrentExecutionAttempt(), times(1)).triggerCheckpoint(eq(checkpointId2), eq(timestamp2), any(CheckpointOptions.class)); @@ -812,10 +832,10 @@ public void testSuccessfulCheckpointSubsumesUnsuccessful() { verify(triggerVertex1.getCurrentExecutionAttempt(), times(1)).triggerCheckpoint(eq(checkpointId1), eq(timestamp1), any(CheckpointOptions.class)); verify(triggerVertex2.getCurrentExecutionAttempt(), times(1)).triggerCheckpoint(eq(checkpointId1), eq(timestamp1), any(CheckpointOptions.class)); - CheckpointMetaData checkpointMetaData1 = new CheckpointMetaData(checkpointId1, 0L); - // acknowledge one of the three tasks - coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID2, checkpointId1)); + SubtaskState subtaskState1_2 = mock(SubtaskState.class); + coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID2, checkpointId1, new CheckpointMetrics(), subtaskState1_2)); + verify(subtaskState1_2, times(1)).register(any(SharedStateRegistry.class)); // start the second checkpoint // trigger the first checkpoint. this should succeed @@ -839,12 +859,22 @@ public void testSuccessfulCheckpointSubsumesUnsuccessful() { // we acknowledge one more task from the first checkpoint and the second // checkpoint completely. The second checkpoint should then subsume the first checkpoint - CheckpointMetaData checkpointMetaData2= new CheckpointMetaData(checkpointId2, 0L); - coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID3, checkpointId2)); - coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID1, checkpointId2)); - coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID1, checkpointId1)); - coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID2, checkpointId2)); + SubtaskState subtaskState2_3 = mock(SubtaskState.class); + coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID3, checkpointId2, new CheckpointMetrics(), subtaskState2_3)); + verify(subtaskState2_3, times(1)).register(any(SharedStateRegistry.class)); + + SubtaskState subtaskState2_1 = mock(SubtaskState.class); + coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID1, checkpointId2, new CheckpointMetrics(), subtaskState2_1)); + verify(subtaskState2_1, times(1)).register(any(SharedStateRegistry.class)); + + SubtaskState subtaskState1_1 = mock(SubtaskState.class); + coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID1, checkpointId1, new CheckpointMetrics(), subtaskState1_1)); + verify(subtaskState1_1, times(1)).register(any(SharedStateRegistry.class)); + + SubtaskState subtaskState2_2 = mock(SubtaskState.class); + coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID2, checkpointId2, new CheckpointMetrics(), subtaskState2_2)); + verify(subtaskState2_2, times(1)).register(any(SharedStateRegistry.class)); // now, the second checkpoint should be confirmed, and the first discarded // actually both pending checkpoints are discarded, and the second has been transformed @@ -855,21 +885,48 @@ public void testSuccessfulCheckpointSubsumesUnsuccessful() { assertEquals(0, coord.getNumberOfPendingCheckpoints()); assertEquals(1, coord.getNumberOfRetainedSuccessfulCheckpoints()); + // validate that all received subtask states in the first checkpoint have been discarded + verify(subtaskState1_1, times(1)).unregister(any(SharedStateRegistry.class)); + verify(subtaskState1_2, times(1)).unregister(any(SharedStateRegistry.class)); + verify(subtaskState1_1, times(1)).discardState(); + verify(subtaskState1_2, times(1)).discardState(); + + // validate that all subtask states in the second checkpoint are not discarded + verify(subtaskState2_1, never()).unregister(any(SharedStateRegistry.class)); + verify(subtaskState2_2, never()).unregister(any(SharedStateRegistry.class)); + verify(subtaskState2_3, never()).unregister(any(SharedStateRegistry.class)); + verify(subtaskState2_1, never()).discardState(); + verify(subtaskState2_2, never()).discardState(); + verify(subtaskState2_3, never()).discardState(); + // validate the committed checkpoints List scs = coord.getSuccessfulCheckpoints(); CompletedCheckpoint success = scs.get(0); assertEquals(checkpointId2, success.getCheckpointID()); assertEquals(timestamp2, success.getTimestamp()); assertEquals(jid, success.getJobId()); - assertTrue(success.getTaskStates().isEmpty()); + assertEquals(3, success.getTaskStates().size()); // the first confirm message should be out verify(commitVertex.getCurrentExecutionAttempt(), times(1)).notifyCheckpointComplete(eq(checkpointId2), eq(timestamp2)); // send the last remaining ack for the first checkpoint. This should not do anything - coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID3, checkpointId1)); + SubtaskState subtaskState1_3 = mock(SubtaskState.class); + coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID3, checkpointId1, new CheckpointMetrics(), subtaskState1_3)); + verify(subtaskState1_3, times(1)).register(any(SharedStateRegistry.class)); + verify(subtaskState1_3, times(1)).unregister(any(SharedStateRegistry.class)); + verify(subtaskState1_3, times(1)).discardState(); coord.shutdown(JobStatus.FINISHED); + + // validate that the states in the second checkpoint have been discarded + verify(subtaskState2_1, times(1)).unregister(any(SharedStateRegistry.class)); + verify(subtaskState2_2, times(1)).unregister(any(SharedStateRegistry.class)); + verify(subtaskState2_3, times(1)).unregister(any(SharedStateRegistry.class)); + verify(subtaskState2_1, times(1)).discardState(); + verify(subtaskState2_2, times(1)).discardState(); + verify(subtaskState2_3, times(1)).discardState(); + } catch (Exception e) { e.printStackTrace(); @@ -924,7 +981,9 @@ public void testCheckpointTimeoutIsolated() { PendingCheckpoint checkpoint = coord.getPendingCheckpoints().values().iterator().next(); assertFalse(checkpoint.isDiscarded()); - coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID1, checkpoint.getCheckpointId())); + SubtaskState subtaskState = mock(SubtaskState.class); + coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID1, checkpoint.getCheckpointId(), new CheckpointMetrics(), subtaskState)); + verify(subtaskState, times(1)).register(any(SharedStateRegistry.class)); // wait until the checkpoint must have expired. // we check every 250 msecs conservatively for 5 seconds @@ -941,6 +1000,10 @@ public void testCheckpointTimeoutIsolated() { assertEquals(0, coord.getNumberOfPendingCheckpoints()); assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints()); + // validate that the received states have been discarded + verify(subtaskState, times(1)).unregister(any(SharedStateRegistry.class)); + verify(subtaskState, times(1)).discardState(); + // no confirm message must have been sent verify(commitVertex.getCurrentExecutionAttempt(), times(0)).notifyCheckpointComplete(anyLong(), anyLong()); @@ -993,8 +1056,6 @@ public void testHandleMessagesForNonExistingCheckpoints() { // of the vertices that need to be acknowledged. // non of the messages should throw an exception - CheckpointMetaData checkpointMetaData = new CheckpointMetaData(checkpointId, 0L); - // wrong job id coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(new JobID(), ackAttemptID1, checkpointId)); @@ -1058,19 +1119,24 @@ public void testStateCleanupForLateOrUnknownMessages() throws Exception { long checkpointId = pendingCheckpoint.getCheckpointId(); - CheckpointMetaData checkpointMetaData = new CheckpointMetaData(checkpointId, 0L); - SubtaskState triggerSubtaskState = mock(SubtaskState.class); // acknowledge the first trigger vertex coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jobId, triggerAttemptId, checkpointId, new CheckpointMetrics(), triggerSubtaskState)); + // verify that the subtask state has registered its shared states at the registry + verify(triggerSubtaskState, times(1)).register(any(SharedStateRegistry.class)); + verify(triggerSubtaskState, never()).unregister(any(SharedStateRegistry.class)); + verify(triggerSubtaskState, never()).discardState(); + SubtaskState unknownSubtaskState = mock(SubtaskState.class); // receive an acknowledge message for an unknown vertex coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jobId, new ExecutionAttemptID(), checkpointId, new CheckpointMetrics(), unknownSubtaskState)); // we should discard acknowledge messages from an unknown vertex belonging to our job + verify(unknownSubtaskState, times(1)).register(any(SharedStateRegistry.class)); + verify(unknownSubtaskState, times(1)).unregister(any(SharedStateRegistry.class)); verify(unknownSubtaskState, times(1)).discardState(); SubtaskState differentJobSubtaskState = mock(SubtaskState.class); @@ -1079,20 +1145,27 @@ public void testStateCleanupForLateOrUnknownMessages() throws Exception { coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(new JobID(), new ExecutionAttemptID(), checkpointId, new CheckpointMetrics(), differentJobSubtaskState)); // we should not interfere with different jobs + verify(differentJobSubtaskState, never()).register(any(SharedStateRegistry.class)); + verify(differentJobSubtaskState, never()).unregister(any(SharedStateRegistry.class)); verify(differentJobSubtaskState, never()).discardState(); // duplicate acknowledge message for the trigger vertex + reset(triggerSubtaskState); coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jobId, triggerAttemptId, checkpointId, new CheckpointMetrics(), triggerSubtaskState)); // duplicate acknowledge messages for a known vertex should not trigger discarding the state + verify(triggerSubtaskState, never()).register(any(SharedStateRegistry.class)); + verify(triggerSubtaskState, never()).unregister(any(SharedStateRegistry.class)); verify(triggerSubtaskState, never()).discardState(); // let the checkpoint fail at the first ack vertex + reset(triggerSubtaskState); coord.receiveDeclineMessage(new DeclineCheckpoint(jobId, ackAttemptId1, checkpointId)); assertTrue(pendingCheckpoint.isDiscarded()); // check that we've cleaned up the already acknowledged state + verify(triggerSubtaskState, times(1)).unregister(any(SharedStateRegistry.class)); verify(triggerSubtaskState, times(1)).discardState(); SubtaskState ackSubtaskState = mock(SubtaskState.class); @@ -1101,12 +1174,17 @@ public void testStateCleanupForLateOrUnknownMessages() throws Exception { coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jobId, ackAttemptId2, checkpointId, new CheckpointMetrics(), ackSubtaskState)); // check that we also cleaned up this state + verify(ackSubtaskState, times(1)).register(any(SharedStateRegistry.class)); + verify(ackSubtaskState, times(1)).unregister(any(SharedStateRegistry.class)); verify(ackSubtaskState, times(1)).discardState(); // receive an acknowledge message from an unknown job + reset(differentJobSubtaskState); coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(new JobID(), new ExecutionAttemptID(), checkpointId, new CheckpointMetrics(), differentJobSubtaskState)); // we should not interfere with different jobs + verify(differentJobSubtaskState, never()).register(any(SharedStateRegistry.class)); + verify(differentJobSubtaskState, never()).unregister(any(SharedStateRegistry.class)); verify(differentJobSubtaskState, never()).discardState(); SubtaskState unknownSubtaskState2 = mock(SubtaskState.class); @@ -1115,6 +1193,8 @@ public void testStateCleanupForLateOrUnknownMessages() throws Exception { coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jobId, new ExecutionAttemptID(), checkpointId, new CheckpointMetrics(), unknownSubtaskState2)); // we should discard acknowledge messages from an unknown vertex belonging to our job + verify(unknownSubtaskState2, times(1)).register(any(SharedStateRegistry.class)); + verify(unknownSubtaskState2, times(1)).unregister(any(SharedStateRegistry.class)); verify(unknownSubtaskState2, times(1)).discardState(); } @@ -1363,26 +1443,32 @@ public void testTriggerAndConfirmSimpleSavepoint() throws Exception { assertFalse(pending.isDiscarded()); assertFalse(pending.isFullyAcknowledged()); assertFalse(pending.canBeSubsumed()); - assertTrue(pending instanceof PendingCheckpoint); - - CheckpointMetaData checkpointMetaData = new CheckpointMetaData(checkpointId, 0L); // acknowledge from one of the tasks - coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID2, checkpointId)); + SubtaskState subtaskState2 = mock(SubtaskState.class); + AcknowledgeCheckpoint acknowledgeCheckpoint2 = new AcknowledgeCheckpoint(jid, attemptID2, checkpointId, new CheckpointMetrics(), subtaskState2); + coord.receiveAcknowledgeMessage(acknowledgeCheckpoint2); assertEquals(1, pending.getNumberOfAcknowledgedTasks()); assertEquals(1, pending.getNumberOfNonAcknowledgedTasks()); assertFalse(pending.isDiscarded()); assertFalse(pending.isFullyAcknowledged()); assertFalse(savepointFuture.isDone()); + verify(subtaskState2, times(1)).register(any(SharedStateRegistry.class)); + verify(subtaskState2, never()).unregister(any(SharedStateRegistry.class)); // acknowledge the same task again (should not matter) - coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID2, checkpointId)); + coord.receiveAcknowledgeMessage(acknowledgeCheckpoint2); assertFalse(pending.isDiscarded()); assertFalse(pending.isFullyAcknowledged()); assertFalse(savepointFuture.isDone()); + verify(subtaskState2, times(1)).register(any(SharedStateRegistry.class)); + verify(subtaskState2, never()).unregister(any(SharedStateRegistry.class)); // acknowledge the other task. - coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID1, checkpointId)); + SubtaskState subtaskState1 = mock(SubtaskState.class); + coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID1, checkpointId, new CheckpointMetrics(), subtaskState1)); + verify(subtaskState1, times(1)).register(any(SharedStateRegistry.class)); + verify(subtaskState1, never()).unregister(any(SharedStateRegistry.class)); // the checkpoint is internally converted to a successful checkpoint and the // pending checkpoint object is disposed @@ -1403,7 +1489,7 @@ public void testTriggerAndConfirmSimpleSavepoint() throws Exception { assertEquals(jid, success.getJobId()); assertEquals(timestamp, success.getTimestamp()); assertEquals(pending.getCheckpointId(), success.getCheckpointID()); - assertTrue(success.getTaskStates().isEmpty()); + assertEquals(2, success.getTaskStates().size()); // --------------- // trigger another checkpoint and see that this one replaces the other checkpoint @@ -1426,6 +1512,15 @@ public void testTriggerAndConfirmSimpleSavepoint() throws Exception { assertTrue(successNew.getTaskStates().isEmpty()); assertTrue(savepointFuture.isDone()); + // validate that the first savepoint does not discard its private states. + verify(subtaskState1, never()).discardState(); + verify(subtaskState2, never()).discardState(); + + // Savepoints are not supposed to have any shared state. But we still + // call the unregister method in case savepoints register something in the registry. + verify(subtaskState1, times(1)).unregister(any(SharedStateRegistry.class)); + verify(subtaskState2, times(1)).unregister(any(SharedStateRegistry.class)); + // validate that the relevant tasks got a confirmation message { verify(vertex1.getCurrentExecutionAttempt(), times(1)).triggerCheckpoint(eq(checkpointIdNew), eq(timestampNew), any(CheckpointOptions.class)); @@ -1478,7 +1573,6 @@ public void testSavepointsAreNotSubsumed() throws Exception { // Trigger savepoint and checkpoint Future savepointFuture1 = coord.triggerSavepoint(timestamp, savepointDir); long savepointId1 = counter.getLast(); - CheckpointMetaData checkpointMetaDataS1 = new CheckpointMetaData(savepointId1, 0L); assertEquals(1, coord.getNumberOfPendingCheckpoints()); assertTrue(coord.triggerCheckpoint(timestamp + 1, false)); @@ -1488,8 +1582,6 @@ public void testSavepointsAreNotSubsumed() throws Exception { long checkpointId2 = counter.getLast(); assertEquals(3, coord.getNumberOfPendingCheckpoints()); - CheckpointMetaData checkpointMetaData2 = new CheckpointMetaData(checkpointId2, 0L); - // 2nd checkpoint should subsume the 1st checkpoint, but not the savepoint coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID1, checkpointId2)); coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID2, checkpointId2)); @@ -1505,7 +1597,6 @@ public void testSavepointsAreNotSubsumed() throws Exception { Future savepointFuture2 = coord.triggerSavepoint(timestamp + 4, savepointDir); long savepointId2 = counter.getLast(); - CheckpointMetaData checkpointMetaDataS2 = new CheckpointMetaData(savepointId2, 0L); assertEquals(3, coord.getNumberOfPendingCheckpoints()); // 2nd savepoint should subsume the last checkpoint, but not the 1st savepoint @@ -1880,6 +1971,8 @@ public void testRestoreLatestCheckpointedState() throws Exception { ExecutionVertex[] arrayExecutionVertices = allExecutionVertices.toArray(new ExecutionVertex[allExecutionVertices.size()]); + CompletedCheckpointStore store = new RecoverableCompletedCheckpointStore(); + // set up the coordinator and validate the initial state CheckpointCoordinator coord = new CheckpointCoordinator( jid, @@ -1892,7 +1985,7 @@ public void testRestoreLatestCheckpointedState() throws Exception { arrayExecutionVertices, arrayExecutionVertices, new StandaloneCheckpointIDCounter(), - new StandaloneCompletedCheckpointStore(1), + store, null, Executors.directExecutor()); @@ -1901,38 +1994,32 @@ public void testRestoreLatestCheckpointedState() throws Exception { assertTrue(coord.getPendingCheckpoints().keySet().size() == 1); long checkpointId = Iterables.getOnlyElement(coord.getPendingCheckpoints().keySet()); - CheckpointMetaData checkpointMetaData = new CheckpointMetaData(checkpointId, 0L); List keyGroupPartitions1 = StateAssignmentOperation.createKeyGroupPartitions(maxParallelism1, parallelism1); List keyGroupPartitions2 = StateAssignmentOperation.createKeyGroupPartitions(maxParallelism2, parallelism2); for (int index = 0; index < jobVertex1.getParallelism(); index++) { - ChainedStateHandle nonPartitionedState = generateStateForVertex(jobVertexID1, index); - ChainedStateHandle partitionableState = generateChainedPartitionableStateHandle(jobVertexID1, index, 2, 8, false); - KeyGroupsStateHandle partitionedKeyGroupState = generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index), false); + SubtaskState subtaskState = mockSubtaskState(jobVertexID1, index, keyGroupPartitions1.get(index)); - SubtaskState checkpointStateHandles = new SubtaskState(nonPartitionedState, partitionableState, null, partitionedKeyGroupState, null); AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint( jid, jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(), checkpointId, new CheckpointMetrics(), - checkpointStateHandles); + subtaskState); coord.receiveAcknowledgeMessage(acknowledgeCheckpoint); } for (int index = 0; index < jobVertex2.getParallelism(); index++) { - ChainedStateHandle nonPartitionedState = generateStateForVertex(jobVertexID2, index); - ChainedStateHandle partitionableState = generateChainedPartitionableStateHandle(jobVertexID2, index, 2, 8, false); - KeyGroupsStateHandle partitionedKeyGroupState = generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index), false); - SubtaskState checkpointStateHandles = new SubtaskState(nonPartitionedState, partitionableState, null, partitionedKeyGroupState, null); + SubtaskState subtaskState = mockSubtaskState(jobVertexID2, index, keyGroupPartitions2.get(index)); + AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint( jid, jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(), checkpointId, new CheckpointMetrics(), - checkpointStateHandles); + subtaskState); coord.receiveAcknowledgeMessage(acknowledgeCheckpoint); } @@ -1941,6 +2028,19 @@ public void testRestoreLatestCheckpointedState() throws Exception { assertEquals(1, completedCheckpoints.size()); + // shutdown the store + store.shutdown(JobStatus.SUSPENDED, new SharedStateRegistry()); + + // All shared states should be unregistered once the store is shut down + for (CompletedCheckpoint completedCheckpoint : completedCheckpoints) { + for (TaskState taskState : completedCheckpoint.getTaskStates().values()) { + for (SubtaskState subtaskState : taskState.getStates()) { + verify(subtaskState, times(1)).unregister(any(SharedStateRegistry.class)); + } + } + } + + // restore the store Map tasks = new HashMap<>(); tasks.put(jobVertexID1, jobVertex1); @@ -1948,6 +2048,15 @@ public void testRestoreLatestCheckpointedState() throws Exception { coord.restoreLatestCheckpointedState(tasks, true, false); + // validate that all shared states are registered again after the recovery. + for (CompletedCheckpoint completedCheckpoint : completedCheckpoints) { + for (TaskState taskState : completedCheckpoint.getTaskStates().values()) { + for (SubtaskState subtaskState : taskState.getStates()) { + verify(subtaskState, times(2)).register(any(SharedStateRegistry.class)); + } + } + } + // verify the restored state verifyStateRestore(jobVertexID1, jobVertex1, keyGroupPartitions1); verifyStateRestore(jobVertexID2, jobVertex2, keyGroupPartitions2); @@ -2666,6 +2775,26 @@ private static ExecutionVertex mockExecutionVertex( return vertex; } + static SubtaskState mockSubtaskState( + JobVertexID jobVertexID, + int index, + KeyGroupRange keyGroupRange) throws IOException { + + ChainedStateHandle nonPartitionedState = generateStateForVertex(jobVertexID, index); + ChainedStateHandle partitionableState = generateChainedPartitionableStateHandle(jobVertexID, index, 2, 8, false); + KeyGroupsStateHandle partitionedKeyGroupState = generateKeyGroupState(jobVertexID, keyGroupRange, false); + + SubtaskState subtaskState = mock(SubtaskState.class, withSettings().serializable()); + + doReturn(nonPartitionedState).when(subtaskState).getLegacyOperatorState(); + doReturn(partitionableState).when(subtaskState).getManagedOperatorState(); + doReturn(null).when(subtaskState).getRawOperatorState(); + doReturn(partitionedKeyGroupState).when(subtaskState).getManagedKeyedState(); + doReturn(null).when(subtaskState).getRawKeyedState(); + + return subtaskState; + } + public static void verifyStateRestore( JobVertexID jobVertexID, ExecutionJobVertex executionJobVertex, List keyGroupPartitions) throws Exception { @@ -3018,7 +3147,6 @@ public void testCheckpointStatsTrackerRestoreCallback() throws Exception { ExecutionVertex vertex1 = mockExecutionVertex(new ExecutionAttemptID()); StandaloneCompletedCheckpointStore store = new StandaloneCompletedCheckpointStore(1); - store.addCheckpoint(new CompletedCheckpoint(new JobID(), 0, 0, 0, Collections.emptyMap())); // set up the coordinator and validate the initial state CheckpointCoordinator coord = new CheckpointCoordinator( @@ -3035,6 +3163,10 @@ public void testCheckpointStatsTrackerRestoreCallback() throws Exception { store, null, Executors.directExecutor()); + + store.addCheckpoint( + new CompletedCheckpoint(new JobID(), 0, 0, 0, Collections.emptyMap()), + coord.getSharedStateRegistry()); CheckpointStatsTracker tracker = mock(CheckpointStatsTracker.class); coord.setCheckpointStatsTracker(tracker); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java index 7e0a7c1859cc6..9e372e161fbfd 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java @@ -255,7 +255,7 @@ public void testNonRestoredState() throws Exception { } CompletedCheckpoint checkpoint = new CompletedCheckpoint(new JobID(), 0, 1, 2, new HashMap<>(checkpointTaskStates)); - coord.getCheckpointStore().addCheckpoint(checkpoint); + coord.getCheckpointStore().addCheckpoint(checkpoint, coord.getSharedStateRegistry()); coord.restoreLatestCheckpointedState(tasks, true, false); coord.restoreLatestCheckpointedState(tasks, true, true); @@ -273,7 +273,7 @@ public void testNonRestoredState() throws Exception { checkpoint = new CompletedCheckpoint(new JobID(), 1, 2, 3, new HashMap<>(checkpointTaskStates)); - coord.getCheckpointStore().addCheckpoint(checkpoint); + coord.getCheckpointStore().addCheckpoint(checkpoint, coord.getSharedStateRegistry()); // (i) Allow non restored state (should succeed) coord.restoreLatestCheckpointedState(tasks, true, true); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java index f6a7fdaed8470..509e2165770ce 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java @@ -23,6 +23,8 @@ import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.messages.CheckpointMessagesTest; import org.apache.flink.runtime.state.ChainedStateHandle; +import org.apache.flink.runtime.state.KeyGroupRange; +import org.apache.flink.runtime.state.SharedStateRegistry; import org.apache.flink.runtime.state.StateObject; import org.apache.flink.runtime.state.StreamStateHandle; import org.apache.flink.util.TestLogger; @@ -38,6 +40,8 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; /** * Test for basic {@link CompletedCheckpointStore} contract. @@ -66,7 +70,8 @@ public void testExceptionOnNoRetainedCheckpoints() throws Exception { @Test public void testAddAndGetLatestCheckpoint() throws Exception { CompletedCheckpointStore checkpoints = createCompletedCheckpoints(4); - + SharedStateRegistry sharedStateRegistry = new SharedStateRegistry(); + // Empty state assertEquals(0, checkpoints.getNumberOfRetainedCheckpoints()); assertEquals(0, checkpoints.getAllCheckpoints().size()); @@ -75,11 +80,11 @@ public void testAddAndGetLatestCheckpoint() throws Exception { createCheckpoint(0), createCheckpoint(1) }; // Add and get latest - checkpoints.addCheckpoint(expected[0]); + checkpoints.addCheckpoint(expected[0], sharedStateRegistry); assertEquals(1, checkpoints.getNumberOfRetainedCheckpoints()); verifyCheckpoint(expected[0], checkpoints.getLatestCheckpoint()); - checkpoints.addCheckpoint(expected[1]); + checkpoints.addCheckpoint(expected[1], sharedStateRegistry); assertEquals(2, checkpoints.getNumberOfRetainedCheckpoints()); verifyCheckpoint(expected[1], checkpoints.getLatestCheckpoint()); } @@ -90,7 +95,8 @@ public void testAddAndGetLatestCheckpoint() throws Exception { */ @Test public void testAddCheckpointMoreThanMaxRetained() throws Exception { - CompletedCheckpointStore checkpoints = createCompletedCheckpoints(1); + CompletedCheckpointStore checkpoints = createCompletedCheckpoints(1); + SharedStateRegistry sharedStateRegistry = new SharedStateRegistry(); TestCompletedCheckpoint[] expected = new TestCompletedCheckpoint[] { createCheckpoint(0), createCheckpoint(1), @@ -98,16 +104,25 @@ public void testAddCheckpointMoreThanMaxRetained() throws Exception { }; // Add checkpoints - checkpoints.addCheckpoint(expected[0]); + sharedStateRegistry.registerAll(expected[0].getTaskStates().values()); + checkpoints.addCheckpoint(expected[0], sharedStateRegistry); assertEquals(1, checkpoints.getNumberOfRetainedCheckpoints()); for (int i = 1; i < expected.length; i++) { - checkpoints.addCheckpoint(expected[i]); + Collection taskStates = expected[i - 1].getTaskStates().values(); + + checkpoints.addCheckpoint(expected[i], sharedStateRegistry); // The ZooKeeper implementation discards asynchronously expected[i - 1].awaitDiscard(); assertTrue(expected[i - 1].isDiscarded()); assertEquals(1, checkpoints.getNumberOfRetainedCheckpoints()); + + for (TaskState taskState : taskStates) { + for (SubtaskState subtaskState : taskState.getStates()) { + verify(subtaskState, times(1)).unregister(sharedStateRegistry); + } + } } } @@ -134,6 +149,7 @@ public void testEmptyState() throws Exception { @Test public void testGetAllCheckpoints() throws Exception { CompletedCheckpointStore checkpoints = createCompletedCheckpoints(4); + SharedStateRegistry sharedStateRegistry = new SharedStateRegistry(); TestCompletedCheckpoint[] expected = new TestCompletedCheckpoint[] { createCheckpoint(0), createCheckpoint(1), @@ -141,7 +157,7 @@ public void testGetAllCheckpoints() throws Exception { }; for (TestCompletedCheckpoint checkpoint : expected) { - checkpoints.addCheckpoint(checkpoint); + checkpoints.addCheckpoint(checkpoint, sharedStateRegistry); } List actual = checkpoints.getAllCheckpoints(); @@ -159,6 +175,7 @@ public void testGetAllCheckpoints() throws Exception { @Test public void testDiscardAllCheckpoints() throws Exception { CompletedCheckpointStore checkpoints = createCompletedCheckpoints(4); + SharedStateRegistry sharedStateRegistry = new SharedStateRegistry(); TestCompletedCheckpoint[] expected = new TestCompletedCheckpoint[] { createCheckpoint(0), createCheckpoint(1), @@ -166,10 +183,10 @@ public void testDiscardAllCheckpoints() throws Exception { }; for (TestCompletedCheckpoint checkpoint : expected) { - checkpoints.addCheckpoint(checkpoint); + checkpoints.addCheckpoint(checkpoint, sharedStateRegistry); } - checkpoints.shutdown(JobStatus.FINISHED); + checkpoints.shutdown(JobStatus.FINISHED, sharedStateRegistry); // Empty state assertNull(checkpoints.getLatestCheckpoint()); @@ -205,10 +222,9 @@ protected TestCompletedCheckpoint createCheckpoint(int id, int numberOfStates, C taskGroupStates.put(jvid, taskState); for (int i = 0; i < numberOfStates; i++) { - ChainedStateHandle stateHandle = CheckpointCoordinatorTest.generateChainedStateHandle( - new CheckpointMessagesTest.MyHandle()); + SubtaskState subtaskState = CheckpointCoordinatorTest.mockSubtaskState(jvid, i, new KeyGroupRange(i, i)); - taskState.putState(i, new SubtaskState(stateHandle, null, null, null, null)); + taskState.putState(i, subtaskState); } return new TestCompletedCheckpoint(new JobID(), id, 0, taskGroupStates, props); @@ -243,8 +259,8 @@ public TestCompletedCheckpoint( } @Override - public boolean subsume(Collection discardedStates) throws Exception { - if (super.subsume(discardedStates)) { + public boolean subsume() throws Exception { + if (super.subsume()) { discard(); return true; } else { @@ -253,8 +269,8 @@ public boolean subsume(Collection discardedStates) throws } @Override - public boolean discard(JobStatus jobStatus, Collection discardedStates) throws Exception { - if (super.discard(jobStatus, discardedStates)) { + public boolean discard(JobStatus jobStatus) throws Exception { + if (super.discard(jobStatus)) { discard(); return true; } else { diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointTest.java index d0c98ac3c2df0..b34e9a6ac26f5 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointTest.java @@ -61,7 +61,7 @@ public void testDiscard() throws Exception { new FileStateHandle(new Path(file.toURI()), file.length()), file.getAbsolutePath()); - checkpoint.discard(JobStatus.FAILED, checkpoint.getTaskStates().values()); + checkpoint.discard(JobStatus.FAILED); assertEquals(false, file.exists()); } @@ -81,7 +81,7 @@ public void testCleanUpOnSubsume() throws Exception { new JobID(), 0, 0, 1, taskStates, props); // Subsume - checkpoint.subsume(taskStates.values()); + checkpoint.subsume(); verify(state, times(1)).discardState(); } @@ -112,7 +112,7 @@ public void testCleanUpOnShutdown() throws Exception { new FileStateHandle(new Path(file.toURI()), file.length()), externalPath); - checkpoint.discard(status, taskStates.values()); + checkpoint.discard(status); verify(state, times(0)).discardState(); assertEquals(true, file.exists()); @@ -121,7 +121,7 @@ public void testCleanUpOnShutdown() throws Exception { checkpoint = new CompletedCheckpoint( new JobID(), 0, 0, 1, new HashMap<>(taskStates), props); - checkpoint.discard(taskStates.values()); + checkpoint.discard(status); verify(state, times(1)).discardState(); } } @@ -146,7 +146,7 @@ public void testCompletedCheckpointStatsCallbacks() throws Exception { CompletedCheckpointStats.DiscardCallback callback = mock(CompletedCheckpointStats.DiscardCallback.class); completed.setDiscardCallback(callback); - completed.discard(JobStatus.FINISHED, taskStates.values()); + completed.discard(JobStatus.FINISHED); verify(callback, times(1)).notifyDiscardedCheckpoint(); } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ExecutionGraphCheckpointCoordinatorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ExecutionGraphCheckpointCoordinatorTest.java index 98b4c4deeeba5..f085844b0bb6c 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ExecutionGraphCheckpointCoordinatorTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ExecutionGraphCheckpointCoordinatorTest.java @@ -35,6 +35,7 @@ import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable; import org.apache.flink.runtime.jobgraph.tasks.ExternalizedCheckpointSettings; import org.apache.flink.runtime.jobmanager.scheduler.Scheduler; +import org.apache.flink.runtime.state.SharedStateRegistry; import org.apache.flink.runtime.testingUtils.TestingUtils; import org.apache.flink.util.SerializedValue; import org.junit.AfterClass; @@ -44,6 +45,8 @@ import java.net.URL; import java.util.Collections; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -70,7 +73,7 @@ public void testShutdownCheckpointCoordinator() throws Exception { graph.fail(new Exception("Test Exception")); verify(counter, times(1)).shutdown(JobStatus.FAILED); - verify(store, times(1)).shutdown(JobStatus.FAILED); + verify(store, times(1)).shutdown(eq(JobStatus.FAILED), any(SharedStateRegistry.class)); } /** @@ -86,8 +89,8 @@ public void testSuspendCheckpointCoordinator() throws Exception { graph.suspend(new Exception("Test Exception")); // No shutdown - verify(counter, times(1)).shutdown(Matchers.eq(JobStatus.SUSPENDED)); - verify(store, times(1)).shutdown(Matchers.eq(JobStatus.SUSPENDED)); + verify(counter, times(1)).shutdown(eq(JobStatus.SUSPENDED)); + verify(store, times(1)).shutdown(eq(JobStatus.SUSPENDED), any(SharedStateRegistry.class)); } private ExecutionGraph createExecutionGraphAndEnableCheckpointing( diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java index a15684c6c4b7a..00635dc7eb20d 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java @@ -24,6 +24,7 @@ import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.executiongraph.ExecutionVertex; import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.state.SharedStateRegistry; import org.junit.Assert; import org.junit.Rule; import org.junit.Test; @@ -45,6 +46,7 @@ import static org.junit.Assert.fail; import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyLong; +import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -184,52 +186,60 @@ public void testCompletionFuture() throws Exception { @SuppressWarnings("unchecked") public void testAbortDiscardsState() throws Exception { CheckpointProperties props = new CheckpointProperties(false, true, false, false, false, false, false); - TaskState state = mock(TaskState.class); QueueExecutor executor = new QueueExecutor(); + SharedStateRegistry sharedStateRegistry = new SharedStateRegistry(); + + TaskState state = mock(TaskState.class); + doNothing().when(state).register(any(SharedStateRegistry.class)); + doNothing().when(state).unregister(any(SharedStateRegistry.class)); String targetDir = tmpFolder.newFolder().getAbsolutePath(); // Abort declined - PendingCheckpoint pending = createPendingCheckpoint(props, targetDir, executor); + PendingCheckpoint pending = createPendingCheckpoint(props, targetDir, executor, sharedStateRegistry); setTaskState(pending, state); pending.abortDeclined(); // execute asynchronous discard operation executor.runQueuedCommands(); verify(state, times(1)).discardState(); + verify(state, times(1)).unregister(sharedStateRegistry); // Abort error Mockito.reset(state); - pending = createPendingCheckpoint(props, targetDir, executor); + pending = createPendingCheckpoint(props, targetDir, executor, sharedStateRegistry); setTaskState(pending, state); pending.abortError(new Exception("Expected Test Exception")); // execute asynchronous discard operation executor.runQueuedCommands(); verify(state, times(1)).discardState(); + verify(state, times(1)).unregister(sharedStateRegistry); // Abort expired Mockito.reset(state); - pending = createPendingCheckpoint(props, targetDir, executor); + pending = createPendingCheckpoint(props, targetDir, executor, sharedStateRegistry); setTaskState(pending, state); pending.abortExpired(); // execute asynchronous discard operation executor.runQueuedCommands(); verify(state, times(1)).discardState(); + verify(state, times(1)).unregister(sharedStateRegistry); // Abort subsumed Mockito.reset(state); - pending = createPendingCheckpoint(props, targetDir, executor); + pending = createPendingCheckpoint(props, targetDir, executor, sharedStateRegistry); setTaskState(pending, state); pending.abortSubsumed(); // execute asynchronous discard operation executor.runQueuedCommands(); verify(state, times(1)).discardState(); + verify(state, times(1)).unregister(sharedStateRegistry); } /** @@ -337,10 +347,15 @@ public void testSetCanceller() { // ------------------------------------------------------------------------ private static PendingCheckpoint createPendingCheckpoint(CheckpointProperties props, String targetDirectory) { - return createPendingCheckpoint(props, targetDirectory, Executors.directExecutor()); + return createPendingCheckpoint(props, targetDirectory, Executors.directExecutor(), new SharedStateRegistry()); } - private static PendingCheckpoint createPendingCheckpoint(CheckpointProperties props, String targetDirectory, Executor executor) { + private static PendingCheckpoint createPendingCheckpoint( + CheckpointProperties props, + String targetDirectory, + Executor executor, + SharedStateRegistry sharedStateRegistry) { + Map ackTasks = new HashMap<>(ACK_TASKS); return new PendingCheckpoint( new JobID(), @@ -349,7 +364,8 @@ private static PendingCheckpoint createPendingCheckpoint(CheckpointProperties pr ackTasks, props, targetDirectory, - executor); + executor, + sharedStateRegistry); } @SuppressWarnings("unchecked") diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStoreTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStoreTest.java index cc7b2d0855fb4..6d2f7d0f91b55 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStoreTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStoreTest.java @@ -19,9 +19,11 @@ package org.apache.flink.runtime.checkpoint; import org.apache.flink.runtime.jobgraph.JobStatus; +import org.apache.flink.runtime.state.SharedStateRegistry; import org.junit.Test; import java.io.IOException; +import java.util.Collections; import java.util.List; import static org.junit.Assert.assertEquals; @@ -49,13 +51,15 @@ protected CompletedCheckpointStore createCompletedCheckpoints( @Test public void testShutdownDiscardsCheckpoints() throws Exception { CompletedCheckpointStore store = createCompletedCheckpoints(1); + SharedStateRegistry sharedStateRegistry = new SharedStateRegistry(); TestCompletedCheckpoint checkpoint = createCheckpoint(0); - store.addCheckpoint(checkpoint); - assertEquals(1, store.getNumberOfRetainedCheckpoints()); + sharedStateRegistry.registerAll(checkpoint.getTaskStates().values()); - store.shutdown(JobStatus.FINISHED); + store.addCheckpoint(checkpoint, sharedStateRegistry); + assertEquals(1, store.getNumberOfRetainedCheckpoints()); + store.shutdown(JobStatus.FINISHED, sharedStateRegistry); assertEquals(0, store.getNumberOfRetainedCheckpoints()); assertTrue(checkpoint.isDiscarded()); } @@ -67,13 +71,15 @@ public void testShutdownDiscardsCheckpoints() throws Exception { @Test public void testSuspendDiscardsCheckpoints() throws Exception { CompletedCheckpointStore store = createCompletedCheckpoints(1); + SharedStateRegistry sharedStateRegistry = new SharedStateRegistry(); + TestCompletedCheckpoint checkpoint = createCheckpoint(0); - store.addCheckpoint(checkpoint); + sharedStateRegistry.registerAll(checkpoint.getTaskStates().values()); + store.addCheckpoint(checkpoint, sharedStateRegistry); assertEquals(1, store.getNumberOfRetainedCheckpoints()); - store.shutdown(JobStatus.SUSPENDED); - + store.shutdown(JobStatus.SUSPENDED, sharedStateRegistry); assertEquals(0, store.getNumberOfRetainedCheckpoints()); assertTrue(checkpoint.isDiscarded()); } @@ -87,14 +93,16 @@ public void testAddCheckpointWithFailedRemove() throws Exception { final int numCheckpointsToRetain = 1; CompletedCheckpointStore store = createCompletedCheckpoints(numCheckpointsToRetain); + SharedStateRegistry sharedStateRegistry = new SharedStateRegistry(); for (long i = 0; i <= numCheckpointsToRetain; ++i) { CompletedCheckpoint checkpointToAdd = mock(CompletedCheckpoint.class); doReturn(i).when(checkpointToAdd).getCheckpointID(); + doReturn(Collections.emptyMap()).when(checkpointToAdd).getTaskStates(); doThrow(new IOException()).when(checkpointToAdd).subsume(); try { - store.addCheckpoint(checkpointToAdd); + store.addCheckpoint(checkpointToAdd, sharedStateRegistry); // The checkpoint should be in the store if we successfully add it into the store. List addedCheckpoints = store.getAllCheckpoints(); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreITCase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreITCase.java index 625999a3343c4..316d6c3fb0f5d 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreITCase.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreITCase.java @@ -22,6 +22,7 @@ import org.apache.flink.runtime.concurrent.Executors; import org.apache.flink.runtime.jobgraph.JobStatus; import org.apache.flink.runtime.state.RetrievableStateHandle; +import org.apache.flink.runtime.state.SharedStateRegistry; import org.apache.flink.runtime.zookeeper.RetrievableStateStorageHelper; import org.apache.flink.runtime.zookeeper.ZooKeeperTestEnvironment; import org.junit.AfterClass; @@ -80,15 +81,19 @@ public RetrievableStateHandle store(CompletedCheckpoint sta @Test public void testRecover() throws Exception { CompletedCheckpointStore checkpoints = createCompletedCheckpoints(3); + SharedStateRegistry sharedStateRegistry = new SharedStateRegistry(); TestCompletedCheckpoint[] expected = new TestCompletedCheckpoint[] { createCheckpoint(0), createCheckpoint(1), createCheckpoint(2) }; // Add multiple checkpoints - checkpoints.addCheckpoint(expected[0]); - checkpoints.addCheckpoint(expected[1]); - checkpoints.addCheckpoint(expected[2]); + sharedStateRegistry.registerAll(expected[0].getTaskStates().values()); + checkpoints.addCheckpoint(expected[0], sharedStateRegistry); + sharedStateRegistry.registerAll(expected[1].getTaskStates().values()); + checkpoints.addCheckpoint(expected[1], sharedStateRegistry); + sharedStateRegistry.registerAll(expected[2].getTaskStates().values()); + checkpoints.addCheckpoint(expected[2], sharedStateRegistry); // All three should be in ZK assertEquals(3, ZooKeeper.getClient().getChildren().forPath(CheckpointsPath).size()); @@ -106,7 +111,7 @@ public void testRecover() throws Exception { expectedCheckpoints.add(expected[2]); expectedCheckpoints.add(createCheckpoint(3)); - checkpoints.addCheckpoint(expectedCheckpoints.get(2)); + checkpoints.addCheckpoint(expectedCheckpoints.get(2), sharedStateRegistry); List actualCheckpoints = checkpoints.getAllCheckpoints(); @@ -121,14 +126,15 @@ public void testShutdownDiscardsCheckpoints() throws Exception { CuratorFramework client = ZooKeeper.getClient(); CompletedCheckpointStore store = createCompletedCheckpoints(1); + SharedStateRegistry sharedStateRegistry = new SharedStateRegistry(); TestCompletedCheckpoint checkpoint = createCheckpoint(0); - store.addCheckpoint(checkpoint); + sharedStateRegistry.registerAll(checkpoint.getTaskStates().values()); + store.addCheckpoint(checkpoint, sharedStateRegistry); assertEquals(1, store.getNumberOfRetainedCheckpoints()); assertNotNull(client.checkExists().forPath(CheckpointsPath + "/" + checkpoint.getCheckpointID())); - store.shutdown(JobStatus.FINISHED); - + store.shutdown(JobStatus.FINISHED, sharedStateRegistry); assertEquals(0, store.getNumberOfRetainedCheckpoints()); assertNull(client.checkExists().forPath(CheckpointsPath + "/" + checkpoint.getCheckpointID())); @@ -146,13 +152,14 @@ public void testSuspendKeepsCheckpoints() throws Exception { CuratorFramework client = ZooKeeper.getClient(); CompletedCheckpointStore store = createCompletedCheckpoints(1); + SharedStateRegistry sharedStateRegistry = new SharedStateRegistry(); TestCompletedCheckpoint checkpoint = createCheckpoint(0); - store.addCheckpoint(checkpoint); + store.addCheckpoint(checkpoint, sharedStateRegistry); assertEquals(1, store.getNumberOfRetainedCheckpoints()); assertNotNull(client.checkExists().forPath(CheckpointsPath + "/" + checkpoint.getCheckpointID())); - store.shutdown(JobStatus.SUSPENDED); + store.shutdown(JobStatus.SUSPENDED, sharedStateRegistry); assertEquals(0, store.getNumberOfRetainedCheckpoints()); assertNotNull(client.checkExists().forPath(CheckpointsPath + "/" + checkpoint.getCheckpointID())); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreTest.java index aa2ec851d53d9..62545fb286d8d 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreTest.java @@ -27,6 +27,7 @@ import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.runtime.concurrent.Executors; import org.apache.flink.runtime.state.RetrievableStateHandle; +import org.apache.flink.runtime.state.SharedStateRegistry; import org.apache.flink.runtime.zookeeper.RetrievableStateStorageHelper; import org.apache.flink.runtime.zookeeper.ZooKeeperStateHandleStore; import org.apache.flink.util.TestLogger; @@ -40,6 +41,7 @@ import java.util.ArrayList; import java.util.Collection; +import java.util.Collections; import java.util.HashSet; import java.util.List; import java.util.concurrent.Executor; @@ -222,14 +224,17 @@ public RetrievableStateHandle answer(InvocationOnMock invoc checkpointsPath, stateSotrage, Executors.directExecutor()); + + SharedStateRegistry sharedStateRegistry = new SharedStateRegistry(); for (long i = 0; i <= numCheckpointsToRetain; ++i) { CompletedCheckpoint checkpointToAdd = mock(CompletedCheckpoint.class); doReturn(i).when(checkpointToAdd).getCheckpointID(); + doReturn(Collections.emptyMap()).when(checkpointToAdd).getTaskStates(); try { - zooKeeperCompletedCheckpointStore.addCheckpoint(checkpointToAdd); + zooKeeperCompletedCheckpointStore.addCheckpoint(checkpointToAdd, sharedStateRegistry); // The checkpoint should be in the store if we successfully add it into the store. List addedCheckpoints = zooKeeperCompletedCheckpointStore.getAllCheckpoints(); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java index a2cda84428aff..483c3605ef84f 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java @@ -71,8 +71,7 @@ import org.apache.flink.runtime.messages.JobManagerMessages; import org.apache.flink.runtime.metrics.MetricRegistry; import org.apache.flink.runtime.state.ChainedStateHandle; -import org.apache.flink.runtime.state.StateObject; -import org.apache.flink.runtime.state.StateRegistry; +import org.apache.flink.runtime.state.SharedStateRegistry; import org.apache.flink.runtime.state.StreamStateHandle; import org.apache.flink.runtime.state.TaskStateHandles; import org.apache.flink.runtime.state.memory.ByteStreamStateHandle; @@ -83,6 +82,7 @@ import org.apache.flink.runtime.testingUtils.TestingTaskManager; import org.apache.flink.runtime.testingUtils.TestingTaskManagerMessages; import org.apache.flink.runtime.testingUtils.TestingUtils; +import org.apache.flink.runtime.testutils.RecoverableCompletedCheckpointStore; import org.apache.flink.runtime.util.TestByteStreamStateHandleDeepCompare; import org.apache.flink.util.InstantiationUtil; @@ -166,7 +166,7 @@ public void testJobRecoveryWhenLosingLeadership() throws Exception { Scheduler scheduler = new Scheduler(TestingUtils.defaultExecutionContext()); MySubmittedJobGraphStore mySubmittedJobGraphStore = new MySubmittedJobGraphStore(); - MyCheckpointStore checkpointStore = new MyCheckpointStore(); + CompletedCheckpointStore checkpointStore = new RecoverableCompletedCheckpointStore(); CheckpointIDCounter checkpointCounter = new StandaloneCheckpointIDCounter(); CheckpointRecoveryFactory checkpointStateFactory = new MyCheckpointRecoveryFactory(checkpointStore, checkpointCounter); TestingLeaderElectionService myLeaderElectionService = new TestingLeaderElectionService(); @@ -440,73 +440,6 @@ public void apply(Object o) throws Exception { } } - /** - * A checkpoint store, which supports shutdown and suspend. You can use this to test HA - * as long as the factory always returns the same store instance. - */ - static class MyCheckpointStore implements CompletedCheckpointStore { - - private final ArrayDeque checkpoints = new ArrayDeque<>(2); - - private final ArrayDeque suspended = new ArrayDeque<>(2); - - private final StateRegistry stateRegistry = new StateRegistry(); - - @Override - public void recover() throws Exception { - checkpoints.addAll(suspended); - suspended.clear(); - } - - @Override - public void addCheckpoint(CompletedCheckpoint checkpoint) throws Exception { - checkpoint.register(stateRegistry); - checkpoints.addLast(checkpoint); - - if (checkpoints.size() > 1) { - CompletedCheckpoint subsumedCheckpoint = checkpoints.removeFirst(); - List discardedStates = subsumedCheckpoint.unregister(stateRegistry); - subsumedCheckpoint.subsume(discardedStates); - } - } - - @Override - public CompletedCheckpoint getLatestCheckpoint() throws Exception { - return checkpoints.isEmpty() ? null : checkpoints.getLast(); - } - - @Override - public void shutdown(JobStatus jobStatus) throws Exception { - if (jobStatus.isGloballyTerminalState()) { - checkpoints.clear(); - suspended.clear(); - } else { - suspended.addAll(checkpoints); - checkpoints.clear(); - } - } - - @Override - public List getAllCheckpoints() throws Exception { - return new ArrayList<>(checkpoints); - } - - @Override - public int getNumberOfRetainedCheckpoints() { - return checkpoints.size(); - } - - @Override - public int getMaxNumberOfRetainedCheckpoints() { - return 1; - } - - @Override - public boolean requiresExternalizedCheckpoints() { - return false; - } - } - static class MyCheckpointRecoveryFactory implements CheckpointRecoveryFactory { private final CompletedCheckpointStore store; diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/SharedStateRegistryTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/SharedStateRegistryTest.java new file mode 100644 index 0000000000000..13b815248c263 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/SharedStateRegistryTest.java @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.apache.flink.runtime.state; + +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +public class SharedStateRegistryTest { + + /** + * Validate that all states can be correctly registered at the registry. + */ + @Test + public void testRegistryNormal() { + SharedStateRegistry sharedStateRegistry = new SharedStateRegistry(); + + // register one state + TestState firstState = new TestState("first"); + sharedStateRegistry.register(firstState.getKey(), firstState); + assertEquals(1, sharedStateRegistry.getReferenceCount(firstState.getKey())); + + // register another state + TestState secondState = new TestState("second"); + sharedStateRegistry.register(secondState.getKey(), secondState); + assertEquals(1, sharedStateRegistry.getReferenceCount(secondState.getKey())); + + // register the first state again + sharedStateRegistry.register(firstState.getKey(), firstState); + assertEquals(2, sharedStateRegistry.getReferenceCount(firstState.getKey())); + + // unregister the second state + sharedStateRegistry.unregister(secondState.getKey()); + assertEquals(0, sharedStateRegistry.getReferenceCount(secondState.getKey())); + + // unregister the first state + sharedStateRegistry.unregister(firstState.getKey()); + assertEquals(1, sharedStateRegistry.getReferenceCount(firstState.getKey())); + } + + /** + * Validate that registering a key with different states will throw exception + */ + @Test(expected = IllegalStateException.class) + public void testRegisterWithInconsistentState() { + SharedStateRegistry sharedStateRegistry = new SharedStateRegistry(); + + // register one state + TestState state = new TestState("state"); + sharedStateRegistry.register("key", state); + assertEquals(1, sharedStateRegistry.getReferenceCount("key")); + + // register the state with another key + TestState anotherState = new TestState("anotherState"); + sharedStateRegistry.register("key", anotherState); + } + + /** + * Validate that unregister an unexisted key will throw exception + */ + @Test(expected = IllegalStateException.class) + public void testUnregisterWithUnexistedKey() { + SharedStateRegistry sharedStateRegistry = new SharedStateRegistry(); + + sharedStateRegistry.unregister("unexistedKey"); + } + + private static class TestState implements StateObject { + private static final long serialVersionUID = 4468635881465159780L; + + private String key; + + TestState(String key) { + this.key = key; + } + + public String getKey() { + return key; + } + + @Override + public void discardState() throws Exception { + // nothing to do + } + + @Override + public long getStateSize() { + return key.length(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + TestState testState = (TestState) o; + + return key.equals(testState.key); + } + + @Override + public int hashCode() { + return key.hashCode(); + } + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/testutils/RecoverableCompletedCheckpointStore.java b/flink-runtime/src/test/java/org/apache/flink/runtime/testutils/RecoverableCompletedCheckpointStore.java new file mode 100644 index 0000000000000..e0c915b747244 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/testutils/RecoverableCompletedCheckpointStore.java @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.testutils; + +import org.apache.flink.runtime.checkpoint.CompletedCheckpoint; +import org.apache.flink.runtime.checkpoint.CompletedCheckpointStore; +import org.apache.flink.runtime.jobgraph.JobStatus; +import org.apache.flink.runtime.state.SharedStateRegistry; +import org.apache.flink.runtime.state.StateObject; +import org.apache.flink.runtime.state.StateUtil; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.List; + +/** + * A checkpoint store, which supports shutdown and suspend. You can use this to test HA + * as long as the factory always returns the same store instance. + */ +public class RecoverableCompletedCheckpointStore implements CompletedCheckpointStore { + + private static final Logger LOG = LoggerFactory.getLogger(RecoverableCompletedCheckpointStore.class); + + private final ArrayDeque checkpoints = new ArrayDeque<>(2); + + private final ArrayDeque suspended = new ArrayDeque<>(2); + + @Override + public void recover() throws Exception { + checkpoints.addAll(suspended); + suspended.clear(); + } + + @Override + public void addCheckpoint(CompletedCheckpoint checkpoint, SharedStateRegistry sharedStateRegistry) throws Exception { + checkpoints.addLast(checkpoint); + + if (checkpoints.size() > 1) { + CompletedCheckpoint checkpointToSubsume = checkpoints.removeFirst(); + List unreferencedSharedStates = sharedStateRegistry.unregisterAll(checkpointToSubsume.getTaskStates().values()); + try { + StateUtil.bestEffortDiscardAllStateObjects(unreferencedSharedStates); + } catch (Exception e) { + LOG.warn("Could not properly discard unreferenced shared states.", e); + } + + checkpointToSubsume.subsume(); + } + } + + @Override + public CompletedCheckpoint getLatestCheckpoint() throws Exception { + return checkpoints.isEmpty() ? null : checkpoints.getLast(); + } + + @Override + public void shutdown(JobStatus jobStatus, SharedStateRegistry sharedStateRegistry) throws Exception { + if (jobStatus.isGloballyTerminalState()) { + checkpoints.clear(); + suspended.clear(); + } else { + suspended.clear(); + + for (CompletedCheckpoint checkpoint : checkpoints) { + List unreferencedSharedStates = sharedStateRegistry.unregisterAll(checkpoint.getTaskStates().values()); + try { + StateUtil.bestEffortDiscardAllStateObjects(unreferencedSharedStates); + } catch (Exception e) { + LOG.warn("Could not properly discard unreferenced shared states.", e); + } + + suspended.add(checkpoint); + } + + checkpoints.clear(); + } + } + + @Override + public List getAllCheckpoints() throws Exception { + return new ArrayList<>(checkpoints); + } + + @Override + public int getNumberOfRetainedCheckpoints() { + return checkpoints.size(); + } + + @Override + public int getMaxNumberOfRetainedCheckpoints() { + return 1; + } + + @Override + public boolean requiresExternalizedCheckpoints() { + return false; + } +} From a1db603590e5e4143b2cf37d6f3a67804a9a2028 Mon Sep 17 00:00:00 2001 From: "xiaogang.sxg" Date: Thu, 6 Apr 2017 16:55:50 +0800 Subject: [PATCH 4/4] Introduce SharedStateHandle as the base of the handles to shared states --- .../checkpoint/CheckpointCoordinator.java | 53 ++--- .../checkpoint/CompletedCheckpoint.java | 62 +++++- .../checkpoint/CompletedCheckpointStore.java | 4 +- .../runtime/checkpoint/PendingCheckpoint.java | 23 +- .../StandaloneCompletedCheckpointStore.java | 25 +-- .../runtime/checkpoint/SubtaskState.java | 9 +- .../flink/runtime/checkpoint/TaskState.java | 15 +- .../ZooKeeperCompletedCheckpointStore.java | 47 ++-- .../runtime/state/CompositeStateHandle.java | 29 ++- .../runtime/state/SharedStateHandle.java | 39 ++++ .../runtime/state/SharedStateRegistry.java | 202 ++++++++---------- .../CheckpointCoordinatorFailureTest.java | 4 +- .../checkpoint/CheckpointCoordinatorTest.java | 95 ++++---- .../CompletedCheckpointStoreTest.java | 42 +++- .../checkpoint/CompletedCheckpointTest.java | 24 ++- .../checkpoint/PendingCheckpointTest.java | 34 +-- ...tandaloneCompletedCheckpointStoreTest.java | 13 +- ...oKeeperCompletedCheckpointStoreITCase.java | 25 ++- ...ZooKeeperCompletedCheckpointStoreTest.java | 4 +- .../state/SharedStateRegistryTest.java | 57 ++--- .../RecoverableCompletedCheckpointStore.java | 25 +-- 21 files changed, 461 insertions(+), 370 deletions(-) create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateHandle.java diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java index e0ff1980ff61b..4caca844ab531 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java @@ -38,8 +38,6 @@ import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint; import org.apache.flink.runtime.messages.checkpoint.DeclineCheckpoint; import org.apache.flink.runtime.state.SharedStateRegistry; -import org.apache.flink.runtime.state.StateObject; -import org.apache.flink.runtime.state.StateUtil; import org.apache.flink.runtime.state.TaskStateHandles; import org.apache.flink.runtime.taskmanager.DispatcherThreadFactory; @@ -489,8 +487,7 @@ CheckpointTriggerResult triggerCheckpoint( ackTasks, props, targetDirectory, - executor, - sharedStateRegistry); + executor); if (statsTracker != null) { PendingCheckpointStats callback = statsTracker.reportPendingCheckpoint( @@ -718,8 +715,6 @@ public boolean receiveAcknowledgeMessage(AcknowledgeCheckpoint message) throws C switch (checkpoint.acknowledgeTask(message.getTaskExecutionId(), message.getSubtaskState(), message.getCheckpointMetrics())) { case SUCCESS: - sharedStateRegistry.registerAll(message.getSubtaskState()); - LOG.debug("Received acknowledge message for checkpoint {} from task {} of job {}.", checkpointId, message.getTaskExecutionId(), message.getJob()); @@ -732,8 +727,6 @@ public boolean receiveAcknowledgeMessage(AcknowledgeCheckpoint message) throws C message.getCheckpointId(), message.getTaskExecutionId(), message.getJob()); break; case UNKNOWN: - sharedStateRegistry.registerAll(message.getSubtaskState()); - LOG.warn("Could not acknowledge the checkpoint {} for task {} of job {}, " + "because the task's execution attempt id was unknown. Discarding " + "the state handle to avoid lingering state.", message.getCheckpointId(), @@ -743,8 +736,6 @@ public boolean receiveAcknowledgeMessage(AcknowledgeCheckpoint message) throws C break; case DISCARDED: - sharedStateRegistry.registerAll(message.getSubtaskState()); - LOG.warn("Could not acknowledge the checkpoint {} for task {} of job {}, " + "because the pending checkpoint had been discarded. Discarding the " + "state handle tp avoid lingering state.", @@ -763,8 +754,6 @@ else if (checkpoint != null) { else { boolean wasPendingCheckpoint; - sharedStateRegistry.registerAll(message.getSubtaskState()); - // message is for an unknown checkpoint, or comes too late (checkpoint disposed) if (recentPendingCheckpoints.contains(checkpointId)) { wasPendingCheckpoint = true; @@ -824,10 +813,8 @@ private void completePendingCheckpoint(PendingCheckpoint pendingCheckpoint) thro executor.execute(new Runnable() { @Override public void run() { - sharedStateRegistry.unregisterAll(completedCheckpoint.getTaskStates().values()); - try { - completedCheckpoint.discard(); + completedCheckpoint.discardOnFail(); } catch (Throwable t) { LOG.warn("Could not properly discard completed checkpoint {}.", completedCheckpoint.getCheckpointID(), t); } @@ -966,21 +953,12 @@ public boolean restoreLatestCheckpointedState( } // Recover the checkpoints - completedCheckpointStore.recover(); + completedCheckpointStore.recover(sharedStateRegistry); - // Recover the registry for shared states - CompletedCheckpoint latestCompletedCheckpoint = null; - List completedCheckpoints = completedCheckpointStore.getAllCheckpoints(); - for (CompletedCheckpoint completedCheckpoint : completedCheckpoints) { - sharedStateRegistry.registerAll(completedCheckpoint.getTaskStates().values()); + // restore from the latest checkpoint + CompletedCheckpoint latest = completedCheckpointStore.getLatestCheckpoint(); - if (latestCompletedCheckpoint == null || - latestCompletedCheckpoint.getCheckpointID() > completedCheckpoint.getCheckpointID()) { - latestCompletedCheckpoint = completedCheckpoint; - } - } - - if (latestCompletedCheckpoint == null) { + if (latest == null) { if (errorIfNoCheckpoint) { throw new IllegalStateException("No completed checkpoint available"); } else { @@ -988,9 +966,9 @@ public boolean restoreLatestCheckpointedState( } } - LOG.info("Restoring from latest valid checkpoint: {}.", latestCompletedCheckpoint); + LOG.info("Restoring from latest valid checkpoint: {}.", latest); - final Map taskStates = latestCompletedCheckpoint.getTaskStates(); + final Map taskStates = latest.getTaskStates(); StateAssignmentOperation stateAssignmentOperation = new StateAssignmentOperation(LOG, tasks, taskStates, allowNonRestoredState); @@ -1000,10 +978,10 @@ public boolean restoreLatestCheckpointedState( if (statsTracker != null) { long restoreTimestamp = System.currentTimeMillis(); RestoredCheckpointStats restored = new RestoredCheckpointStats( - latestCompletedCheckpoint.getCheckpointID(), - latestCompletedCheckpoint.getProperties(), + latest.getCheckpointID(), + latest.getProperties(), restoreTimestamp, - latestCompletedCheckpoint.getExternalPointer()); + latest.getExternalPointer()); statsTracker.reportRestoredCheckpoint(restored); } @@ -1183,16 +1161,13 @@ private void discardState( executor.execute(new Runnable() { @Override public void run() { - List discardedSharedStates = sharedStateRegistry.unregisterAll(subtaskState); - try { - StateUtil.bestEffortDiscardAllStateObjects(discardedSharedStates); + subtaskState.discardSharedStatesOnFail(); } catch (Throwable t1) { LOG.warn("Could not properly discard shared states of checkpoint {} " + - "belonging to task {} of job {}.", checkpointId, executionAttemptID, jobId, t1 - ); + "belonging to task {} of job {}.", checkpointId, executionAttemptID, jobId, t1); } - + try { subtaskState.discardState(); } catch (Throwable t2) { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpoint.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpoint.java index d1b182fd38967..58e91e1e0c991 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpoint.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpoint.java @@ -23,10 +23,12 @@ import org.apache.flink.runtime.checkpoint.CompletedCheckpointStats.DiscardCallback; import org.apache.flink.runtime.jobgraph.JobStatus; import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.state.SharedStateRegistry; import org.apache.flink.runtime.state.StateUtil; import org.apache.flink.runtime.state.StreamStateHandle; import org.apache.flink.util.ExceptionUtils; +import org.apache.flink.util.Preconditions; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -184,22 +186,30 @@ public CheckpointProperties getProperties() { return props; } - public boolean subsume() throws Exception { + public void discardOnFail() throws Exception { + discard(null, true); + } + + public boolean discardOnSubsume(SharedStateRegistry sharedStateRegistry) throws Exception { + Preconditions.checkNotNull(sharedStateRegistry, "The registry cannot be null."); + if (props.discardOnSubsumed()) { - discard(); + discard(sharedStateRegistry, false); return true; } return false; } - public boolean discard(JobStatus jobStatus) throws Exception { + public boolean discardOnShutdown(JobStatus jobStatus, SharedStateRegistry sharedStateRegistry) throws Exception { + Preconditions.checkNotNull(sharedStateRegistry, "The registry cannot be null."); + if (jobStatus == JobStatus.FINISHED && props.discardOnJobFinished() || jobStatus == JobStatus.CANCELED && props.discardOnJobCancelled() || jobStatus == JobStatus.FAILED && props.discardOnJobFailed() || jobStatus == JobStatus.SUSPENDED && props.discardOnJobSuspended()) { - discard(); + discard(sharedStateRegistry, false); return true; } else { if (externalPointer != null) { @@ -211,7 +221,10 @@ public boolean discard(JobStatus jobStatus) throws Exception { } } - void discard() throws Exception { + private void discard(SharedStateRegistry sharedStateRegistry, boolean failed) throws Exception { + Preconditions.checkState(failed || (sharedStateRegistry != null), + "The registry must not be null if the complete checkpoint does not fail."); + try { // collect exceptions and continue cleanup Exception exception = null; @@ -225,6 +238,15 @@ void discard() throws Exception { } } + // In the cases where the completed checkpoint fails, the shared + // states have not been registered to the registry. It's the state + // handles' responsibility to discard their shared states. + if (!failed) { + unregisterSharedStates(sharedStateRegistry); + } else { + discardSharedStatesOnFail(); + } + // discard private state objects try { StateUtil.bestEffortDiscardAllStateObjects(taskStates.values()); @@ -287,6 +309,36 @@ void setDiscardCallback(@Nullable CompletedCheckpointStats.DiscardCallback disca this.discardCallback = discardCallback; } + /** + * Register all shared states in the given registry. This is method is called + * when the completed checkpoint has been successfully added into the store. + * + * @param sharedStateRegistry The registry where shared states are registered + */ + public void registerSharedStates(SharedStateRegistry sharedStateRegistry) { + sharedStateRegistry.registerAll(taskStates.values()); + } + + /** + * Unregister all shared states from the given registry. This is method is + * called when the completed checkpoint is subsumed or the job terminates. + * + * @param sharedStateRegistry The registry where shared states are registered + */ + private void unregisterSharedStates(SharedStateRegistry sharedStateRegistry) { + sharedStateRegistry.unregisterAll(taskStates.values()); + } + + /** + * Discard all shared states created in the checkpoint. This method is called + * when the completed checkpoint fails to be added into the store. + */ + private void discardSharedStatesOnFail() throws Exception { + for (TaskState taskState : taskStates.values()) { + taskState.discardSharedStatesOnFail(); + } + } + // -------------------------------------------------------------------------------------------- @Override diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStore.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStore.java index 27d4d686845e8..0ade25c849328 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStore.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStore.java @@ -34,14 +34,14 @@ public interface CompletedCheckpointStore { *

After a call to this method, {@link #getLatestCheckpoint()} returns the latest * available checkpoint. */ - void recover() throws Exception; + void recover(SharedStateRegistry sharedStateRegistry) throws Exception; /** * Adds a {@link CompletedCheckpoint} instance to the list of completed checkpoints. * *

Only a bounded number of checkpoints is kept. When exceeding the maximum number of * retained checkpoints, the oldest one will be discarded via {@link - * CompletedCheckpoint#subsume()} )}. + * CompletedCheckpoint#discardOnSubsume(SharedStateRegistry)} )}. */ void addCheckpoint(CompletedCheckpoint checkpoint, SharedStateRegistry sharedStateRegistry) throws Exception; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java index 34df45fc3b942..36b5d3f3a65f3 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java @@ -29,8 +29,6 @@ import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.state.ChainedStateHandle; import org.apache.flink.runtime.state.OperatorStateHandle; -import org.apache.flink.runtime.state.SharedStateRegistry; -import org.apache.flink.runtime.state.StateObject; import org.apache.flink.runtime.state.StateUtil; import org.apache.flink.runtime.state.StreamStateHandle; import org.apache.flink.runtime.state.filesystem.FileStateHandle; @@ -45,7 +43,6 @@ import java.io.IOException; import java.util.HashMap; import java.util.HashSet; -import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.Executor; @@ -95,9 +92,6 @@ public class PendingCheckpoint { /** The executor for potentially blocking I/O operations, like state disposal */ private final Executor executor; - /** The registry where shared states are registered */ - private final SharedStateRegistry sharedStateRegistry; - private int numAcknowledgedTasks; private boolean discarded; @@ -117,8 +111,7 @@ public PendingCheckpoint( Map verticesToConfirm, CheckpointProperties props, String targetDirectory, - Executor executor, - SharedStateRegistry sharedStateRegistry) { + Executor executor) { // Sanity check if (props.externalizeCheckpoint() && targetDirectory == null) { @@ -135,7 +128,6 @@ public PendingCheckpoint( this.props = checkNotNull(props); this.targetDirectory = targetDirectory; this.executor = Preconditions.checkNotNull(executor); - this.sharedStateRegistry = Preconditions.checkNotNull(sharedStateRegistry); this.taskStates = new HashMap<>(); this.acknowledgedTasks = new HashSet<>(verticesToConfirm.size()); @@ -507,14 +499,17 @@ private void dispose(boolean releaseState) { executor.execute(new Runnable() { @Override public void run() { - List unreferencedSharedStates = sharedStateRegistry.unregisterAll(taskStates.values()); - try { - StateUtil.bestEffortDiscardAllStateObjects(unreferencedSharedStates); - } catch (Throwable t) { - LOG.warn("Could not properly dispose unreferenced shared states."); + // discard the shared states that are created in the checkpoint + for (TaskState taskState : taskStates.values()) { + try { + taskState.discardSharedStatesOnFail(); + } catch (Throwable t) { + LOG.warn("Could not properly dispose unreferenced shared states."); + } } + // discard the private states try { StateUtil.bestEffortDiscardAllStateObjects(taskStates.values()); } catch (Throwable t) { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStore.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStore.java index 65faef207b2d1..9f833c31107f6 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStore.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStore.java @@ -21,8 +21,6 @@ import org.apache.flink.runtime.jobgraph.JobStatus; import org.apache.flink.runtime.jobmanager.HighAvailabilityMode; import org.apache.flink.runtime.state.SharedStateRegistry; -import org.apache.flink.runtime.state.StateObject; -import org.apache.flink.runtime.state.StateUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -59,7 +57,7 @@ public StandaloneCompletedCheckpointStore(int maxNumberOfCheckpointsToRetain) { } @Override - public void recover() throws Exception { + public void recover(SharedStateRegistry sharedStateRegistry) throws Exception { // Nothing to do } @@ -68,18 +66,12 @@ public void addCheckpoint(CompletedCheckpoint checkpoint, SharedStateRegistry sh checkpoints.addLast(checkpoint); + checkpoint.registerSharedStates(sharedStateRegistry); + if (checkpoints.size() > maxNumberOfCheckpointsToRetain) { try { CompletedCheckpoint checkpointToSubsume = checkpoints.removeFirst(); - - List unreferencedSharedStates = sharedStateRegistry.unregisterAll(checkpointToSubsume.getTaskStates().values()); - try { - StateUtil.bestEffortDiscardAllStateObjects(unreferencedSharedStates); - } catch (Exception e) { - LOG.warn("Could not properly discard unreferenced shared states.", e); - } - - checkpointToSubsume.subsume(); + checkpointToSubsume.discardOnSubsume(sharedStateRegistry); } catch (Exception e) { LOG.warn("Fail to subsume the old checkpoint.", e); } @@ -112,14 +104,7 @@ public void shutdown(JobStatus jobStatus, SharedStateRegistry sharedStateRegistr LOG.info("Shutting down"); for (CompletedCheckpoint checkpoint : checkpoints) { - List unreferencedSharedStates = sharedStateRegistry.unregisterAll(checkpoint.getTaskStates().values()); - try { - StateUtil.bestEffortDiscardAllStateObjects(unreferencedSharedStates); - } catch (Exception e) { - LOG.warn("Could not properly discard unreferenced shared states.", e); - } - - checkpoint.discard(jobStatus); + checkpoint.discardOnShutdown(jobStatus, sharedStateRegistry); } } finally { checkpoints.clear(); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java index 487b191fed9ff..9f11656524e46 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java @@ -144,12 +144,17 @@ public void discardState() { } @Override - public void register(SharedStateRegistry sharedStateRegistry) { + public void registerSharedStates(SharedStateRegistry sharedStateRegistry) { // No shared states } @Override - public void unregister(SharedStateRegistry sharedStateRegistry) { + public void unregisterSharedStates(SharedStateRegistry sharedStateRegistry) { + // No shared states + } + + @Override + public void discardSharedStatesOnFail() { // No shared states } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskState.java index dd86c87602b47..19fe9624746be 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskState.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskState.java @@ -130,16 +130,23 @@ public void discardState() throws Exception { } @Override - public void register(SharedStateRegistry sharedStateRegistry) { + public void registerSharedStates(SharedStateRegistry sharedStateRegistry) { for (SubtaskState subtaskState : subtaskStates.values()) { - subtaskState.register(sharedStateRegistry); + subtaskState.registerSharedStates(sharedStateRegistry); } } @Override - public void unregister(SharedStateRegistry sharedStateRegistry) { + public void unregisterSharedStates(SharedStateRegistry sharedStateRegistry) { for (SubtaskState subtaskState : subtaskStates.values()) { - subtaskState.unregister(sharedStateRegistry); + subtaskState.unregisterSharedStates(sharedStateRegistry); + } + } + + @Override + public void discardSharedStatesOnFail() { + for (SubtaskState subtaskState : subtaskStates.values()) { + subtaskState.discardSharedStatesOnFail(); } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStore.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStore.java index 33da79baf6901..07546ea4227ff 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStore.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStore.java @@ -28,8 +28,6 @@ import org.apache.flink.runtime.jobmanager.HighAvailabilityMode; import org.apache.flink.runtime.state.RetrievableStateHandle; import org.apache.flink.runtime.state.SharedStateRegistry; -import org.apache.flink.runtime.state.StateObject; -import org.apache.flink.runtime.state.StateUtil; import org.apache.flink.runtime.zookeeper.RetrievableStateStorageHelper; import org.apache.flink.runtime.zookeeper.ZooKeeperStateHandleStore; import org.apache.flink.util.FlinkException; @@ -143,7 +141,7 @@ public boolean requiresExternalizedCheckpoints() { * that the history of checkpoints is consistent. */ @Override - public void recover() throws Exception { + public void recover(SharedStateRegistry sharedStateRegistry) throws Exception { LOG.info("Recovering checkpoints from ZooKeeper."); // Clear local handles in order to prevent duplicates on @@ -167,8 +165,24 @@ public void recover() throws Exception { LOG.info("Found {} checkpoints in ZooKeeper.", numberOfInitialCheckpoints); - for (Tuple2, String> checkpoint : initialCheckpoints) { - checkpointStateHandles.add(checkpoint); + for (Tuple2, String> checkpointStateHandle : initialCheckpoints) { + + CompletedCheckpoint completedCheckpoint = null; + + try { + completedCheckpoint = retrieveCompletedCheckpoint(checkpointStateHandle); + } catch (Exception e) { + LOG.warn("Could not retrieve checkpoint. Removing it from the completed " + + "checkpoint store.", e); + + // remove the checkpoint with broken state handle + removeBrokenStateHandle(checkpointStateHandle); + } + + if (completedCheckpoint != null) { + completedCheckpoint.registerSharedStates(sharedStateRegistry); + checkpointStateHandles.add(checkpointStateHandle); + } } } @@ -189,6 +203,9 @@ public void addCheckpoint(final CompletedCheckpoint checkpoint, final SharedStat checkpointStateHandles.addLast(new Tuple2<>(stateHandle, path)); + // Register all shared states in the checkpoint + checkpoint.registerSharedStates(sharedStateRegistry); + // Everything worked, let's remove a previous checkpoint if necessary. while (checkpointStateHandles.size() > maxNumberOfCheckpointsToRetain) { try { @@ -302,15 +319,7 @@ public Void call() throws Exception { CompletedCheckpoint completedCheckpoint = retrieveCompletedCheckpoint(stateHandleAndPath); if (completedCheckpoint != null) { - List unreferencedSharedStates = sharedStateRegistry.unregisterAll(completedCheckpoint.getTaskStates().values()); - - try { - StateUtil.bestEffortDiscardAllStateObjects(unreferencedSharedStates); - } catch (Throwable t) { - LOG.warn("Could not properly discard unreferenced shared states.", t); - } - - completedCheckpoint.subsume(); + completedCheckpoint.discardOnSubsume(sharedStateRegistry); } return null; @@ -331,15 +340,7 @@ public Void call() throws Exception { CompletedCheckpoint completedCheckpoint = retrieveCompletedCheckpoint(stateHandleAndPath); if (completedCheckpoint != null) { - List unreferencedSharedStates = sharedStateRegistry.unregisterAll(completedCheckpoint.getTaskStates().values()); - - try { - StateUtil.bestEffortDiscardAllStateObjects(unreferencedSharedStates); - } catch (Throwable t) { - LOG.warn("Could not properly discard unreferenced shared states.", t); - } - - completedCheckpoint.discard(jobStatus); + completedCheckpoint.discardOnShutdown(jobStatus, sharedStateRegistry); } return null; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/CompositeStateHandle.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/CompositeStateHandle.java index 4c528659372b5..2ea5bc9080e12 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/CompositeStateHandle.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/CompositeStateHandle.java @@ -24,30 +24,37 @@ * *

Each snapshot is composed of a collection of {@link StateObject}s some of * which may be referenced by other checkpoints. The shared states will be - * registered at the given {@link SharedStateRegistry} when the handle is + * registered at the given {@link SharedStateRegistry} when the handle is * received by the {@link org.apache.flink.runtime.checkpoint.CheckpointCoordinator} * and will be discarded when the checkpoint is discarded. * - *

The {@link SharedStateRegistry} is responsible for the discarding of the + *

The {@link SharedStateRegistry} is responsible for the discarding of the * shared states. The composite state handle should only delete those private * states in the {@link StateObject#discardState()} method. */ public interface CompositeStateHandle extends StateObject { /** - * Register shared states in the given {@link SharedStateRegistry}. This - * method is called when the state handle is received by the - * {@link org.apache.flink.runtime.checkpoint.CheckpointCoordinator}. - * + * Register both created and referenced shared states in the given + * {@link SharedStateRegistry}. This method is called when the checkpoint + * successfully completes or is recovered from failures. + * * @param stateRegistry The registry where shared states are registered. */ - void register(SharedStateRegistry stateRegistry); + void registerSharedStates(SharedStateRegistry stateRegistry); /** - * Unregister shared states in the given {@link SharedStateRegistry}. This - * method is called when the state handle is discarded. - * + * Unregister both created and referenced shared states in the given + * {@link SharedStateRegistry}. This method is called when the checkpoint is + * subsumed or the job is shut down. + * * @param stateRegistry The registry where shared states are registered. */ - void unregister(SharedStateRegistry stateRegistry); + void unregisterSharedStates(SharedStateRegistry stateRegistry); + + /** + * Discard all shared states created in this checkpoint. This method is + * called when the checkpoint fails to complete. + */ + void discardSharedStatesOnFail() throws Exception; } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateHandle.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateHandle.java new file mode 100644 index 0000000000000..f856052904952 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateHandle.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state; + +/** + * A handle to those states that are referenced by different checkpoints. + * + *

Each shared state handle is identified by a unique key. Two shared states + * are considered equal if their keys are identical. + * + *

All shared states are registered at the {@link SharedStateRegistry} once + * they are received by the {@link org.apache.flink.runtime.checkpoint.CheckpointCoordinator} + * and will be unregistered when the checkpoints are discarded. A shared state + * will be discarded once it is not referenced by any checkpoint. A shared state + * should not be referenced any more if it has been discarded. + */ +public interface SharedStateHandle extends StateObject { + + /** + * Return the identifier of the shared state. + */ + String getKey(); +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateRegistry.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateRegistry.java index 6bef231aa0ba5..b5048d0d96208 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateRegistry.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateRegistry.java @@ -19,174 +19,160 @@ package org.apache.flink.runtime.state; import org.apache.flink.annotation.VisibleForTesting; -import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.util.Preconditions; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.io.Serializable; -import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; -import java.util.List; import java.util.Map; /** * A {@code SharedStateRegistry} will be deployed in the * {@link org.apache.flink.runtime.checkpoint.CheckpointCoordinator} to - * maintain the reference count of those state objects shared among different - * checkpoints. Each shared state object must be identified by a unique key. + * maintain the reference count of {@link SharedStateHandle}s which are shared + * among different checkpoints. */ public class SharedStateRegistry implements Serializable { + private static Logger LOG = LoggerFactory.getLogger(SharedStateRegistry.class); + private static final long serialVersionUID = -8357254413007773970L; /** All registered state objects */ - private final Map> registeredStates = new HashMap<>(); - - /** All state objects that are not referenced any more */ - private transient final List discardedStates = new ArrayList<>(); + private final Map registeredStates = new HashMap<>(); /** * Register the state in the registry * - * @param key The key of the state to register * @param state The state to register + * @param isNew True if the shared state is newly created */ - public void register(String key, StateObject state) { - Tuple2 stateAndRefCnt = registeredStates.get(key); - - if (stateAndRefCnt == null) { - registeredStates.put(key, new Tuple2<>(state, 1)); - } else { - if (!stateAndRefCnt.f0.equals(state)) { - throw new IllegalStateException("Cannot register a key with different states."); - } + public void register(SharedStateHandle state, boolean isNew) { + if (state == null) { + return; + } + + synchronized (registeredStates) { + SharedStateEntry entry = registeredStates.get(state.getKey()); + + if (isNew) { + Preconditions.checkState(entry == null, + "The state cannot be created more than once."); - stateAndRefCnt.f1++; + registeredStates.put(state.getKey(), new SharedStateEntry(state)); + } else { + Preconditions.checkState(entry != null, + "The state cannot be referenced if it has not been created yet."); + + entry.increaseReferenceCount(); + } } } /** - * Decrease the reference count of the state in the registry + * Unregister the state in the registry * - * @param key The key of the state to unregister + * @param state The state to unregister */ - public void unregister(String key) { - Tuple2 stateAndRefCnt = registeredStates.get(key); - - if (stateAndRefCnt == null) { - throw new IllegalStateException("Cannot unregister an unexisted state."); + public void unregister(SharedStateHandle state) { + if (state == null) { + return; } - stateAndRefCnt.f1--; + synchronized (registeredStates) { + SharedStateEntry entry = registeredStates.get(state.getKey()); - // Remove the state from the registry when it's not referenced any more. - if (stateAndRefCnt.f1 == 0) { - registeredStates.remove(key); - discardedStates.add(stateAndRefCnt.f0); - } - } + if (entry == null) { + throw new IllegalStateException("Cannot unregister an unexisted state."); + } - /** - * Register all the shared states in the given state handles. - * - * @param stateHandles The state handles to register their shared states - */ - public void registerAll(Collection stateHandles) { - synchronized (this) { - if (stateHandles != null) { - for (CompositeStateHandle stateHandle : stateHandles) { - stateHandle.register(this); + entry.decreaseReferenceCount(); + + // Remove the state from the registry when it's not referenced any more. + if (entry.getReferenceCount() == 0) { + registeredStates.remove(state.getKey()); + + try { + entry.getState().discardState(); + } catch (Exception e) { + LOG.warn("Cannot properly discard the state " + entry.getState() + ".", e); } } } } - - /** - * Register all the shared states in the given state handle. - * - * @param stateHandle The state handle to register its shared states - */ - public void registerAll(CompositeStateHandle stateHandle) { - if (stateHandle != null) { - synchronized (this) { - stateHandle.register(this); - } - } - } /** - * Unregister all the shared states in the given state handles and return - * those unreferenced states after these shared states are unregistered. - * - * @param stateHandles The state handles to unregister their shared states - * @return The states that are not referenced any more + * Register given shared states in the registry. + * + * @param stateHandles The shared states to register. */ - public List unregisterAll(Collection stateHandles) { - synchronized (this) { - discardedStates.clear(); + public void registerAll(Collection stateHandles) { + if (stateHandles == null) { + return; + } - if (stateHandles != null) { - for (CompositeStateHandle stateHandle : stateHandles) { - stateHandle.unregister(this); - } + synchronized (registeredStates) { + for (CompositeStateHandle stateHandle : stateHandles) { + stateHandle.registerSharedStates(this); } - - return discardedStates; } - } + + /** - * Unregister all the shared states in the given state handles and return - * those unreferenced states after these shared states are unregistered. + * Unregister all the shared states referenced by the given. * - * @param stateHandle The state handle to unregister its shared states - * @return The states that are not referenced any more + * @param stateHandles The shared states to unregister. */ - public List unregisterAll(CompositeStateHandle stateHandle) { - if (stateHandle == null) { - return null; - } else { - synchronized (this) { - discardedStates.clear(); - - stateHandle.unregister(this); + public void unregisterAll(Collection stateHandles) { + if (stateHandles == null) { + return; + } - return discardedStates; + synchronized (registeredStates) { + for (CompositeStateHandle stateHandle : stateHandles) { + stateHandle.unregisterSharedStates(this); } } } - @VisibleForTesting - int getReferenceCount(String key) { - Tuple2 stateAndRefCnt = registeredStates.get(key); + private static class SharedStateEntry { + /** The shared object */ + private final SharedStateHandle state; - return stateAndRefCnt == null ? 0 : stateAndRefCnt.f1; - } + /** The reference count of the object */ + private int referenceCount; - @Override - public boolean equals(Object o) { - if (this == o) { - return true; + SharedStateEntry(SharedStateHandle value) { + this.state = value; + this.referenceCount = 1; } - if (o == null || getClass() != o.getClass()) { - return false; + + SharedStateHandle getState() { + return state; } - SharedStateRegistry that = (SharedStateRegistry) o; + int getReferenceCount() { + return referenceCount; + } - return registeredStates.equals(that.registeredStates); - } + void increaseReferenceCount() { + ++referenceCount; + } - @Override - public int hashCode() { - int result = registeredStates.hashCode(); - result = 31 * result + discardedStates.hashCode(); - return result; + void decreaseReferenceCount() { + --referenceCount; + } } - @Override - public String toString() { - return "SharedStateRegistry{" + "registeredStates=" + registeredStates + - ", discardedStates=" + discardedStates + '}'; + + @VisibleForTesting + public int getReferenceCount(SharedStateHandle state) { + SharedStateEntry entry = registeredStates.get(state.getKey()); + + return entry == null ? 0 : entry.getReferenceCount(); } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java index 853ba9cbc18b7..632f2c0678a3d 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java @@ -105,14 +105,14 @@ public void testFailingCompletedCheckpointStoreAdd() throws Exception { assertTrue(pendingCheckpoint.isDiscarded()); // make sure that the subtask state has been discarded after we could not complete it. - verify(subtaskState, times(1)).unregister(any(SharedStateRegistry.class)); + verify(subtaskState, times(1)).discardSharedStatesOnFail(); verify(subtaskState).discardState(); } private static final class FailingCompletedCheckpointStore implements CompletedCheckpointStore { @Override - public void recover() throws Exception { + public void recover(SharedStateRegistry sharedStateRegistry) throws Exception { throw new UnsupportedOperationException("Not implemented."); } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java index d806ee2ccb44e..fabf3fc33ff8c 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java @@ -558,18 +558,17 @@ public void testTriggerAndConfirmSimpleCheckpoint() { assertEquals(1, checkpoint.getNumberOfNonAcknowledgedTasks()); assertFalse(checkpoint.isDiscarded()); assertFalse(checkpoint.isFullyAcknowledged()); - verify(subtaskState2, times(1)).register(any(SharedStateRegistry.class)); + verify(subtaskState2, never()).registerSharedStates(any(SharedStateRegistry.class)); // acknowledge the same task again (should not matter) coord.receiveAcknowledgeMessage(acknowledgeCheckpoint1); assertFalse(checkpoint.isDiscarded()); assertFalse(checkpoint.isFullyAcknowledged()); - verify(subtaskState2, times(1)).register(any(SharedStateRegistry.class)); + verify(subtaskState2, never()).registerSharedStates(any(SharedStateRegistry.class)); // acknowledge the other task. SubtaskState subtaskState1 = mock(SubtaskState.class); coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID1, checkpointId, new CheckpointMetrics(), subtaskState1)); - verify(subtaskState1, times(1)).register(any(SharedStateRegistry.class)); // the checkpoint is internally converted to a successful checkpoint and the // pending checkpoint object is disposed @@ -582,10 +581,10 @@ public void testTriggerAndConfirmSimpleCheckpoint() { // the canceler should be removed now assertEquals(0, coord.getNumScheduledTasks()); - // validate that the subtasks states have not unregistered their shared states. + // validate that the subtasks states have registered their shared states. { - verify(subtaskState1, never()).unregister(any(SharedStateRegistry.class)); - verify(subtaskState2, never()).unregister(any(SharedStateRegistry.class)); + verify(subtaskState1, times(1)).registerSharedStates(any(SharedStateRegistry.class)); + verify(subtaskState2, times(1)).registerSharedStates(any(SharedStateRegistry.class)); } // validate that the relevant tasks got a confirmation message @@ -622,8 +621,8 @@ public void testTriggerAndConfirmSimpleCheckpoint() { // validate that the subtask states in old savepoint have unregister their shared states { - verify(subtaskState1, times(1)).unregister(any(SharedStateRegistry.class)); - verify(subtaskState2, times(1)).unregister(any(SharedStateRegistry.class)); + verify(subtaskState1, times(1)).unregisterSharedStates(any(SharedStateRegistry.class)); + verify(subtaskState2, times(1)).unregisterSharedStates(any(SharedStateRegistry.class)); } // validate that the relevant tasks got a confirmation message @@ -835,7 +834,6 @@ public void testSuccessfulCheckpointSubsumesUnsuccessful() { // acknowledge one of the three tasks SubtaskState subtaskState1_2 = mock(SubtaskState.class); coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID2, checkpointId1, new CheckpointMetrics(), subtaskState1_2)); - verify(subtaskState1_2, times(1)).register(any(SharedStateRegistry.class)); // start the second checkpoint // trigger the first checkpoint. this should succeed @@ -862,19 +860,15 @@ public void testSuccessfulCheckpointSubsumesUnsuccessful() { SubtaskState subtaskState2_3 = mock(SubtaskState.class); coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID3, checkpointId2, new CheckpointMetrics(), subtaskState2_3)); - verify(subtaskState2_3, times(1)).register(any(SharedStateRegistry.class)); SubtaskState subtaskState2_1 = mock(SubtaskState.class); coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID1, checkpointId2, new CheckpointMetrics(), subtaskState2_1)); - verify(subtaskState2_1, times(1)).register(any(SharedStateRegistry.class)); SubtaskState subtaskState1_1 = mock(SubtaskState.class); coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID1, checkpointId1, new CheckpointMetrics(), subtaskState1_1)); - verify(subtaskState1_1, times(1)).register(any(SharedStateRegistry.class)); SubtaskState subtaskState2_2 = mock(SubtaskState.class); coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID2, checkpointId2, new CheckpointMetrics(), subtaskState2_2)); - verify(subtaskState2_2, times(1)).register(any(SharedStateRegistry.class)); // now, the second checkpoint should be confirmed, and the first discarded // actually both pending checkpoints are discarded, and the second has been transformed @@ -886,15 +880,15 @@ public void testSuccessfulCheckpointSubsumesUnsuccessful() { assertEquals(1, coord.getNumberOfRetainedSuccessfulCheckpoints()); // validate that all received subtask states in the first checkpoint have been discarded - verify(subtaskState1_1, times(1)).unregister(any(SharedStateRegistry.class)); - verify(subtaskState1_2, times(1)).unregister(any(SharedStateRegistry.class)); + verify(subtaskState1_1, times(1)).discardSharedStatesOnFail(); + verify(subtaskState1_2, times(1)).discardSharedStatesOnFail(); verify(subtaskState1_1, times(1)).discardState(); verify(subtaskState1_2, times(1)).discardState(); // validate that all subtask states in the second checkpoint are not discarded - verify(subtaskState2_1, never()).unregister(any(SharedStateRegistry.class)); - verify(subtaskState2_2, never()).unregister(any(SharedStateRegistry.class)); - verify(subtaskState2_3, never()).unregister(any(SharedStateRegistry.class)); + verify(subtaskState2_1, never()).unregisterSharedStates(any(SharedStateRegistry.class)); + verify(subtaskState2_2, never()).unregisterSharedStates(any(SharedStateRegistry.class)); + verify(subtaskState2_3, never()).unregisterSharedStates(any(SharedStateRegistry.class)); verify(subtaskState2_1, never()).discardState(); verify(subtaskState2_2, never()).discardState(); verify(subtaskState2_3, never()).discardState(); @@ -913,16 +907,15 @@ public void testSuccessfulCheckpointSubsumesUnsuccessful() { // send the last remaining ack for the first checkpoint. This should not do anything SubtaskState subtaskState1_3 = mock(SubtaskState.class); coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID3, checkpointId1, new CheckpointMetrics(), subtaskState1_3)); - verify(subtaskState1_3, times(1)).register(any(SharedStateRegistry.class)); - verify(subtaskState1_3, times(1)).unregister(any(SharedStateRegistry.class)); + verify(subtaskState1_3, times(1)).discardSharedStatesOnFail(); verify(subtaskState1_3, times(1)).discardState(); coord.shutdown(JobStatus.FINISHED); // validate that the states in the second checkpoint have been discarded - verify(subtaskState2_1, times(1)).unregister(any(SharedStateRegistry.class)); - verify(subtaskState2_2, times(1)).unregister(any(SharedStateRegistry.class)); - verify(subtaskState2_3, times(1)).unregister(any(SharedStateRegistry.class)); + verify(subtaskState2_1, times(1)).unregisterSharedStates(any(SharedStateRegistry.class)); + verify(subtaskState2_2, times(1)).unregisterSharedStates(any(SharedStateRegistry.class)); + verify(subtaskState2_3, times(1)).unregisterSharedStates(any(SharedStateRegistry.class)); verify(subtaskState2_1, times(1)).discardState(); verify(subtaskState2_2, times(1)).discardState(); verify(subtaskState2_3, times(1)).discardState(); @@ -983,7 +976,6 @@ public void testCheckpointTimeoutIsolated() { SubtaskState subtaskState = mock(SubtaskState.class); coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID1, checkpoint.getCheckpointId(), new CheckpointMetrics(), subtaskState)); - verify(subtaskState, times(1)).register(any(SharedStateRegistry.class)); // wait until the checkpoint must have expired. // we check every 250 msecs conservatively for 5 seconds @@ -1001,7 +993,7 @@ public void testCheckpointTimeoutIsolated() { assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints()); // validate that the received states have been discarded - verify(subtaskState, times(1)).unregister(any(SharedStateRegistry.class)); + verify(subtaskState, times(1)).discardSharedStatesOnFail(); verify(subtaskState, times(1)).discardState(); // no confirm message must have been sent @@ -1125,8 +1117,7 @@ public void testStateCleanupForLateOrUnknownMessages() throws Exception { coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jobId, triggerAttemptId, checkpointId, new CheckpointMetrics(), triggerSubtaskState)); // verify that the subtask state has registered its shared states at the registry - verify(triggerSubtaskState, times(1)).register(any(SharedStateRegistry.class)); - verify(triggerSubtaskState, never()).unregister(any(SharedStateRegistry.class)); + verify(triggerSubtaskState, never()).discardSharedStatesOnFail(); verify(triggerSubtaskState, never()).discardState(); SubtaskState unknownSubtaskState = mock(SubtaskState.class); @@ -1135,8 +1126,7 @@ public void testStateCleanupForLateOrUnknownMessages() throws Exception { coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jobId, new ExecutionAttemptID(), checkpointId, new CheckpointMetrics(), unknownSubtaskState)); // we should discard acknowledge messages from an unknown vertex belonging to our job - verify(unknownSubtaskState, times(1)).register(any(SharedStateRegistry.class)); - verify(unknownSubtaskState, times(1)).unregister(any(SharedStateRegistry.class)); + verify(unknownSubtaskState, times(1)).discardSharedStatesOnFail(); verify(unknownSubtaskState, times(1)).discardState(); SubtaskState differentJobSubtaskState = mock(SubtaskState.class); @@ -1145,8 +1135,7 @@ public void testStateCleanupForLateOrUnknownMessages() throws Exception { coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(new JobID(), new ExecutionAttemptID(), checkpointId, new CheckpointMetrics(), differentJobSubtaskState)); // we should not interfere with different jobs - verify(differentJobSubtaskState, never()).register(any(SharedStateRegistry.class)); - verify(differentJobSubtaskState, never()).unregister(any(SharedStateRegistry.class)); + verify(differentJobSubtaskState, never()).discardSharedStatesOnFail(); verify(differentJobSubtaskState, never()).discardState(); // duplicate acknowledge message for the trigger vertex @@ -1154,8 +1143,7 @@ public void testStateCleanupForLateOrUnknownMessages() throws Exception { coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jobId, triggerAttemptId, checkpointId, new CheckpointMetrics(), triggerSubtaskState)); // duplicate acknowledge messages for a known vertex should not trigger discarding the state - verify(triggerSubtaskState, never()).register(any(SharedStateRegistry.class)); - verify(triggerSubtaskState, never()).unregister(any(SharedStateRegistry.class)); + verify(triggerSubtaskState, never()).discardSharedStatesOnFail(); verify(triggerSubtaskState, never()).discardState(); // let the checkpoint fail at the first ack vertex @@ -1165,7 +1153,7 @@ public void testStateCleanupForLateOrUnknownMessages() throws Exception { assertTrue(pendingCheckpoint.isDiscarded()); // check that we've cleaned up the already acknowledged state - verify(triggerSubtaskState, times(1)).unregister(any(SharedStateRegistry.class)); + verify(triggerSubtaskState, times(1)).discardSharedStatesOnFail(); verify(triggerSubtaskState, times(1)).discardState(); SubtaskState ackSubtaskState = mock(SubtaskState.class); @@ -1174,8 +1162,7 @@ public void testStateCleanupForLateOrUnknownMessages() throws Exception { coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jobId, ackAttemptId2, checkpointId, new CheckpointMetrics(), ackSubtaskState)); // check that we also cleaned up this state - verify(ackSubtaskState, times(1)).register(any(SharedStateRegistry.class)); - verify(ackSubtaskState, times(1)).unregister(any(SharedStateRegistry.class)); + verify(ackSubtaskState, times(1)).discardSharedStatesOnFail(); verify(ackSubtaskState, times(1)).discardState(); // receive an acknowledge message from an unknown job @@ -1183,8 +1170,7 @@ public void testStateCleanupForLateOrUnknownMessages() throws Exception { coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(new JobID(), new ExecutionAttemptID(), checkpointId, new CheckpointMetrics(), differentJobSubtaskState)); // we should not interfere with different jobs - verify(differentJobSubtaskState, never()).register(any(SharedStateRegistry.class)); - verify(differentJobSubtaskState, never()).unregister(any(SharedStateRegistry.class)); + verify(differentJobSubtaskState, never()).discardSharedStatesOnFail(); verify(differentJobSubtaskState, never()).discardState(); SubtaskState unknownSubtaskState2 = mock(SubtaskState.class); @@ -1193,8 +1179,7 @@ public void testStateCleanupForLateOrUnknownMessages() throws Exception { coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jobId, new ExecutionAttemptID(), checkpointId, new CheckpointMetrics(), unknownSubtaskState2)); // we should discard acknowledge messages from an unknown vertex belonging to our job - verify(unknownSubtaskState2, times(1)).register(any(SharedStateRegistry.class)); - verify(unknownSubtaskState2, times(1)).unregister(any(SharedStateRegistry.class)); + verify(unknownSubtaskState2, times(1)).discardSharedStatesOnFail(); verify(unknownSubtaskState2, times(1)).discardState(); } @@ -1453,22 +1438,16 @@ public void testTriggerAndConfirmSimpleSavepoint() throws Exception { assertFalse(pending.isDiscarded()); assertFalse(pending.isFullyAcknowledged()); assertFalse(savepointFuture.isDone()); - verify(subtaskState2, times(1)).register(any(SharedStateRegistry.class)); - verify(subtaskState2, never()).unregister(any(SharedStateRegistry.class)); // acknowledge the same task again (should not matter) coord.receiveAcknowledgeMessage(acknowledgeCheckpoint2); assertFalse(pending.isDiscarded()); assertFalse(pending.isFullyAcknowledged()); assertFalse(savepointFuture.isDone()); - verify(subtaskState2, times(1)).register(any(SharedStateRegistry.class)); - verify(subtaskState2, never()).unregister(any(SharedStateRegistry.class)); // acknowledge the other task. SubtaskState subtaskState1 = mock(SubtaskState.class); coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID1, checkpointId, new CheckpointMetrics(), subtaskState1)); - verify(subtaskState1, times(1)).register(any(SharedStateRegistry.class)); - verify(subtaskState1, never()).unregister(any(SharedStateRegistry.class)); // the checkpoint is internally converted to a successful checkpoint and the // pending checkpoint object is disposed @@ -1485,6 +1464,12 @@ public void testTriggerAndConfirmSimpleSavepoint() throws Exception { verify(vertex2.getCurrentExecutionAttempt(), times(1)).notifyCheckpointComplete(eq(checkpointId), eq(timestamp)); } + // validate that the shared states are registered + { + verify(subtaskState1, times(1)).registerSharedStates(any(SharedStateRegistry.class)); + verify(subtaskState2, times(1)).registerSharedStates(any(SharedStateRegistry.class)); + } + CompletedCheckpoint success = coord.getSuccessfulCheckpoints().get(0); assertEquals(jid, success.getJobId()); assertEquals(timestamp, success.getTimestamp()); @@ -1516,10 +1501,9 @@ public void testTriggerAndConfirmSimpleSavepoint() throws Exception { verify(subtaskState1, never()).discardState(); verify(subtaskState2, never()).discardState(); - // Savepoints are not supposed to have any shared state. But we still - // call the unregister method in case savepoints register something in the registry. - verify(subtaskState1, times(1)).unregister(any(SharedStateRegistry.class)); - verify(subtaskState2, times(1)).unregister(any(SharedStateRegistry.class)); + // Savepoints are not supposed to have any shared state. + verify(subtaskState1, never()).unregisterSharedStates(any(SharedStateRegistry.class)); + verify(subtaskState2, never()).unregisterSharedStates(any(SharedStateRegistry.class)); // validate that the relevant tasks got a confirmation message { @@ -2029,13 +2013,14 @@ public void testRestoreLatestCheckpointedState() throws Exception { assertEquals(1, completedCheckpoints.size()); // shutdown the store - store.shutdown(JobStatus.SUSPENDED, new SharedStateRegistry()); + SharedStateRegistry sharedStateRegistry = coord.getSharedStateRegistry(); + store.shutdown(JobStatus.SUSPENDED, sharedStateRegistry); // All shared states should be unregistered once the store is shut down for (CompletedCheckpoint completedCheckpoint : completedCheckpoints) { for (TaskState taskState : completedCheckpoint.getTaskStates().values()) { for (SubtaskState subtaskState : taskState.getStates()) { - verify(subtaskState, times(1)).unregister(any(SharedStateRegistry.class)); + verify(subtaskState, times(1)).unregisterSharedStates(sharedStateRegistry); } } } @@ -2052,7 +2037,7 @@ public void testRestoreLatestCheckpointedState() throws Exception { for (CompletedCheckpoint completedCheckpoint : completedCheckpoints) { for (TaskState taskState : completedCheckpoint.getTaskStates().values()) { for (SubtaskState subtaskState : taskState.getStates()) { - verify(subtaskState, times(2)).register(any(SharedStateRegistry.class)); + verify(subtaskState, times(2)).registerSharedStates(sharedStateRegistry); } } } @@ -3163,9 +3148,9 @@ public void testCheckpointStatsTrackerRestoreCallback() throws Exception { store, null, Executors.directExecutor()); - + store.addCheckpoint( - new CompletedCheckpoint(new JobID(), 0, 0, 0, Collections.emptyMap()), + new CompletedCheckpoint(new JobID(), 0, 0, 0, Collections.emptyMap()), coord.getSharedStateRegistry()); CheckpointStatsTracker tracker = mock(CheckpointStatsTracker.class); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java index 509e2165770ce..aa1726bde5f45 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java @@ -21,14 +21,11 @@ import org.apache.flink.api.common.JobID; import org.apache.flink.runtime.jobgraph.JobStatus; import org.apache.flink.runtime.jobgraph.JobVertexID; -import org.apache.flink.runtime.messages.CheckpointMessagesTest; -import org.apache.flink.runtime.state.ChainedStateHandle; import org.apache.flink.runtime.state.KeyGroupRange; import org.apache.flink.runtime.state.SharedStateRegistry; -import org.apache.flink.runtime.state.StateObject; -import org.apache.flink.runtime.state.StreamStateHandle; import org.apache.flink.util.TestLogger; import org.junit.Test; +import org.mockito.Mockito; import java.io.IOException; import java.util.Collection; @@ -40,6 +37,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; +import static org.mockito.Matchers.eq; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -104,7 +102,6 @@ public void testAddCheckpointMoreThanMaxRetained() throws Exception { }; // Add checkpoints - sharedStateRegistry.registerAll(expected[0].getTaskStates().values()); checkpoints.addCheckpoint(expected[0], sharedStateRegistry); assertEquals(1, checkpoints.getNumberOfRetainedCheckpoints()); @@ -120,7 +117,7 @@ public void testAddCheckpointMoreThanMaxRetained() throws Exception { for (TaskState taskState : taskStates) { for (SubtaskState subtaskState : taskState.getStates()) { - verify(subtaskState, times(1)).unregister(sharedStateRegistry); + verify(subtaskState, times(1)).unregisterSharedStates(sharedStateRegistry); } } } @@ -230,6 +227,31 @@ protected TestCompletedCheckpoint createCheckpoint(int id, int numberOfStates, C return new TestCompletedCheckpoint(new JobID(), id, 0, taskGroupStates, props); } + protected void resetCheckpoint(Collection taskStates) { + for (TaskState taskState : taskStates) { + for (SubtaskState subtaskState : taskState.getStates()) { + Mockito.reset(subtaskState); + } + } + } + + protected void verifyCheckpointRegistered(Collection taskStates, SharedStateRegistry sharedStateRegistry) { + for (TaskState taskState : taskStates) { + for (SubtaskState subtaskState : taskState.getStates()) { + verify(subtaskState, times(1)).registerSharedStates(eq(sharedStateRegistry)); + } + } + } + + protected void verifyCheckpointDiscarded(Collection taskStates) { + for (TaskState taskState : taskStates) { + for (SubtaskState subtaskState : taskState.getStates()) { + verify(subtaskState, times(1)).discardSharedStatesOnFail(); + verify(subtaskState, times(1)).discardState(); + } + } + } + private void verifyCheckpoint(CompletedCheckpoint expected, CompletedCheckpoint actual) { assertEquals(expected, actual); } @@ -259,8 +281,8 @@ public TestCompletedCheckpoint( } @Override - public boolean subsume() throws Exception { - if (super.subsume()) { + public boolean discardOnSubsume(SharedStateRegistry sharedStateRegistry) throws Exception { + if (super.discardOnSubsume(sharedStateRegistry)) { discard(); return true; } else { @@ -269,8 +291,8 @@ public boolean subsume() throws Exception { } @Override - public boolean discard(JobStatus jobStatus) throws Exception { - if (super.discard(jobStatus)) { + public boolean discardOnShutdown(JobStatus jobStatus, SharedStateRegistry sharedStateRegistry) throws Exception { + if (super.discardOnShutdown(jobStatus, sharedStateRegistry)) { discard(); return true; } else { diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointTest.java index b34e9a6ac26f5..0b759d4056273 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointTest.java @@ -23,6 +23,8 @@ import org.apache.flink.core.testutils.CommonTestUtils; import org.apache.flink.runtime.jobgraph.JobStatus; import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.state.SharedStateHandle; +import org.apache.flink.runtime.state.SharedStateRegistry; import org.apache.flink.runtime.state.filesystem.FileStateHandle; import org.junit.Rule; import org.junit.Test; @@ -30,10 +32,12 @@ import org.mockito.Mockito; import java.io.File; +import java.util.Collections; import java.util.HashMap; import java.util.Map; import static org.junit.Assert.assertEquals; +import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -61,7 +65,7 @@ public void testDiscard() throws Exception { new FileStateHandle(new Path(file.toURI()), file.length()), file.getAbsolutePath()); - checkpoint.discard(JobStatus.FAILED); + checkpoint.discardOnShutdown(JobStatus.FAILED, new SharedStateRegistry()); assertEquals(false, file.exists()); } @@ -80,10 +84,15 @@ public void testCleanUpOnSubsume() throws Exception { CompletedCheckpoint checkpoint = new CompletedCheckpoint( new JobID(), 0, 0, 1, taskStates, props); + SharedStateRegistry sharedStateRegistry = new SharedStateRegistry(); + checkpoint.registerSharedStates(sharedStateRegistry); + verify(state, times(1)).registerSharedStates(sharedStateRegistry); + // Subsume - checkpoint.subsume(); + checkpoint.discardOnSubsume(sharedStateRegistry); verify(state, times(1)).discardState(); + verify(state, times(1)).unregisterSharedStates(sharedStateRegistry); } /** @@ -112,17 +121,22 @@ public void testCleanUpOnShutdown() throws Exception { new FileStateHandle(new Path(file.toURI()), file.length()), externalPath); - checkpoint.discard(status); + SharedStateRegistry sharedStateRegistry = new SharedStateRegistry(); + checkpoint.registerSharedStates(sharedStateRegistry); + + checkpoint.discardOnShutdown(status, sharedStateRegistry); verify(state, times(0)).discardState(); assertEquals(true, file.exists()); + verify(state, times(0)).unregisterSharedStates(sharedStateRegistry); // Discard props = new CheckpointProperties(false, false, true, true, true, true, true); checkpoint = new CompletedCheckpoint( new JobID(), 0, 0, 1, new HashMap<>(taskStates), props); - checkpoint.discard(status); + checkpoint.discardOnShutdown(status, sharedStateRegistry); verify(state, times(1)).discardState(); + verify(state, times(1)).unregisterSharedStates(sharedStateRegistry); } } @@ -146,7 +160,7 @@ public void testCompletedCheckpointStatsCallbacks() throws Exception { CompletedCheckpointStats.DiscardCallback callback = mock(CompletedCheckpointStats.DiscardCallback.class); completed.setDiscardCallback(callback); - completed.discard(JobStatus.FINISHED); + completed.discardOnShutdown(JobStatus.FINISHED, new SharedStateRegistry()); verify(callback, times(1)).notifyDiscardedCheckpoint(); } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java index 00635dc7eb20d..d77fac1bed9dd 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java @@ -24,16 +24,19 @@ import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.executiongraph.ExecutionVertex; import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.state.SharedStateHandle; import org.apache.flink.runtime.state.SharedStateRegistry; import org.junit.Assert; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; +import org.mockito.Mock; import org.mockito.Mockito; import java.io.File; import java.lang.reflect.Field; import java.util.ArrayDeque; +import java.util.Collections; import java.util.HashMap; import java.util.Map; import java.util.Queue; @@ -47,7 +50,9 @@ import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyLong; import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.powermock.api.mockito.PowerMockito.when; @@ -187,59 +192,58 @@ public void testCompletionFuture() throws Exception { public void testAbortDiscardsState() throws Exception { CheckpointProperties props = new CheckpointProperties(false, true, false, false, false, false, false); QueueExecutor executor = new QueueExecutor(); - SharedStateRegistry sharedStateRegistry = new SharedStateRegistry(); TaskState state = mock(TaskState.class); - doNothing().when(state).register(any(SharedStateRegistry.class)); - doNothing().when(state).unregister(any(SharedStateRegistry.class)); + doNothing().when(state).registerSharedStates(any(SharedStateRegistry.class)); + doNothing().when(state).unregisterSharedStates(any(SharedStateRegistry.class)); String targetDir = tmpFolder.newFolder().getAbsolutePath(); // Abort declined - PendingCheckpoint pending = createPendingCheckpoint(props, targetDir, executor, sharedStateRegistry); + PendingCheckpoint pending = createPendingCheckpoint(props, targetDir, executor); setTaskState(pending, state); pending.abortDeclined(); // execute asynchronous discard operation executor.runQueuedCommands(); verify(state, times(1)).discardState(); - verify(state, times(1)).unregister(sharedStateRegistry); + verify(state, times(1)).discardSharedStatesOnFail(); // Abort error Mockito.reset(state); - pending = createPendingCheckpoint(props, targetDir, executor, sharedStateRegistry); + pending = createPendingCheckpoint(props, targetDir, executor); setTaskState(pending, state); pending.abortError(new Exception("Expected Test Exception")); // execute asynchronous discard operation executor.runQueuedCommands(); verify(state, times(1)).discardState(); - verify(state, times(1)).unregister(sharedStateRegistry); + verify(state, times(1)).discardSharedStatesOnFail(); // Abort expired Mockito.reset(state); - pending = createPendingCheckpoint(props, targetDir, executor, sharedStateRegistry); + pending = createPendingCheckpoint(props, targetDir, executor); setTaskState(pending, state); pending.abortExpired(); // execute asynchronous discard operation executor.runQueuedCommands(); verify(state, times(1)).discardState(); - verify(state, times(1)).unregister(sharedStateRegistry); + verify(state, times(1)).discardSharedStatesOnFail(); // Abort subsumed Mockito.reset(state); - pending = createPendingCheckpoint(props, targetDir, executor, sharedStateRegistry); + pending = createPendingCheckpoint(props, targetDir, executor); setTaskState(pending, state); pending.abortSubsumed(); // execute asynchronous discard operation executor.runQueuedCommands(); verify(state, times(1)).discardState(); - verify(state, times(1)).unregister(sharedStateRegistry); + verify(state, times(1)).discardSharedStatesOnFail(); } /** @@ -347,14 +351,13 @@ public void testSetCanceller() { // ------------------------------------------------------------------------ private static PendingCheckpoint createPendingCheckpoint(CheckpointProperties props, String targetDirectory) { - return createPendingCheckpoint(props, targetDirectory, Executors.directExecutor(), new SharedStateRegistry()); + return createPendingCheckpoint(props, targetDirectory, Executors.directExecutor()); } private static PendingCheckpoint createPendingCheckpoint( CheckpointProperties props, String targetDirectory, - Executor executor, - SharedStateRegistry sharedStateRegistry) { + Executor executor) { Map ackTasks = new HashMap<>(ACK_TASKS); return new PendingCheckpoint( @@ -364,8 +367,7 @@ private static PendingCheckpoint createPendingCheckpoint( ackTasks, props, targetDirectory, - executor, - sharedStateRegistry); + executor); } @SuppressWarnings("unchecked") diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStoreTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStoreTest.java index 6d2f7d0f91b55..7a8589719819e 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStoreTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStoreTest.java @@ -23,6 +23,7 @@ import org.junit.Test; import java.io.IOException; +import java.util.Collection; import java.util.Collections; import java.util.List; @@ -53,15 +54,16 @@ public void testShutdownDiscardsCheckpoints() throws Exception { CompletedCheckpointStore store = createCompletedCheckpoints(1); SharedStateRegistry sharedStateRegistry = new SharedStateRegistry(); TestCompletedCheckpoint checkpoint = createCheckpoint(0); - - sharedStateRegistry.registerAll(checkpoint.getTaskStates().values()); + Collection taskStates = checkpoint.getTaskStates().values(); store.addCheckpoint(checkpoint, sharedStateRegistry); assertEquals(1, store.getNumberOfRetainedCheckpoints()); + verifyCheckpointRegistered(taskStates, sharedStateRegistry); store.shutdown(JobStatus.FINISHED, sharedStateRegistry); assertEquals(0, store.getNumberOfRetainedCheckpoints()); assertTrue(checkpoint.isDiscarded()); + verifyCheckpointDiscarded(taskStates); } /** @@ -72,16 +74,17 @@ public void testShutdownDiscardsCheckpoints() throws Exception { public void testSuspendDiscardsCheckpoints() throws Exception { CompletedCheckpointStore store = createCompletedCheckpoints(1); SharedStateRegistry sharedStateRegistry = new SharedStateRegistry(); - TestCompletedCheckpoint checkpoint = createCheckpoint(0); + Collection taskStates = checkpoint.getTaskStates().values(); - sharedStateRegistry.registerAll(checkpoint.getTaskStates().values()); store.addCheckpoint(checkpoint, sharedStateRegistry); assertEquals(1, store.getNumberOfRetainedCheckpoints()); + verifyCheckpointRegistered(taskStates, sharedStateRegistry); store.shutdown(JobStatus.SUSPENDED, sharedStateRegistry); assertEquals(0, store.getNumberOfRetainedCheckpoints()); assertTrue(checkpoint.isDiscarded()); + verifyCheckpointDiscarded(taskStates); } /** @@ -99,7 +102,7 @@ public void testAddCheckpointWithFailedRemove() throws Exception { CompletedCheckpoint checkpointToAdd = mock(CompletedCheckpoint.class); doReturn(i).when(checkpointToAdd).getCheckpointID(); doReturn(Collections.emptyMap()).when(checkpointToAdd).getTaskStates(); - doThrow(new IOException()).when(checkpointToAdd).subsume(); + doThrow(new IOException()).when(checkpointToAdd).discardOnSubsume(sharedStateRegistry); try { store.addCheckpoint(checkpointToAdd, sharedStateRegistry); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreITCase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreITCase.java index 316d6c3fb0f5d..607e773cb19c6 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreITCase.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreITCase.java @@ -88,19 +88,25 @@ public void testRecover() throws Exception { }; // Add multiple checkpoints - sharedStateRegistry.registerAll(expected[0].getTaskStates().values()); checkpoints.addCheckpoint(expected[0], sharedStateRegistry); - sharedStateRegistry.registerAll(expected[1].getTaskStates().values()); checkpoints.addCheckpoint(expected[1], sharedStateRegistry); - sharedStateRegistry.registerAll(expected[2].getTaskStates().values()); checkpoints.addCheckpoint(expected[2], sharedStateRegistry); + verifyCheckpointRegistered(expected[0].getTaskStates().values(), sharedStateRegistry); + verifyCheckpointRegistered(expected[1].getTaskStates().values(), sharedStateRegistry); + verifyCheckpointRegistered(expected[2].getTaskStates().values(), sharedStateRegistry); + // All three should be in ZK assertEquals(3, ZooKeeper.getClient().getChildren().forPath(CheckpointsPath).size()); assertEquals(3, checkpoints.getNumberOfRetainedCheckpoints()); + resetCheckpoint(expected[0].getTaskStates().values()); + resetCheckpoint(expected[1].getTaskStates().values()); + resetCheckpoint(expected[2].getTaskStates().values()); + // Recover - checkpoints.recover(); + SharedStateRegistry newSharedStateRegistry = new SharedStateRegistry(); + checkpoints.recover(newSharedStateRegistry); assertEquals(3, ZooKeeper.getClient().getChildren().forPath(CheckpointsPath).size()); assertEquals(3, checkpoints.getNumberOfRetainedCheckpoints()); @@ -111,11 +117,15 @@ public void testRecover() throws Exception { expectedCheckpoints.add(expected[2]); expectedCheckpoints.add(createCheckpoint(3)); - checkpoints.addCheckpoint(expectedCheckpoints.get(2), sharedStateRegistry); + checkpoints.addCheckpoint(expectedCheckpoints.get(2), newSharedStateRegistry); List actualCheckpoints = checkpoints.getAllCheckpoints(); assertEquals(expectedCheckpoints, actualCheckpoints); + + for (CompletedCheckpoint actualCheckpoint : actualCheckpoints) { + verifyCheckpointRegistered(actualCheckpoint.getTaskStates().values(), newSharedStateRegistry); + } } /** @@ -129,7 +139,6 @@ public void testShutdownDiscardsCheckpoints() throws Exception { SharedStateRegistry sharedStateRegistry = new SharedStateRegistry(); TestCompletedCheckpoint checkpoint = createCheckpoint(0); - sharedStateRegistry.registerAll(checkpoint.getTaskStates().values()); store.addCheckpoint(checkpoint, sharedStateRegistry); assertEquals(1, store.getNumberOfRetainedCheckpoints()); assertNotNull(client.checkExists().forPath(CheckpointsPath + "/" + checkpoint.getCheckpointID())); @@ -138,7 +147,7 @@ public void testShutdownDiscardsCheckpoints() throws Exception { assertEquals(0, store.getNumberOfRetainedCheckpoints()); assertNull(client.checkExists().forPath(CheckpointsPath + "/" + checkpoint.getCheckpointID())); - store.recover(); + store.recover(sharedStateRegistry); assertEquals(0, store.getNumberOfRetainedCheckpoints()); } @@ -165,7 +174,7 @@ public void testSuspendKeepsCheckpoints() throws Exception { assertNotNull(client.checkExists().forPath(CheckpointsPath + "/" + checkpoint.getCheckpointID())); // Recover again - store.recover(); + store.recover(sharedStateRegistry); CompletedCheckpoint recovered = store.getLatestCheckpoint(); assertEquals(checkpoint, recovered); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreTest.java index 62545fb286d8d..1f5731d319bb1 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreTest.java @@ -160,7 +160,9 @@ public Void answer(InvocationOnMock invocation) throws Throwable { stateSotrage, Executors.directExecutor()); - zooKeeperCompletedCheckpointStore.recover(); + SharedStateRegistry sharedStateRegistry = new SharedStateRegistry(); + + zooKeeperCompletedCheckpointStore.recover(sharedStateRegistry); CompletedCheckpoint latestCompletedCheckpoint = zooKeeperCompletedCheckpointStore.getLatestCheckpoint(); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/SharedStateRegistryTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/SharedStateRegistryTest.java index 13b815248c263..cb14ff04d10ec 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/SharedStateRegistryTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/SharedStateRegistryTest.java @@ -33,43 +33,51 @@ public void testRegistryNormal() { SharedStateRegistry sharedStateRegistry = new SharedStateRegistry(); // register one state - TestState firstState = new TestState("first"); - sharedStateRegistry.register(firstState.getKey(), firstState); - assertEquals(1, sharedStateRegistry.getReferenceCount(firstState.getKey())); + TestSharedState firstState = new TestSharedState("first"); + sharedStateRegistry.register(firstState, true); + assertEquals(1, sharedStateRegistry.getReferenceCount(firstState)); // register another state - TestState secondState = new TestState("second"); - sharedStateRegistry.register(secondState.getKey(), secondState); - assertEquals(1, sharedStateRegistry.getReferenceCount(secondState.getKey())); + TestSharedState secondState = new TestSharedState("second"); + sharedStateRegistry.register(secondState, true); + assertEquals(1, sharedStateRegistry.getReferenceCount(secondState)); // register the first state again - sharedStateRegistry.register(firstState.getKey(), firstState); - assertEquals(2, sharedStateRegistry.getReferenceCount(firstState.getKey())); + sharedStateRegistry.register(firstState, false); + assertEquals(2, sharedStateRegistry.getReferenceCount(firstState)); // unregister the second state - sharedStateRegistry.unregister(secondState.getKey()); - assertEquals(0, sharedStateRegistry.getReferenceCount(secondState.getKey())); + sharedStateRegistry.unregister(secondState); + assertEquals(0, sharedStateRegistry.getReferenceCount(secondState)); // unregister the first state - sharedStateRegistry.unregister(firstState.getKey()); - assertEquals(1, sharedStateRegistry.getReferenceCount(firstState.getKey())); + sharedStateRegistry.unregister(firstState); + assertEquals(1, sharedStateRegistry.getReferenceCount(firstState)); } /** - * Validate that registering a key with different states will throw exception + * Validate that registering a handle referencing uncreated state will throw exception */ @Test(expected = IllegalStateException.class) - public void testRegisterWithInconsistentState() { + public void testRegisterWithUncreatedReference() { SharedStateRegistry sharedStateRegistry = new SharedStateRegistry(); // register one state - TestState state = new TestState("state"); - sharedStateRegistry.register("key", state); - assertEquals(1, sharedStateRegistry.getReferenceCount("key")); + TestSharedState state = new TestSharedState("state"); + sharedStateRegistry.register(state, false); + } + + /** + * Validate that registering duplicate creation of the same state will throw exception + */ + @Test(expected = IllegalStateException.class) + public void testRegisterWithDuplicateState() { + SharedStateRegistry sharedStateRegistry = new SharedStateRegistry(); - // register the state with another key - TestState anotherState = new TestState("anotherState"); - sharedStateRegistry.register("key", anotherState); + // register one state + TestSharedState state = new TestSharedState("state"); + sharedStateRegistry.register(state, true); + sharedStateRegistry.register(state, true); } /** @@ -79,18 +87,19 @@ public void testRegisterWithInconsistentState() { public void testUnregisterWithUnexistedKey() { SharedStateRegistry sharedStateRegistry = new SharedStateRegistry(); - sharedStateRegistry.unregister("unexistedKey"); + sharedStateRegistry.unregister(new TestSharedState("unexisted")); } - private static class TestState implements StateObject { + private static class TestSharedState implements SharedStateHandle { private static final long serialVersionUID = 4468635881465159780L; private String key; - TestState(String key) { + TestSharedState(String key) { this.key = key; } + @Override public String getKey() { return key; } @@ -114,7 +123,7 @@ public boolean equals(Object o) { return false; } - TestState testState = (TestState) o; + TestSharedState testState = (TestSharedState) o; return key.equals(testState.key); } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/testutils/RecoverableCompletedCheckpointStore.java b/flink-runtime/src/test/java/org/apache/flink/runtime/testutils/RecoverableCompletedCheckpointStore.java index e0c915b747244..75b0f6fb7343d 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/testutils/RecoverableCompletedCheckpointStore.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/testutils/RecoverableCompletedCheckpointStore.java @@ -44,25 +44,24 @@ public class RecoverableCompletedCheckpointStore implements CompletedCheckpointS private final ArrayDeque suspended = new ArrayDeque<>(2); @Override - public void recover() throws Exception { + public void recover(SharedStateRegistry sharedStateRegistry) throws Exception { checkpoints.addAll(suspended); suspended.clear(); + + for (CompletedCheckpoint checkpoint : checkpoints) { + checkpoint.registerSharedStates(sharedStateRegistry); + } } @Override public void addCheckpoint(CompletedCheckpoint checkpoint, SharedStateRegistry sharedStateRegistry) throws Exception { checkpoints.addLast(checkpoint); + checkpoint.registerSharedStates(sharedStateRegistry); + if (checkpoints.size() > 1) { CompletedCheckpoint checkpointToSubsume = checkpoints.removeFirst(); - List unreferencedSharedStates = sharedStateRegistry.unregisterAll(checkpointToSubsume.getTaskStates().values()); - try { - StateUtil.bestEffortDiscardAllStateObjects(unreferencedSharedStates); - } catch (Exception e) { - LOG.warn("Could not properly discard unreferenced shared states.", e); - } - - checkpointToSubsume.subsume(); + checkpointToSubsume.discardOnSubsume(sharedStateRegistry); } } @@ -80,13 +79,7 @@ public void shutdown(JobStatus jobStatus, SharedStateRegistry sharedStateRegistr suspended.clear(); for (CompletedCheckpoint checkpoint : checkpoints) { - List unreferencedSharedStates = sharedStateRegistry.unregisterAll(checkpoint.getTaskStates().values()); - try { - StateUtil.bestEffortDiscardAllStateObjects(unreferencedSharedStates); - } catch (Exception e) { - LOG.warn("Could not properly discard unreferenced shared states.", e); - } - + sharedStateRegistry.unregisterAll(checkpoint.getTaskStates().values()); suspended.add(checkpoint); }