From 2ed4f6b28c2fda674f1319f2a3678b2a231988ac Mon Sep 17 00:00:00 2001 From: Stefan Richter Date: Mon, 26 Jun 2017 18:07:59 +0200 Subject: [PATCH 1/5] [FLINK-7213] Introduce state management by OperatorID in TaskManager --- .../state/RocksDBAsyncSnapshotTest.java | 21 +- .../checkpoint/CheckpointCoordinator.java | 2 +- .../CheckpointCoordinatorGateway.java | 2 +- .../runtime/checkpoint/OperatorState.java | 4 +- .../checkpoint/OperatorSubtaskState.java | 191 ++++++--- .../runtime/checkpoint/PendingCheckpoint.java | 54 +-- .../checkpoint/StateAssignmentOperation.java | 52 ++- .../runtime/checkpoint/TaskStateSnapshot.java | 136 +++++++ .../flink/runtime/execution/Environment.java | 9 +- .../flink/runtime/jobmaster/JobMaster.java | 5 +- .../checkpoint/AcknowledgeCheckpoint.java | 8 +- .../rpc/RpcCheckpointResponder.java | 4 +- .../ActorGatewayCheckpointResponder.java | 4 +- .../taskmanager/CheckpointResponder.java | 4 +- .../taskmanager/RuntimeEnvironment.java | 4 +- .../CheckpointCoordinatorFailureTest.java | 50 +-- .../checkpoint/CheckpointCoordinatorTest.java | 372 +++++++++--------- .../CheckpointStateRestoreTest.java | 21 +- .../checkpoint/PendingCheckpointTest.java | 2 +- .../jobmanager/JobManagerHARecoveryTest.java | 60 +-- .../messages/CheckpointMessagesTest.java | 23 +- .../operators/testutils/DummyEnvironment.java | 4 +- .../operators/testutils/MockEnvironment.java | 6 +- .../runtime/util/JvmExitOnFatalErrorTest.java | 7 +- .../streaming/api/graph/StreamConfig.java | 13 +- .../api/graph/StreamingJobGraphGenerator.java | 11 +- .../api/operators/AbstractStreamOperator.java | 7 +- .../api/operators/StreamOperator.java | 2 + .../streaming/runtime/tasks/StreamTask.java | 149 +++---- ...bstractUdfStreamOperatorLifecycleTest.java | 5 +- .../async/AsyncWaitOperatorTest.java | 15 +- .../operators/StreamTaskTimerTest.java | 2 + .../TestProcessingTimeServiceTest.java | 2 + .../tasks/BlockingCheckpointsTest.java | 2 + .../tasks/InterruptSensitiveRestoreTest.java | 3 + .../runtime/tasks/OneInputStreamTaskTest.java | 29 +- .../SourceExternalCheckpointTriggerTest.java | 2 + .../runtime/tasks/SourceStreamTaskTest.java | 3 + .../runtime/tasks/StreamMockEnvironment.java | 4 +- .../StreamTaskCancellationBarrierTest.java | 3 + .../tasks/StreamTaskTerminationTest.java | 2 + .../runtime/tasks/StreamTaskTest.java | 62 ++- .../runtime/tasks/StreamTaskTestHarness.java | 2 + .../runtime/tasks/TwoInputStreamTaskTest.java | 5 + .../AbstractStreamOperatorTestHarness.java | 2 + ...OperatorIDMappedStateToChainConverter.java | 91 +++++ 46 files changed, 929 insertions(+), 532 deletions(-) create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskStateSnapshot.java create mode 100644 flink-streaming-java/src/test/java/org/apache/flink/streaming/util/OperatorIDMappedStateToChainConverter.java diff --git a/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBAsyncSnapshotTest.java b/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBAsyncSnapshotTest.java index d2edf0efaf768..3d56b5e26ac16 100644 --- a/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBAsyncSnapshotTest.java +++ b/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBAsyncSnapshotTest.java @@ -32,8 +32,10 @@ import org.apache.flink.runtime.checkpoint.CheckpointMetaData; import org.apache.flink.runtime.checkpoint.CheckpointMetrics; import org.apache.flink.runtime.checkpoint.CheckpointOptions; -import org.apache.flink.runtime.checkpoint.SubtaskState; +import org.apache.flink.runtime.checkpoint.OperatorSubtaskState; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.execution.Environment; +import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.operators.testutils.DummyEnvironment; import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider; import org.apache.flink.runtime.state.AbstractKeyedStateBackend; @@ -74,6 +76,7 @@ import java.lang.reflect.Field; import java.net.URI; import java.util.Arrays; +import java.util.Map; import java.util.UUID; import java.util.concurrent.CancellationException; import java.util.concurrent.ExecutionException; @@ -81,7 +84,7 @@ import java.util.concurrent.RunnableFuture; import java.util.concurrent.TimeUnit; -import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyInt; @@ -137,6 +140,7 @@ public String getKey(String value) throws Exception { streamConfig.setStateBackend(backend); streamConfig.setStreamOperator(new AsyncCheckpointOperator()); + streamConfig.setOperatorID(new OperatorID()); final OneShotLatch delayCheckpointLatch = new OneShotLatch(); final OneShotLatch ensureCheckpointLatch = new OneShotLatch(); @@ -152,7 +156,7 @@ public String getKey(String value) throws Exception { public void acknowledgeCheckpoint( long checkpointId, CheckpointMetrics checkpointMetrics, - SubtaskState checkpointStateHandles) { + TaskStateSnapshot checkpointStateHandles) { super.acknowledgeCheckpoint(checkpointId, checkpointMetrics); @@ -164,8 +168,16 @@ public void acknowledgeCheckpoint( throw new RuntimeException(e); } + boolean hasKeyedManagedKeyedState = false; + for (Map.Entry entry : checkpointStateHandles.getSubtaskStateMappings()) { + OperatorSubtaskState state = entry.getValue(); + if (state != null) { + hasKeyedManagedKeyedState |= state.getManagedKeyedState() != null; + } + } + // should be one k/v state - assertNotNull(checkpointStateHandles.getManagedKeyedState()); + assertTrue(hasKeyedManagedKeyedState); // we now know that the checkpoint went through ensureCheckpointLatch.trigger(); @@ -241,6 +253,7 @@ public String getKey(String value) throws Exception { streamConfig.setStateBackend(backend); streamConfig.setStreamOperator(new AsyncCheckpointOperator()); + streamConfig.setOperatorID(new OperatorID()); StreamMockEnvironment mockEnv = new StreamMockEnvironment( testHarness.jobConfig, diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java index 3e3615822c126..2c9c902d7cdcd 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java @@ -1258,7 +1258,7 @@ private void discardSubtaskState( final JobID jobId, final ExecutionAttemptID executionAttemptID, final long checkpointId, - final SubtaskState subtaskState) { + final TaskStateSnapshot subtaskState) { if (subtaskState != null) { executor.execute(new Runnable() { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorGateway.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorGateway.java index 43d66ee719604..22244f6cb8d51 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorGateway.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorGateway.java @@ -29,7 +29,7 @@ void acknowledgeCheckpoint( final ExecutionAttemptID executionAttemptID, final long checkpointId, final CheckpointMetrics checkpointMetrics, - final SubtaskState subtaskState); + final TaskStateSnapshot subtaskState); void declineCheckpoint( JobID jobID, diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorState.java index b15302835bb38..145ff6a978931 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorState.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorState.java @@ -30,8 +30,8 @@ import java.util.Objects; /** - * Simple container class which contains the raw/managed/legacy operator state and key-group state handles for the sub - * tasks of an operator. + * Simple container class which contains the raw/managed/legacy operator state and key-group state handles from all sub + * tasks of an operator and therefore represents the complete state of a logical operator. */ public class OperatorState implements CompositeStateHandle { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorSubtaskState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorSubtaskState.java index e2ae632a26b1b..afe2c2fe87769 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorSubtaskState.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorSubtaskState.java @@ -25,13 +25,23 @@ import org.apache.flink.runtime.state.StateObject; import org.apache.flink.runtime.state.StateUtil; import org.apache.flink.runtime.state.StreamStateHandle; +import org.apache.flink.util.Preconditions; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.util.Arrays; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; /** - * Container for the state of one parallel subtask of an operator. This is part of the {@link OperatorState}. + * This class encapsulates the state for one parallel instance of an operator. The complete state of a (logical) + * operator (e.g. a flatmap operator) consists of the union of all {@link OperatorSubtaskState}s from all + * parallel tasks that physically execute parallelized, physical instances of the operator. + *

+ * The full state of the logical operator is represented by {@link OperatorState} which consists of + * {@link OperatorSubtaskState}s. */ public class OperatorSubtaskState implements CompositeStateHandle { @@ -51,22 +61,22 @@ public class OperatorSubtaskState implements CompositeStateHandle { /** * Snapshot from the {@link org.apache.flink.runtime.state.OperatorStateBackend}. */ - private final OperatorStateHandle managedOperatorState; + private final Collection managedOperatorState; /** * Snapshot written using {@link org.apache.flink.runtime.state.OperatorStateCheckpointOutputStream}. */ - private final OperatorStateHandle rawOperatorState; + private final Collection rawOperatorState; /** * Snapshot from {@link org.apache.flink.runtime.state.KeyedStateBackend}. */ - private final KeyedStateHandle managedKeyedState; + private final Collection managedKeyedState; /** * Snapshot written using {@link org.apache.flink.runtime.state.KeyedStateCheckpointOutputStream}. */ - private final KeyedStateHandle rawKeyedState; + private final Collection rawKeyedState; /** * The state size. This is also part of the deserialized state handle. @@ -75,32 +85,65 @@ public class OperatorSubtaskState implements CompositeStateHandle { */ private final long stateSize; + public OperatorSubtaskState() { + this.legacyOperatorState = null; + this.managedOperatorState = Collections.emptyList(); + this.rawOperatorState = Collections.emptyList(); + this.managedKeyedState = Collections.emptyList(); + this.rawKeyedState = Collections.emptyList(); + this.stateSize = 0L; + } + public OperatorSubtaskState( StreamStateHandle legacyOperatorState, - OperatorStateHandle managedOperatorState, - OperatorStateHandle rawOperatorState, - KeyedStateHandle managedKeyedState, - KeyedStateHandle rawKeyedState) { + Collection managedOperatorState, + Collection rawOperatorState, + Collection managedKeyedState, + Collection rawKeyedState) { this.legacyOperatorState = legacyOperatorState; - this.managedOperatorState = managedOperatorState; - this.rawOperatorState = rawOperatorState; - this.managedKeyedState = managedKeyedState; - this.rawKeyedState = rawKeyedState; + this.managedOperatorState = Preconditions.checkNotNull(managedOperatorState); + this.rawOperatorState = Preconditions.checkNotNull(rawOperatorState); + this.managedKeyedState = Preconditions.checkNotNull(managedKeyedState); + this.rawKeyedState = Preconditions.checkNotNull(rawKeyedState); try { - long calculateStateSize = getSizeNullSafe(legacyOperatorState); - calculateStateSize += getSizeNullSafe(managedOperatorState); - calculateStateSize += getSizeNullSafe(rawOperatorState); - calculateStateSize += getSizeNullSafe(managedKeyedState); - calculateStateSize += getSizeNullSafe(rawKeyedState); + long calculateStateSize = sumAllSizes(legacyOperatorState); + calculateStateSize += sumAllSizes(managedOperatorState); + calculateStateSize += sumAllSizes(rawOperatorState); + calculateStateSize += sumAllSizes(managedKeyedState); + calculateStateSize += sumAllSizes(rawKeyedState); stateSize = calculateStateSize; } catch (Exception e) { throw new RuntimeException("Failed to get state size.", e); } } - private static long getSizeNullSafe(StateObject stateObject) throws Exception { + public OperatorSubtaskState( + StreamStateHandle legacyOperatorState, + OperatorStateHandle managedOperatorState, + OperatorStateHandle rawOperatorState, + KeyedStateHandle managedKeyedState, + KeyedStateHandle rawKeyedState) { + + this(legacyOperatorState, + Collections.singletonList(managedOperatorState), + Collections.singletonList(rawOperatorState), + Collections.singletonList(managedKeyedState), + Collections.singletonList(rawKeyedState)); + } + + private static long sumAllSizes(Collection stateObject) throws Exception { + + long size = 0L; + for (StateObject object : stateObject) { + size += sumAllSizes(object); + } + + return size; + } + + private static long sumAllSizes(StateObject stateObject) throws Exception { return stateObject != null ? stateObject.getStateSize() : 0L; } @@ -115,32 +158,49 @@ public StreamStateHandle getLegacyOperatorState() { return legacyOperatorState; } - public OperatorStateHandle getManagedOperatorState() { + /** + * Returns a handle to the managed operator state. + */ + public Collection getManagedOperatorState() { return managedOperatorState; } - public OperatorStateHandle getRawOperatorState() { + /** + * Returns a handle to the raw operator state. + */ + public Collection getRawOperatorState() { return rawOperatorState; } - public KeyedStateHandle getManagedKeyedState() { + /** + * Returns a handle to the managed keyed state. + */ + public Collection getManagedKeyedState() { return managedKeyedState; } - public KeyedStateHandle getRawKeyedState() { + /** + * Returns a handle to the raw keyed state. + */ + public Collection getRawKeyedState() { return rawKeyedState; } @Override public void discardState() { try { - StateUtil.bestEffortDiscardAllStateObjects( - Arrays.asList( - legacyOperatorState, - managedOperatorState, - rawOperatorState, - managedKeyedState, - rawKeyedState)); + List toDispose = + new ArrayList<>(1 + + managedOperatorState.size() + + rawOperatorState.size() + + managedKeyedState.size() + + rawKeyedState.size()); + toDispose.add(legacyOperatorState); + toDispose.addAll(managedOperatorState); + toDispose.addAll(rawOperatorState); + toDispose.addAll(managedKeyedState); + toDispose.addAll(rawKeyedState); + StateUtil.bestEffortDiscardAllStateObjects(toDispose); } catch (Exception e) { LOG.warn("Error while discarding operator states.", e); } @@ -148,12 +208,17 @@ public void discardState() { @Override public void registerSharedStates(SharedStateRegistry sharedStateRegistry) { - if (managedKeyedState != null) { - managedKeyedState.registerSharedStates(sharedStateRegistry); - } + registerSharedState(sharedStateRegistry, managedKeyedState); + registerSharedState(sharedStateRegistry, rawKeyedState); + } - if (rawKeyedState != null) { - rawKeyedState.registerSharedStates(sharedStateRegistry); + private static void registerSharedState( + SharedStateRegistry sharedStateRegistry, + Iterable stateHandles) { + for (KeyedStateHandle stateHandle : stateHandles) { + if (stateHandle != null) { + stateHandle.registerSharedStates(sharedStateRegistry); + } } } @@ -164,6 +229,7 @@ public long getStateSize() { // -------------------------------------------------------------------------------------------- + @Override public boolean equals(Object o) { if (this == o) { @@ -175,44 +241,32 @@ public boolean equals(Object o) { OperatorSubtaskState that = (OperatorSubtaskState) o; - if (stateSize != that.stateSize) { + if (getStateSize() != that.getStateSize()) { return false; } - - if (legacyOperatorState != null ? - !legacyOperatorState.equals(that.legacyOperatorState) - : that.legacyOperatorState != null) { + if (getLegacyOperatorState() != null ? !getLegacyOperatorState().equals(that.getLegacyOperatorState()) : that.getLegacyOperatorState() != null) { return false; } - if (managedOperatorState != null ? - !managedOperatorState.equals(that.managedOperatorState) - : that.managedOperatorState != null) { + if (!getManagedOperatorState().equals(that.getManagedOperatorState())) { return false; } - if (rawOperatorState != null ? - !rawOperatorState.equals(that.rawOperatorState) - : that.rawOperatorState != null) { + if (!getRawOperatorState().equals(that.getRawOperatorState())) { return false; } - if (managedKeyedState != null ? - !managedKeyedState.equals(that.managedKeyedState) - : that.managedKeyedState != null) { + if (!getManagedKeyedState().equals(that.getManagedKeyedState())) { return false; } - return rawKeyedState != null ? - rawKeyedState.equals(that.rawKeyedState) - : that.rawKeyedState == null; - + return getRawKeyedState().equals(that.getRawKeyedState()); } @Override public int hashCode() { - int result = legacyOperatorState != null ? legacyOperatorState.hashCode() : 0; - result = 31 * result + (managedOperatorState != null ? managedOperatorState.hashCode() : 0); - result = 31 * result + (rawOperatorState != null ? rawOperatorState.hashCode() : 0); - result = 31 * result + (managedKeyedState != null ? managedKeyedState.hashCode() : 0); - result = 31 * result + (rawKeyedState != null ? rawKeyedState.hashCode() : 0); - result = 31 * result + (int) (stateSize ^ (stateSize >>> 32)); + int result = getLegacyOperatorState() != null ? getLegacyOperatorState().hashCode() : 0; + result = 31 * result + getManagedOperatorState().hashCode(); + result = 31 * result + getRawOperatorState().hashCode(); + result = 31 * result + getManagedKeyedState().hashCode(); + result = 31 * result + getRawKeyedState().hashCode(); + result = 31 * result + (int) (getStateSize() ^ (getStateSize() >>> 32)); return result; } @@ -227,4 +281,21 @@ public String toString() { ", stateSize=" + stateSize + '}'; } + + public boolean hasState() { + return legacyOperatorState != null + || hasState(managedOperatorState) + || hasState(rawOperatorState) + || hasState(managedKeyedState) + || hasState(rawKeyedState); + } + + private boolean hasState(Iterable states) { + for (StateObject state : states) { + if (state != null) { + return true; + } + } + return false; + } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java index 0633fec92fb1b..9c1e648fecd19 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java @@ -27,19 +27,18 @@ import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.executiongraph.ExecutionVertex; import org.apache.flink.runtime.jobgraph.OperatorID; -import org.apache.flink.runtime.state.ChainedStateHandle; -import org.apache.flink.runtime.state.KeyedStateHandle; -import org.apache.flink.runtime.state.OperatorStateHandle; import org.apache.flink.runtime.state.StateUtil; import org.apache.flink.runtime.state.StreamStateHandle; import org.apache.flink.runtime.state.filesystem.FileStateHandle; import org.apache.flink.util.ExceptionUtils; import org.apache.flink.util.Preconditions; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; import javax.annotation.Nullable; import javax.annotation.concurrent.GuardedBy; + import java.io.IOException; import java.util.ArrayList; import java.util.HashMap; @@ -354,13 +353,13 @@ private CompletedCheckpoint finalizeInternal( * Acknowledges the task with the given execution attempt id and the given subtask state. * * @param executionAttemptId of the acknowledged task - * @param subtaskState of the acknowledged task + * @param operatorSubtaskStates of the acknowledged task * @param metrics Checkpoint metrics for the stats * @return TaskAcknowledgeResult of the operation */ public TaskAcknowledgeResult acknowledgeTask( ExecutionAttemptID executionAttemptId, - SubtaskState subtaskState, + TaskStateSnapshot operatorSubtaskStates, CheckpointMetrics metrics) { synchronized (lock) { @@ -384,21 +383,19 @@ public TaskAcknowledgeResult acknowledgeTask( int subtaskIndex = vertex.getParallelSubtaskIndex(); long ackTimestamp = System.currentTimeMillis(); - long stateSize = 0; - if (subtaskState != null) { - stateSize = subtaskState.getStateSize(); - - @SuppressWarnings("deprecation") - ChainedStateHandle nonPartitionedState = - subtaskState.getLegacyOperatorState(); - ChainedStateHandle partitioneableState = - subtaskState.getManagedOperatorState(); - ChainedStateHandle rawOperatorState = - subtaskState.getRawOperatorState(); - - // break task state apart into separate operator states - for (int x = 0; x < operatorIDs.size(); x++) { - OperatorID operatorID = operatorIDs.get(x); + long stateSize = 0L; + + if (operatorSubtaskStates != null) { + for (OperatorID operatorID : operatorIDs) { + + OperatorSubtaskState operatorSubtaskState = + operatorSubtaskStates.getSubtaskStateByOperatorID(operatorID); + + // if no real operatorSubtaskState was reported, we insert an empty state + if (operatorSubtaskState == null) { + operatorSubtaskState = new OperatorSubtaskState(); + } + OperatorState operatorState = operatorStates.get(operatorID); if (operatorState == null) { @@ -409,23 +406,8 @@ public TaskAcknowledgeResult acknowledgeTask( operatorStates.put(operatorID, operatorState); } - KeyedStateHandle managedKeyedState = null; - KeyedStateHandle rawKeyedState = null; - - // only the head operator retains the keyed state - if (x == operatorIDs.size() - 1) { - managedKeyedState = subtaskState.getManagedKeyedState(); - rawKeyedState = subtaskState.getRawKeyedState(); - } - - OperatorSubtaskState operatorSubtaskState = new OperatorSubtaskState( - nonPartitionedState != null ? nonPartitionedState.get(x) : null, - partitioneableState != null ? partitioneableState.get(x) : null, - rawOperatorState != null ? rawOperatorState.get(x) : null, - managedKeyedState, - rawKeyedState); - operatorState.putState(subtaskIndex, operatorSubtaskState); + stateSize += operatorSubtaskState.getStateSize(); } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java index 5712ea1d43827..250c63e248546 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java @@ -32,6 +32,7 @@ import org.apache.flink.runtime.state.StreamStateHandle; import org.apache.flink.runtime.state.TaskStateHandles; import org.apache.flink.util.Preconditions; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -265,12 +266,8 @@ private Tuple2, Collection> reAss if (newParallelism == oldParallelism) { if (operatorState.getState(subTaskIndex) != null) { - KeyedStateHandle oldSubManagedKeyedState = operatorState.getState(subTaskIndex).getManagedKeyedState(); - KeyedStateHandle oldSubRawKeyedState = operatorState.getState(subTaskIndex).getRawKeyedState(); - subManagedKeyedState = oldSubManagedKeyedState != null ? Collections.singletonList( - oldSubManagedKeyedState) : null; - subRawKeyedState = oldSubRawKeyedState != null ? Collections.singletonList( - oldSubRawKeyedState) : null; + subManagedKeyedState = operatorState.getState(subTaskIndex).getManagedKeyedState(); + subRawKeyedState = operatorState.getState(subTaskIndex).getRawKeyedState(); } else { subManagedKeyedState = null; subRawKeyedState = null; @@ -355,14 +352,14 @@ private void collectPartionableStates( if (managedOperatorState == null) { managedOperatorState = new ArrayList<>(); } - managedOperatorState.add(operatorSubtaskState.getManagedOperatorState()); + managedOperatorState.addAll(operatorSubtaskState.getManagedOperatorState()); } if (operatorSubtaskState.getRawOperatorState() != null) { if (rawOperatorState == null) { rawOperatorState = new ArrayList<>(); } - rawOperatorState.add(operatorSubtaskState.getRawOperatorState()); + rawOperatorState.addAll(operatorSubtaskState.getRawOperatorState()); } } @@ -389,13 +386,23 @@ public static List getManagedKeyedStateHandles( for (int i = 0; i < operatorState.getParallelism(); i++) { if (operatorState.getState(i) != null && operatorState.getState(i).getManagedKeyedState() != null) { - KeyedStateHandle intersectedKeyedStateHandle = operatorState.getState(i).getManagedKeyedState().getIntersection(subtaskKeyGroupRange); - if (intersectedKeyedStateHandle != null) { - if (subtaskKeyedStateHandles == null) { - subtaskKeyedStateHandles = new ArrayList<>(); + Collection keyedStateHandles = operatorState.getState(i).getManagedKeyedState(); + for (KeyedStateHandle keyedStateHandle : keyedStateHandles) { + + //TODO deduplicate code!!!!!!! + if(keyedStateHandle != null) { + + KeyedStateHandle intersectedKeyedStateHandle = + keyedStateHandle.getIntersection(subtaskKeyGroupRange); + + if (intersectedKeyedStateHandle != null) { + if (subtaskKeyedStateHandles == null) { + subtaskKeyedStateHandles = new ArrayList<>(); + } + subtaskKeyedStateHandles.add(intersectedKeyedStateHandle); + } } - subtaskKeyedStateHandles.add(intersectedKeyedStateHandle); } } } @@ -419,13 +426,22 @@ public static List getRawKeyedStateHandles( for (int i = 0; i < operatorState.getParallelism(); i++) { if (operatorState.getState(i) != null && operatorState.getState(i).getRawKeyedState() != null) { - KeyedStateHandle intersectedKeyedStateHandle = operatorState.getState(i).getRawKeyedState().getIntersection(subtaskKeyGroupRange); - if (intersectedKeyedStateHandle != null) { - if (subtaskKeyedStateHandles == null) { - subtaskKeyedStateHandles = new ArrayList<>(); + Collection rawKeyedState = operatorState.getState(i).getRawKeyedState(); + + for (KeyedStateHandle keyedStateHandle : rawKeyedState) { + + if (keyedStateHandle != null) { + + KeyedStateHandle intersectedKeyedStateHandle = keyedStateHandle.getIntersection(subtaskKeyGroupRange); + + if (intersectedKeyedStateHandle != null) { + if (subtaskKeyedStateHandles == null) { + subtaskKeyedStateHandles = new ArrayList<>(); + } + subtaskKeyedStateHandles.add(intersectedKeyedStateHandle); + } } - subtaskKeyedStateHandles.add(intersectedKeyedStateHandle); } } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskStateSnapshot.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskStateSnapshot.java new file mode 100644 index 0000000000000..ee59ed97a8691 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskStateSnapshot.java @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint; + +import org.apache.flink.runtime.jobgraph.OperatorID; +import org.apache.flink.runtime.state.CompositeStateHandle; +import org.apache.flink.runtime.state.SharedStateRegistry; +import org.apache.flink.runtime.state.StateUtil; +import org.apache.flink.util.Preconditions; + +import java.util.HashMap; +import java.util.Map; +import java.util.Set; + +/** + * This class encapsulates state handles to the snapshots of all operator instances executed within one task. A task + * can run multiple operator instances as a result of operator chaining, and all operator instances from the chain can + * register their state under their operator id. Each operator instance is a physical execution responsible for + * processing a partition of the data that goes through a logical operator. This partitioning happens to parallelize + * execution of logical operators, e.g. distributing a map function. + *

+ * One instance of this class contains the information that one task will send to acknowledge a checkpoint request by t + * he checkpoint coordinator. Tasks run operator instances in parallel, so the union of all + * {@link TaskStateSnapshot} that are collected by the checkpoint coordinator from all tasks represent the whole + * state of a job at the time of the checkpoint. + */ +public class TaskStateSnapshot implements CompositeStateHandle { + + private static final long serialVersionUID = 1L; + + /** Mapping from an operator id to the state of one subtask of this operator */ + private final Map subtaskStatesByOperatorID; + + public TaskStateSnapshot() { + this(10); + } + + public TaskStateSnapshot(int size) { + this(new HashMap(size)); + } + + public TaskStateSnapshot(Map subtaskStatesByOperatorID) { + this.subtaskStatesByOperatorID = Preconditions.checkNotNull(subtaskStatesByOperatorID); + } + + /** + * Returns the subtask state for the given operator id (or null if not contained). + */ + public OperatorSubtaskState getSubtaskStateByOperatorID(OperatorID operatorID) { + return subtaskStatesByOperatorID.get(operatorID); + } + + /** + * Maps the given operator id to the given subtask state. Returns the subtask state of a previous mapping, if such + * a mapping existed or null otherwise. + */ + public OperatorSubtaskState putSubtaskStateByOperatorID(OperatorID operatorID, OperatorSubtaskState state) { + return subtaskStatesByOperatorID.put(operatorID, Preconditions.checkNotNull(state)); + } + + /** + * Returns the set of all mappings from operator id to the corresponding subtask state. + */ + public Set> getSubtaskStateMappings() { + return subtaskStatesByOperatorID.entrySet(); + } + + @Override + public void discardState() throws Exception { + StateUtil.bestEffortDiscardAllStateObjects(subtaskStatesByOperatorID.values()); + } + + @Override + public long getStateSize() { + long size = 0L; + + for (OperatorSubtaskState subtaskState : subtaskStatesByOperatorID.values()) { + if (subtaskState != null) { + size += subtaskState.getStateSize(); + } + } + + return size; + } + + @Override + public void registerSharedStates(SharedStateRegistry stateRegistry) { + for (OperatorSubtaskState operatorSubtaskState : subtaskStatesByOperatorID.values()) { + if (operatorSubtaskState != null) { + operatorSubtaskState.registerSharedStates(stateRegistry); + } + } + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + TaskStateSnapshot that = (TaskStateSnapshot) o; + + return subtaskStatesByOperatorID.equals(that.subtaskStatesByOperatorID); + } + + @Override + public int hashCode() { + return subtaskStatesByOperatorID.hashCode(); + } + + @Override + public String toString() { + return "TaskOperatorSubtaskStates{" + + "subtaskStatesByOperatorID=" + subtaskStatesByOperatorID + + '}'; + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/execution/Environment.java b/flink-runtime/src/main/java/org/apache/flink/runtime/execution/Environment.java index 9e9f7c4c719c4..203ee8547cf42 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/execution/Environment.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/execution/Environment.java @@ -18,8 +18,6 @@ package org.apache.flink.runtime.execution; -import java.util.Map; -import java.util.concurrent.Future; import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.JobID; import org.apache.flink.api.common.TaskInfo; @@ -28,7 +26,7 @@ import org.apache.flink.runtime.accumulators.AccumulatorRegistry; import org.apache.flink.runtime.broadcast.BroadcastVariableManager; import org.apache.flink.runtime.checkpoint.CheckpointMetrics; -import org.apache.flink.runtime.checkpoint.SubtaskState; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.io.disk.iomanager.IOManager; import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter; @@ -41,6 +39,9 @@ import org.apache.flink.runtime.state.internal.InternalKvState; import org.apache.flink.runtime.taskmanager.TaskManagerRuntimeInfo; +import java.util.Map; +import java.util.concurrent.Future; + /** * The Environment gives the code executed in a task access to the task's properties * (such as name, parallelism), the configurations, the data stream readers and writers, @@ -175,7 +176,7 @@ public interface Environment { * @param checkpointMetrics metrics for this checkpoint * @param subtaskState All state handles for the checkpointed state */ - void acknowledgeCheckpoint(long checkpointId, CheckpointMetrics checkpointMetrics, SubtaskState subtaskState); + void acknowledgeCheckpoint(long checkpointId, CheckpointMetrics checkpointMetrics, TaskStateSnapshot subtaskState); /** * Declines a checkpoint. This tells the checkpoint coordinator that this task will diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/jobmaster/JobMaster.java b/flink-runtime/src/main/java/org/apache/flink/runtime/jobmaster/JobMaster.java index 3a55f2eb2e260..8428cf8a1d989 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/jobmaster/JobMaster.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/jobmaster/JobMaster.java @@ -32,7 +32,7 @@ import org.apache.flink.runtime.checkpoint.CheckpointException; import org.apache.flink.runtime.checkpoint.CheckpointMetrics; import org.apache.flink.runtime.checkpoint.CheckpointRecoveryFactory; -import org.apache.flink.runtime.checkpoint.SubtaskState; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.client.JobExecutionException; import org.apache.flink.runtime.clusterframework.types.AllocationID; import org.apache.flink.runtime.clusterframework.types.ResourceID; @@ -100,6 +100,7 @@ import org.slf4j.Logger; import javax.annotation.Nullable; + import java.io.IOException; import java.util.ArrayList; import java.util.HashMap; @@ -565,7 +566,7 @@ public void acknowledgeCheckpoint( final ExecutionAttemptID executionAttemptID, final long checkpointId, final CheckpointMetrics checkpointMetrics, - final SubtaskState checkpointState) throws CheckpointException { + final TaskStateSnapshot checkpointState) throws CheckpointException { final CheckpointCoordinator checkpointCoordinator = executionGraph.getCheckpointCoordinator(); final AcknowledgeCheckpoint ackMessage = diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/messages/checkpoint/AcknowledgeCheckpoint.java b/flink-runtime/src/main/java/org/apache/flink/runtime/messages/checkpoint/AcknowledgeCheckpoint.java index 9721c2cd6f306..65e3019951fcc 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/messages/checkpoint/AcknowledgeCheckpoint.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/messages/checkpoint/AcknowledgeCheckpoint.java @@ -21,7 +21,7 @@ import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.api.common.JobID; import org.apache.flink.runtime.checkpoint.CheckpointMetrics; -import org.apache.flink.runtime.checkpoint.SubtaskState; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; /** @@ -36,7 +36,7 @@ public class AcknowledgeCheckpoint extends AbstractCheckpointMessage implements private static final long serialVersionUID = -7606214777192401493L; - private final SubtaskState subtaskState; + private final TaskStateSnapshot subtaskState; private final CheckpointMetrics checkpointMetrics; @@ -47,7 +47,7 @@ public AcknowledgeCheckpoint( ExecutionAttemptID taskExecutionId, long checkpointId, CheckpointMetrics checkpointMetrics, - SubtaskState subtaskState) { + TaskStateSnapshot subtaskState) { super(job, taskExecutionId, checkpointId); @@ -64,7 +64,7 @@ public AcknowledgeCheckpoint(JobID jobId, ExecutionAttemptID taskExecutionId, lo // properties // ------------------------------------------------------------------------ - public SubtaskState getSubtaskState() { + public TaskStateSnapshot getSubtaskState() { return subtaskState; } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/rpc/RpcCheckpointResponder.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/rpc/RpcCheckpointResponder.java index bf6016126af1d..aba8bda191825 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/rpc/RpcCheckpointResponder.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/rpc/RpcCheckpointResponder.java @@ -21,7 +21,7 @@ import org.apache.flink.api.common.JobID; import org.apache.flink.runtime.checkpoint.CheckpointCoordinatorGateway; import org.apache.flink.runtime.checkpoint.CheckpointMetrics; -import org.apache.flink.runtime.checkpoint.SubtaskState; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.taskmanager.CheckpointResponder; import org.apache.flink.util.Preconditions; @@ -40,7 +40,7 @@ public void acknowledgeCheckpoint( ExecutionAttemptID executionAttemptID, long checkpointId, CheckpointMetrics checkpointMetrics, - SubtaskState subtaskState) { + TaskStateSnapshot subtaskState) { checkpointCoordinatorGateway.acknowledgeCheckpoint( jobID, diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/ActorGatewayCheckpointResponder.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/ActorGatewayCheckpointResponder.java index ad0df7151c20a..e9f600d672abc 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/ActorGatewayCheckpointResponder.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/ActorGatewayCheckpointResponder.java @@ -20,7 +20,7 @@ import org.apache.flink.api.common.JobID; import org.apache.flink.runtime.checkpoint.CheckpointMetrics; -import org.apache.flink.runtime.checkpoint.SubtaskState; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.instance.ActorGateway; import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint; @@ -44,7 +44,7 @@ public void acknowledgeCheckpoint( ExecutionAttemptID executionAttemptID, long checkpointId, CheckpointMetrics checkpointMetrics, - SubtaskState checkpointStateHandles) { + TaskStateSnapshot checkpointStateHandles) { AcknowledgeCheckpoint message = new AcknowledgeCheckpoint( jobID, executionAttemptID, checkpointId, checkpointMetrics, diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/CheckpointResponder.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/CheckpointResponder.java index cc66a3f283160..b3584a6dfc987 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/CheckpointResponder.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/CheckpointResponder.java @@ -20,7 +20,7 @@ import org.apache.flink.api.common.JobID; import org.apache.flink.runtime.checkpoint.CheckpointMetrics; -import org.apache.flink.runtime.checkpoint.SubtaskState; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; /** @@ -47,7 +47,7 @@ void acknowledgeCheckpoint( ExecutionAttemptID executionAttemptID, long checkpointId, CheckpointMetrics checkpointMetrics, - SubtaskState subtaskState); + TaskStateSnapshot subtaskState); /** * Declines the given checkpoint. diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/RuntimeEnvironment.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/RuntimeEnvironment.java index 788a59090d309..92b58868d666f 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/RuntimeEnvironment.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/RuntimeEnvironment.java @@ -26,7 +26,7 @@ import org.apache.flink.runtime.accumulators.AccumulatorRegistry; import org.apache.flink.runtime.broadcast.BroadcastVariableManager; import org.apache.flink.runtime.checkpoint.CheckpointMetrics; -import org.apache.flink.runtime.checkpoint.SubtaskState; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.execution.Environment; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.io.disk.iomanager.IOManager; @@ -245,7 +245,7 @@ public void acknowledgeCheckpoint(long checkpointId, CheckpointMetrics checkpoin public void acknowledgeCheckpoint( long checkpointId, CheckpointMetrics checkpointMetrics, - SubtaskState checkpointStateHandles) { + TaskStateSnapshot checkpointStateHandles) { checkpointResponder.acknowledgeCheckpoint( jobId, executionId, checkpointId, checkpointMetrics, diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java index 344b34093d9b7..584d4fa31cbdb 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java @@ -23,14 +23,15 @@ import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.executiongraph.ExecutionVertex; import org.apache.flink.runtime.jobgraph.JobStatus; +import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.jobgraph.tasks.ExternalizedCheckpointSettings; import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint; -import org.apache.flink.runtime.state.ChainedStateHandle; import org.apache.flink.runtime.state.KeyedStateHandle; import org.apache.flink.runtime.state.OperatorStateHandle; import org.apache.flink.runtime.state.SharedStateRegistry; import org.apache.flink.runtime.state.StreamStateHandle; import org.apache.flink.util.TestLogger; + import org.junit.Test; import org.junit.runner.RunWith; import org.powermock.core.classloader.annotations.PrepareForTest; @@ -42,8 +43,9 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; -import static org.mockito.Matchers.anyInt; +import static org.mockito.Matchers.any; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -89,29 +91,26 @@ public void testFailingCompletedCheckpointStoreAdd() throws Exception { assertFalse(pendingCheckpoint.isDiscarded()); final long checkpointId = coord.getPendingCheckpoints().keySet().iterator().next(); - - SubtaskState subtaskState = mock(SubtaskState.class); + StreamStateHandle legacyHandle = mock(StreamStateHandle.class); - ChainedStateHandle chainedLegacyHandle = mock(ChainedStateHandle.class); - when(chainedLegacyHandle.get(anyInt())).thenReturn(legacyHandle); - when(subtaskState.getLegacyOperatorState()).thenReturn(chainedLegacyHandle); + KeyedStateHandle managedKeyedHandle = mock(KeyedStateHandle.class); + KeyedStateHandle rawKeyedHandle = mock(KeyedStateHandle.class); + OperatorStateHandle managedOpHandle = mock(OperatorStateHandle.class); + OperatorStateHandle rawOpHandle = mock(OperatorStateHandle.class); - OperatorStateHandle managedHandle = mock(OperatorStateHandle.class); - ChainedStateHandle chainedManagedHandle = mock(ChainedStateHandle.class); - when(chainedManagedHandle.get(anyInt())).thenReturn(managedHandle); - when(subtaskState.getManagedOperatorState()).thenReturn(chainedManagedHandle); + final OperatorSubtaskState operatorSubtaskState = spy(new OperatorSubtaskState( + legacyHandle, + managedOpHandle, + rawOpHandle, + managedKeyedHandle, + rawKeyedHandle)); - OperatorStateHandle rawHandle = mock(OperatorStateHandle.class); - ChainedStateHandle chainedRawHandle = mock(ChainedStateHandle.class); - when(chainedRawHandle.get(anyInt())).thenReturn(rawHandle); - when(subtaskState.getRawOperatorState()).thenReturn(chainedRawHandle); + TaskStateSnapshot subtaskState = spy(new TaskStateSnapshot()); + subtaskState.putSubtaskStateByOperatorID(new OperatorID(), operatorSubtaskState); + + when(subtaskState.getSubtaskStateByOperatorID(OperatorID.fromJobVertexID(vertex.getJobvertexId()))).thenReturn(operatorSubtaskState); - KeyedStateHandle managedKeyedHandle = mock(KeyedStateHandle.class); - when(subtaskState.getRawKeyedState()).thenReturn(managedKeyedHandle); - KeyedStateHandle managedRawHandle = mock(KeyedStateHandle.class); - when(subtaskState.getManagedKeyedState()).thenReturn(managedRawHandle); - AcknowledgeCheckpoint acknowledgeMessage = new AcknowledgeCheckpoint(jid, executionAttemptId, checkpointId, new CheckpointMetrics(), subtaskState); try { @@ -126,11 +125,12 @@ public void testFailingCompletedCheckpointStoreAdd() throws Exception { assertTrue(pendingCheckpoint.isDiscarded()); // make sure that the subtask state has been discarded after we could not complete it. - verify(subtaskState.getLegacyOperatorState().get(0)).discardState(); - verify(subtaskState.getManagedOperatorState().get(0)).discardState(); - verify(subtaskState.getRawOperatorState().get(0)).discardState(); - verify(subtaskState.getManagedKeyedState()).discardState(); - verify(subtaskState.getRawKeyedState()).discardState(); + verify(operatorSubtaskState).discardState(); + verify(operatorSubtaskState.getLegacyOperatorState()).discardState(); + verify(operatorSubtaskState.getManagedOperatorState()).discardState(); + verify(operatorSubtaskState.getRawOperatorState()).discardState(); + verify(operatorSubtaskState.getManagedKeyedState()).discardState(); + verify(operatorSubtaskState.getRawKeyedState()).discardState(); } private static final class FailingCompletedCheckpointStore implements CompletedCheckpointStore { diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java index 186a8196aaaa5..7b87d1ea81a44 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java @@ -18,8 +18,6 @@ package org.apache.flink.runtime.checkpoint; -import com.google.common.collect.Iterables; -import com.google.common.collect.Lists; import org.apache.flink.api.common.JobID; import org.apache.flink.api.common.time.Time; import org.apache.flink.api.java.tuple.Tuple2; @@ -56,6 +54,9 @@ import org.apache.flink.util.InstantiationUtil; import org.apache.flink.util.Preconditions; import org.apache.flink.util.TestLogger; + +import com.google.common.collect.Iterables; +import com.google.common.collect.Lists; import org.junit.Assert; import org.junit.Rule; import org.junit.Test; @@ -91,7 +92,6 @@ import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyLong; import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; @@ -100,7 +100,6 @@ import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import static org.mockito.Mockito.withSettings; /** * Tests for the checkpoint coordinator. @@ -553,31 +552,29 @@ public void testTriggerAndConfirmSimpleCheckpoint() { assertFalse(checkpoint.isDiscarded()); assertFalse(checkpoint.isFullyAcknowledged()); - OperatorID opID1 = OperatorID.fromJobVertexID(vertex1.getJobvertexId()); - OperatorID opID2 = OperatorID.fromJobVertexID(vertex2.getJobvertexId()); - - Map operatorStates = checkpoint.getOperatorStates(); - - operatorStates.put(opID1, new SpyInjectingOperatorState( - opID1, vertex1.getTotalNumberOfParallelSubtasks(), vertex1.getMaxParallelism())); - operatorStates.put(opID2, new SpyInjectingOperatorState( - opID2, vertex2.getTotalNumberOfParallelSubtasks(), vertex2.getMaxParallelism())); - // check that the vertices received the trigger checkpoint message { verify(vertex1.getCurrentExecutionAttempt(), times(1)).triggerCheckpoint(eq(checkpointId), eq(timestamp), any(CheckpointOptions.class)); verify(vertex2.getCurrentExecutionAttempt(), times(1)).triggerCheckpoint(eq(checkpointId), eq(timestamp), any(CheckpointOptions.class)); } + OperatorID opID1 = OperatorID.fromJobVertexID(vertex1.getJobvertexId()); + OperatorID opID2 = OperatorID.fromJobVertexID(vertex2.getJobvertexId()); + TaskStateSnapshot taskOperatorSubtaskStates1 = mock(TaskStateSnapshot.class); + TaskStateSnapshot taskOperatorSubtaskStates2 = mock(TaskStateSnapshot.class); + OperatorSubtaskState subtaskState1 = mock(OperatorSubtaskState.class); + OperatorSubtaskState subtaskState2 = mock(OperatorSubtaskState.class); + when(taskOperatorSubtaskStates1.getSubtaskStateByOperatorID(opID1)).thenReturn(subtaskState1); + when(taskOperatorSubtaskStates2.getSubtaskStateByOperatorID(opID2)).thenReturn(subtaskState2); + // acknowledge from one of the tasks - AcknowledgeCheckpoint acknowledgeCheckpoint1 = new AcknowledgeCheckpoint(jid, attemptID2, checkpointId, new CheckpointMetrics(), mock(SubtaskState.class)); + AcknowledgeCheckpoint acknowledgeCheckpoint1 = new AcknowledgeCheckpoint(jid, attemptID2, checkpointId, new CheckpointMetrics(), taskOperatorSubtaskStates2); coord.receiveAcknowledgeMessage(acknowledgeCheckpoint1); - OperatorSubtaskState subtaskState2 = operatorStates.get(opID2).getState(vertex2.getParallelSubtaskIndex()); assertEquals(1, checkpoint.getNumberOfAcknowledgedTasks()); assertEquals(1, checkpoint.getNumberOfNonAcknowledgedTasks()); assertFalse(checkpoint.isDiscarded()); assertFalse(checkpoint.isFullyAcknowledged()); - verify(subtaskState2, never()).registerSharedStates(any(SharedStateRegistry.class)); + verify(taskOperatorSubtaskStates2, never()).registerSharedStates(any(SharedStateRegistry.class)); // acknowledge the same task again (should not matter) coord.receiveAcknowledgeMessage(acknowledgeCheckpoint1); @@ -586,8 +583,7 @@ public void testTriggerAndConfirmSimpleCheckpoint() { verify(subtaskState2, never()).registerSharedStates(any(SharedStateRegistry.class)); // acknowledge the other task. - coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID1, checkpointId, new CheckpointMetrics(), mock(SubtaskState.class))); - OperatorSubtaskState subtaskState1 = operatorStates.get(opID1).getState(vertex1.getParallelSubtaskIndex()); + coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID1, checkpointId, new CheckpointMetrics(), taskOperatorSubtaskStates1)); // the checkpoint is internally converted to a successful checkpoint and the // pending checkpoint object is disposed @@ -626,9 +622,7 @@ public void testTriggerAndConfirmSimpleCheckpoint() { long checkpointIdNew = coord.getPendingCheckpoints().entrySet().iterator().next().getKey(); coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID1, checkpointIdNew)); - subtaskState1 = operatorStates.get(opID1).getState(vertex1.getParallelSubtaskIndex()); coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID2, checkpointIdNew)); - subtaskState2 = operatorStates.get(opID2).getState(vertex2.getParallelSubtaskIndex()); assertEquals(0, coord.getNumberOfPendingCheckpoints()); assertEquals(1, coord.getNumberOfRetainedSuccessfulCheckpoints()); @@ -850,18 +844,20 @@ public void testSuccessfulCheckpointSubsumesUnsuccessful() { OperatorID opID2 = OperatorID.fromJobVertexID(ackVertex2.getJobvertexId()); OperatorID opID3 = OperatorID.fromJobVertexID(ackVertex3.getJobvertexId()); - Map operatorStates1 = pending1.getOperatorStates(); + TaskStateSnapshot taskOperatorSubtaskStates1_1 = spy(new TaskStateSnapshot()); + TaskStateSnapshot taskOperatorSubtaskStates1_2 = spy(new TaskStateSnapshot()); + TaskStateSnapshot taskOperatorSubtaskStates1_3 = spy(new TaskStateSnapshot()); - operatorStates1.put(opID1, new SpyInjectingOperatorState( - opID1, ackVertex1.getTotalNumberOfParallelSubtasks(), ackVertex1.getMaxParallelism())); - operatorStates1.put(opID2, new SpyInjectingOperatorState( - opID2, ackVertex2.getTotalNumberOfParallelSubtasks(), ackVertex2.getMaxParallelism())); - operatorStates1.put(opID3, new SpyInjectingOperatorState( - opID3, ackVertex3.getTotalNumberOfParallelSubtasks(), ackVertex3.getMaxParallelism())); + OperatorSubtaskState subtaskState1_1 = mock(OperatorSubtaskState.class); + OperatorSubtaskState subtaskState1_2 = mock(OperatorSubtaskState.class); + OperatorSubtaskState subtaskState1_3 = mock(OperatorSubtaskState.class); + taskOperatorSubtaskStates1_1.putSubtaskStateByOperatorID(opID1, subtaskState1_1); + taskOperatorSubtaskStates1_2.putSubtaskStateByOperatorID(opID2, subtaskState1_2); + taskOperatorSubtaskStates1_3.putSubtaskStateByOperatorID(opID3, subtaskState1_3); // acknowledge one of the three tasks - coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID2, checkpointId1, new CheckpointMetrics(), mock(SubtaskState.class))); - OperatorSubtaskState subtaskState1_2 = operatorStates1.get(opID2).getState(ackVertex2.getParallelSubtaskIndex()); + coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID2, checkpointId1, new CheckpointMetrics(), taskOperatorSubtaskStates1_2)); + // start the second checkpoint // trigger the first checkpoint. this should succeed assertTrue(coord.triggerCheckpoint(timestamp2, false)); @@ -878,14 +874,17 @@ public void testSuccessfulCheckpointSubsumesUnsuccessful() { } long checkpointId2 = pending2.getCheckpointId(); - Map operatorStates2 = pending2.getOperatorStates(); + TaskStateSnapshot taskOperatorSubtaskStates2_1 = spy(new TaskStateSnapshot()); + TaskStateSnapshot taskOperatorSubtaskStates2_2 = spy(new TaskStateSnapshot()); + TaskStateSnapshot taskOperatorSubtaskStates2_3 = spy(new TaskStateSnapshot()); - operatorStates2.put(opID1, new SpyInjectingOperatorState( - opID1, ackVertex1.getTotalNumberOfParallelSubtasks(), ackVertex1.getMaxParallelism())); - operatorStates2.put(opID2, new SpyInjectingOperatorState( - opID2, ackVertex2.getTotalNumberOfParallelSubtasks(), ackVertex2.getMaxParallelism())); - operatorStates2.put(opID3, new SpyInjectingOperatorState( - opID3, ackVertex3.getTotalNumberOfParallelSubtasks(), ackVertex3.getMaxParallelism())); + OperatorSubtaskState subtaskState2_1 = mock(OperatorSubtaskState.class); + OperatorSubtaskState subtaskState2_2 = mock(OperatorSubtaskState.class); + OperatorSubtaskState subtaskState2_3 = mock(OperatorSubtaskState.class); + + taskOperatorSubtaskStates2_1.putSubtaskStateByOperatorID(opID1, subtaskState2_1); + taskOperatorSubtaskStates2_2.putSubtaskStateByOperatorID(opID2, subtaskState2_2); + taskOperatorSubtaskStates2_3.putSubtaskStateByOperatorID(opID3, subtaskState2_3); // trigger messages should have been sent verify(triggerVertex1.getCurrentExecutionAttempt(), times(1)).triggerCheckpoint(eq(checkpointId2), eq(timestamp2), any(CheckpointOptions.class)); @@ -894,17 +893,13 @@ public void testSuccessfulCheckpointSubsumesUnsuccessful() { // we acknowledge one more task from the first checkpoint and the second // checkpoint completely. The second checkpoint should then subsume the first checkpoint - coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID3, checkpointId2, new CheckpointMetrics(), mock(SubtaskState.class))); - OperatorSubtaskState subtaskState2_3 = operatorStates2.get(opID3).getState(ackVertex3.getParallelSubtaskIndex()); + coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID3, checkpointId2, new CheckpointMetrics(), taskOperatorSubtaskStates2_3)); - coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID1, checkpointId2, new CheckpointMetrics(), mock(SubtaskState.class))); - OperatorSubtaskState subtaskState2_1 = operatorStates2.get(opID1).getState(ackVertex1.getParallelSubtaskIndex()); + coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID1, checkpointId2, new CheckpointMetrics(), taskOperatorSubtaskStates2_1)); - coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID1, checkpointId1, new CheckpointMetrics(), mock(SubtaskState.class))); - OperatorSubtaskState subtaskState1_1 = operatorStates1.get(opID1).getState(ackVertex1.getParallelSubtaskIndex()); + coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID1, checkpointId1, new CheckpointMetrics(), taskOperatorSubtaskStates1_1)); - coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID2, checkpointId2, new CheckpointMetrics(), mock(SubtaskState.class))); - OperatorSubtaskState subtaskState2_2 = operatorStates2.get(opID2).getState(ackVertex2.getParallelSubtaskIndex()); + coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID2, checkpointId2, new CheckpointMetrics(), taskOperatorSubtaskStates2_2)); // now, the second checkpoint should be confirmed, and the first discarded // actually both pending checkpoints are discarded, and the second has been transformed @@ -936,8 +931,7 @@ public void testSuccessfulCheckpointSubsumesUnsuccessful() { verify(commitVertex.getCurrentExecutionAttempt(), times(1)).notifyCheckpointComplete(eq(checkpointId2), eq(timestamp2)); // send the last remaining ack for the first checkpoint. This should not do anything - SubtaskState subtaskState1_3 = mock(SubtaskState.class); - coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID3, checkpointId1, new CheckpointMetrics(), subtaskState1_3)); + coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID3, checkpointId1, new CheckpointMetrics(), taskOperatorSubtaskStates1_3)); verify(subtaskState1_3, times(1)).discardState(); coord.shutdown(JobStatus.FINISHED); @@ -1003,13 +997,11 @@ public void testCheckpointTimeoutIsolated() { OperatorID opID1 = OperatorID.fromJobVertexID(ackVertex1.getJobvertexId()); - Map operatorStates = checkpoint.getOperatorStates(); - - operatorStates.put(opID1, new SpyInjectingOperatorState( - opID1, ackVertex1.getTotalNumberOfParallelSubtasks(), ackVertex1.getMaxParallelism())); + TaskStateSnapshot taskOperatorSubtaskStates1 = spy(new TaskStateSnapshot()); + OperatorSubtaskState subtaskState1 = mock(OperatorSubtaskState.class); + taskOperatorSubtaskStates1.putSubtaskStateByOperatorID(opID1, subtaskState1); - coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID1, checkpoint.getCheckpointId(), new CheckpointMetrics(), mock(SubtaskState.class))); - OperatorSubtaskState subtaskState = operatorStates.get(opID1).getState(ackVertex1.getParallelSubtaskIndex()); + coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID1, checkpoint.getCheckpointId(), new CheckpointMetrics(), taskOperatorSubtaskStates1)); // wait until the checkpoint must have expired. // we check every 250 msecs conservatively for 5 seconds @@ -1027,7 +1019,7 @@ public void testCheckpointTimeoutIsolated() { assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints()); // validate that the received states have been discarded - verify(subtaskState, times(1)).discardState(); + verify(subtaskState1, times(1)).discardState(); // no confirm message must have been sent verify(commitVertex.getCurrentExecutionAttempt(), times(0)).notifyCheckpointComplete(anyLong(), anyLong()); @@ -1145,26 +1137,18 @@ public void testStateCleanupForLateOrUnknownMessages() throws Exception { long checkpointId = pendingCheckpoint.getCheckpointId(); OperatorID opIDtrigger = OperatorID.fromJobVertexID(triggerVertex.getJobvertexId()); - OperatorID opID1 = OperatorID.fromJobVertexID(ackVertex1.getJobvertexId()); - OperatorID opID2 = OperatorID.fromJobVertexID(ackVertex2.getJobvertexId()); - Map operatorStates = pendingCheckpoint.getOperatorStates(); - - operatorStates.put(opIDtrigger, new SpyInjectingOperatorState( - opIDtrigger, triggerVertex.getTotalNumberOfParallelSubtasks(), triggerVertex.getMaxParallelism())); - operatorStates.put(opID1, new SpyInjectingOperatorState( - opID1, ackVertex1.getTotalNumberOfParallelSubtasks(), ackVertex1.getMaxParallelism())); - operatorStates.put(opID2, new SpyInjectingOperatorState( - opID2, ackVertex2.getTotalNumberOfParallelSubtasks(), ackVertex2.getMaxParallelism())); + TaskStateSnapshot taskOperatorSubtaskStatesTrigger = spy(new TaskStateSnapshot()); + OperatorSubtaskState subtaskStateTrigger = mock(OperatorSubtaskState.class); + taskOperatorSubtaskStatesTrigger.putSubtaskStateByOperatorID(opIDtrigger, subtaskStateTrigger); // acknowledge the first trigger vertex - coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jobId, triggerAttemptId, checkpointId, new CheckpointMetrics(), mock(SubtaskState.class))); - OperatorSubtaskState storedTriggerSubtaskState = operatorStates.get(opIDtrigger).getState(triggerVertex.getParallelSubtaskIndex()); + coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jobId, triggerAttemptId, checkpointId, new CheckpointMetrics(), taskOperatorSubtaskStatesTrigger)); // verify that the subtask state has not been discarded - verify(storedTriggerSubtaskState, never()).discardState(); + verify(subtaskStateTrigger, never()).discardState(); - SubtaskState unknownSubtaskState = mock(SubtaskState.class); + TaskStateSnapshot unknownSubtaskState = mock(TaskStateSnapshot.class); // receive an acknowledge message for an unknown vertex coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jobId, new ExecutionAttemptID(), checkpointId, new CheckpointMetrics(), unknownSubtaskState)); @@ -1172,7 +1156,7 @@ public void testStateCleanupForLateOrUnknownMessages() throws Exception { // we should discard acknowledge messages from an unknown vertex belonging to our job verify(unknownSubtaskState, times(1)).discardState(); - SubtaskState differentJobSubtaskState = mock(SubtaskState.class); + TaskStateSnapshot differentJobSubtaskState = mock(TaskStateSnapshot.class); // receive an acknowledge message from an unknown job coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(new JobID(), new ExecutionAttemptID(), checkpointId, new CheckpointMetrics(), differentJobSubtaskState)); @@ -1181,22 +1165,22 @@ public void testStateCleanupForLateOrUnknownMessages() throws Exception { verify(differentJobSubtaskState, never()).discardState(); // duplicate acknowledge message for the trigger vertex - SubtaskState triggerSubtaskState = mock(SubtaskState.class); + TaskStateSnapshot triggerSubtaskState = mock(TaskStateSnapshot.class); coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jobId, triggerAttemptId, checkpointId, new CheckpointMetrics(), triggerSubtaskState)); // duplicate acknowledge messages for a known vertex should not trigger discarding the state verify(triggerSubtaskState, never()).discardState(); // let the checkpoint fail at the first ack vertex - reset(storedTriggerSubtaskState); + reset(subtaskStateTrigger); coord.receiveDeclineMessage(new DeclineCheckpoint(jobId, ackAttemptId1, checkpointId)); assertTrue(pendingCheckpoint.isDiscarded()); // check that we've cleaned up the already acknowledged state - verify(storedTriggerSubtaskState, times(1)).discardState(); + verify(subtaskStateTrigger, times(1)).discardState(); - SubtaskState ackSubtaskState = mock(SubtaskState.class); + TaskStateSnapshot ackSubtaskState = mock(TaskStateSnapshot.class); // late acknowledge message from the second ack vertex coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jobId, ackAttemptId2, checkpointId, new CheckpointMetrics(), ackSubtaskState)); @@ -1211,7 +1195,7 @@ public void testStateCleanupForLateOrUnknownMessages() throws Exception { // we should not interfere with different jobs verify(differentJobSubtaskState, never()).discardState(); - SubtaskState unknownSubtaskState2 = mock(SubtaskState.class); + TaskStateSnapshot unknownSubtaskState2 = mock(TaskStateSnapshot.class); // receive an acknowledge message for an unknown vertex coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jobId, new ExecutionAttemptID(), checkpointId, new CheckpointMetrics(), unknownSubtaskState2)); @@ -1468,18 +1452,16 @@ public void testTriggerAndConfirmSimpleSavepoint() throws Exception { OperatorID opID1 = OperatorID.fromJobVertexID(vertex1.getJobvertexId()); OperatorID opID2 = OperatorID.fromJobVertexID(vertex2.getJobvertexId()); - - Map operatorStates = pending.getOperatorStates(); - - operatorStates.put(opID1, new SpyInjectingOperatorState( - opID1, vertex1.getTotalNumberOfParallelSubtasks(), vertex1.getMaxParallelism())); - operatorStates.put(opID2, new SpyInjectingOperatorState( - opID2, vertex2.getTotalNumberOfParallelSubtasks(), vertex1.getMaxParallelism())); + TaskStateSnapshot taskOperatorSubtaskStates1 = mock(TaskStateSnapshot.class); + TaskStateSnapshot taskOperatorSubtaskStates2 = mock(TaskStateSnapshot.class); + OperatorSubtaskState subtaskState1 = mock(OperatorSubtaskState.class); + OperatorSubtaskState subtaskState2 = mock(OperatorSubtaskState.class); + when(taskOperatorSubtaskStates1.getSubtaskStateByOperatorID(opID1)).thenReturn(subtaskState1); + when(taskOperatorSubtaskStates2.getSubtaskStateByOperatorID(opID2)).thenReturn(subtaskState2); // acknowledge from one of the tasks - AcknowledgeCheckpoint acknowledgeCheckpoint2 = new AcknowledgeCheckpoint(jid, attemptID2, checkpointId, new CheckpointMetrics(), mock(SubtaskState.class)); + AcknowledgeCheckpoint acknowledgeCheckpoint2 = new AcknowledgeCheckpoint(jid, attemptID2, checkpointId, new CheckpointMetrics(), taskOperatorSubtaskStates2); coord.receiveAcknowledgeMessage(acknowledgeCheckpoint2); - OperatorSubtaskState subtaskState2 = operatorStates.get(opID2).getState(vertex2.getParallelSubtaskIndex()); assertEquals(1, pending.getNumberOfAcknowledgedTasks()); assertEquals(1, pending.getNumberOfNonAcknowledgedTasks()); assertFalse(pending.isDiscarded()); @@ -1493,8 +1475,7 @@ public void testTriggerAndConfirmSimpleSavepoint() throws Exception { assertFalse(savepointFuture.isDone()); // acknowledge the other task. - coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID1, checkpointId, new CheckpointMetrics(), mock(SubtaskState.class))); - OperatorSubtaskState subtaskState1 = operatorStates.get(opID1).getState(vertex1.getParallelSubtaskIndex()); + coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID1, checkpointId, new CheckpointMetrics(), taskOperatorSubtaskStates1)); // the checkpoint is internally converted to a successful checkpoint and the // pending checkpoint object is disposed @@ -1534,9 +1515,6 @@ public void testTriggerAndConfirmSimpleSavepoint() throws Exception { coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID1, checkpointIdNew)); coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID2, checkpointIdNew)); - subtaskState1 = operatorStates.get(opID1).getState(vertex1.getParallelSubtaskIndex()); - subtaskState2 = operatorStates.get(opID2).getState(vertex2.getParallelSubtaskIndex()); - assertEquals(0, coord.getNumberOfPendingCheckpoints()); assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints()); @@ -2035,20 +2013,8 @@ public void testRestoreLatestCheckpointedState() throws Exception { List keyGroupPartitions1 = StateAssignmentOperation.createKeyGroupPartitions(maxParallelism1, parallelism1); List keyGroupPartitions2 = StateAssignmentOperation.createKeyGroupPartitions(maxParallelism2, parallelism2); - PendingCheckpoint pending = coord.getPendingCheckpoints().get(checkpointId); - - OperatorID opID1 = OperatorID.fromJobVertexID(jobVertexID1); - OperatorID opID2 = OperatorID.fromJobVertexID(jobVertexID2); - - Map operatorStates = pending.getOperatorStates(); - - operatorStates.put(opID1, new SpyInjectingOperatorState( - opID1, jobVertex1.getParallelism(), jobVertex1.getMaxParallelism())); - operatorStates.put(opID2, new SpyInjectingOperatorState( - opID2, jobVertex2.getParallelism(), jobVertex2.getMaxParallelism())); - for (int index = 0; index < jobVertex1.getParallelism(); index++) { - SubtaskState subtaskState = mockSubtaskState(jobVertexID1, index, keyGroupPartitions1.get(index)); + TaskStateSnapshot subtaskState = mockSubtaskState(jobVertexID1, index, keyGroupPartitions1.get(index)); AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint( jid, @@ -2061,7 +2027,7 @@ public void testRestoreLatestCheckpointedState() throws Exception { } for (int index = 0; index < jobVertex2.getParallelism(); index++) { - SubtaskState subtaskState = mockSubtaskState(jobVertexID2, index, keyGroupPartitions2.get(index)); + TaskStateSnapshot subtaskState = mockSubtaskState(jobVertexID2, index, keyGroupPartitions2.get(index)); AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint( jid, @@ -2163,30 +2129,34 @@ public void testRestoreLatestCheckpointFailureWhenMaxParallelismChanges() throws List keyGroupPartitions2 = StateAssignmentOperation.createKeyGroupPartitions(maxParallelism2, parallelism2); for (int index = 0; index < jobVertex1.getParallelism(); index++) { - ChainedStateHandle valueSizeTuple = generateStateForVertex(jobVertexID1, index); + StreamStateHandle valueSizeTuple = generateStateForVertex(jobVertexID1, index); KeyGroupsStateHandle keyGroupState = generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index), false); - SubtaskState checkpointStateHandles = new SubtaskState(valueSizeTuple, null, null, keyGroupState, null); + OperatorSubtaskState operatorSubtaskState = new OperatorSubtaskState(valueSizeTuple, null, null, keyGroupState, null); + TaskStateSnapshot taskOperatorSubtaskStates = new TaskStateSnapshot(); + taskOperatorSubtaskStates.putSubtaskStateByOperatorID(OperatorID.fromJobVertexID(jobVertexID1), operatorSubtaskState); AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint( jid, jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(), checkpointId, new CheckpointMetrics(), - checkpointStateHandles); + taskOperatorSubtaskStates); coord.receiveAcknowledgeMessage(acknowledgeCheckpoint); } for (int index = 0; index < jobVertex2.getParallelism(); index++) { - ChainedStateHandle valueSizeTuple = generateStateForVertex(jobVertexID2, index); + StreamStateHandle valueSizeTuple = generateStateForVertex(jobVertexID2, index); KeyGroupsStateHandle keyGroupState = generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index), false); - SubtaskState checkpointStateHandles = new SubtaskState(valueSizeTuple, null, null, keyGroupState, null); + OperatorSubtaskState operatorSubtaskState = new OperatorSubtaskState(valueSizeTuple, null, null, keyGroupState, null); + TaskStateSnapshot taskOperatorSubtaskStates = new TaskStateSnapshot(); + taskOperatorSubtaskStates.putSubtaskStateByOperatorID(OperatorID.fromJobVertexID(jobVertexID2), operatorSubtaskState); AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint( jid, jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(), checkpointId, new CheckpointMetrics(), - checkpointStateHandles); + taskOperatorSubtaskStates); coord.receiveAcknowledgeMessage(acknowledgeCheckpoint); } @@ -2282,17 +2252,20 @@ public void testRestoreLatestCheckpointFailureWhenParallelismChanges() throws Ex StateAssignmentOperation.createKeyGroupPartitions(maxParallelism2, parallelism2); for (int index = 0; index < jobVertex1.getParallelism(); index++) { - ChainedStateHandle valueSizeTuple = generateStateForVertex(jobVertexID1, index); + StreamStateHandle valueSizeTuple = generateStateForVertex(jobVertexID1, index); KeyGroupsStateHandle keyGroupState = generateKeyGroupState( jobVertexID1, keyGroupPartitions1.get(index), false); - SubtaskState checkpointStateHandles = new SubtaskState(valueSizeTuple, null, null, keyGroupState, null); + OperatorSubtaskState operatorSubtaskState = new OperatorSubtaskState(valueSizeTuple, null, null, keyGroupState, null); + TaskStateSnapshot taskOperatorSubtaskStates = new TaskStateSnapshot(); + taskOperatorSubtaskStates.putSubtaskStateByOperatorID(OperatorID.fromJobVertexID(jobVertexID1), operatorSubtaskState); + AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint( jid, jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(), checkpointId, new CheckpointMetrics(), - checkpointStateHandles); + taskOperatorSubtaskStates); coord.receiveAcknowledgeMessage(acknowledgeCheckpoint); } @@ -2300,17 +2273,19 @@ public void testRestoreLatestCheckpointFailureWhenParallelismChanges() throws Ex for (int index = 0; index < jobVertex2.getParallelism(); index++) { - ChainedStateHandle state = generateStateForVertex(jobVertexID2, index); + StreamStateHandle state = generateStateForVertex(jobVertexID2, index); KeyGroupsStateHandle keyGroupState = generateKeyGroupState( jobVertexID2, keyGroupPartitions2.get(index), false); - SubtaskState checkpointStateHandles = new SubtaskState(state, null, null, keyGroupState, null); + OperatorSubtaskState operatorSubtaskState = new OperatorSubtaskState(state, null, null, keyGroupState, null); + TaskStateSnapshot taskOperatorSubtaskStates = new TaskStateSnapshot(); + taskOperatorSubtaskStates.putSubtaskStateByOperatorID(OperatorID.fromJobVertexID(jobVertexID2), operatorSubtaskState); AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint( jid, jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(), checkpointId, new CheckpointMetrics(), - checkpointStateHandles); + taskOperatorSubtaskStates); coord.receiveAcknowledgeMessage(acknowledgeCheckpoint); } @@ -2436,18 +2411,21 @@ private void testRestoreLatestCheckpointedStateWithChangingParallelism(boolean s //vertex 1 for (int index = 0; index < jobVertex1.getParallelism(); index++) { - ChainedStateHandle valueSizeTuple = generateStateForVertex(jobVertexID1, index); - ChainedStateHandle opStateBackend = generateChainedPartitionableStateHandle(jobVertexID1, index, 2, 8, false); + StreamStateHandle valueSizeTuple = generateStateForVertex(jobVertexID1, index); + OperatorStateHandle opStateBackend = generatePartitionableStateHandle(jobVertexID1, index, 2, 8, false); KeyGroupsStateHandle keyedStateBackend = generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index), false); KeyGroupsStateHandle keyedStateRaw = generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index), true); - SubtaskState checkpointStateHandles = new SubtaskState(valueSizeTuple, opStateBackend, null, keyedStateBackend, keyedStateRaw); + OperatorSubtaskState operatorSubtaskState = new OperatorSubtaskState(valueSizeTuple, opStateBackend, null, keyedStateBackend, keyedStateRaw); + TaskStateSnapshot taskOperatorSubtaskStates = new TaskStateSnapshot(); + taskOperatorSubtaskStates.putSubtaskStateByOperatorID(OperatorID.fromJobVertexID(jobVertexID1), operatorSubtaskState); + AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint( jid, jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(), checkpointId, new CheckpointMetrics(), - checkpointStateHandles); + taskOperatorSubtaskStates); coord.receiveAcknowledgeMessage(acknowledgeCheckpoint); } @@ -2458,19 +2436,21 @@ private void testRestoreLatestCheckpointedStateWithChangingParallelism(boolean s for (int index = 0; index < jobVertex2.getParallelism(); index++) { KeyGroupsStateHandle keyedStateBackend = generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index), false); KeyGroupsStateHandle keyedStateRaw = generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index), true); - ChainedStateHandle opStateBackend = generateChainedPartitionableStateHandle(jobVertexID2, index, 2, 8, false); - ChainedStateHandle opStateRaw = generateChainedPartitionableStateHandle(jobVertexID2, index, 2, 8, true); - expectedOpStatesBackend.add(opStateBackend); - expectedOpStatesRaw.add(opStateRaw); - SubtaskState checkpointStateHandles = - new SubtaskState(new ChainedStateHandle<>( - Collections.singletonList(null)), opStateBackend, opStateRaw, keyedStateBackend, keyedStateRaw); + OperatorStateHandle opStateBackend = generatePartitionableStateHandle(jobVertexID2, index, 2, 8, false); + OperatorStateHandle opStateRaw = generatePartitionableStateHandle(jobVertexID2, index, 2, 8, true); + expectedOpStatesBackend.add(new ChainedStateHandle<>(Collections.singletonList(opStateBackend))); + expectedOpStatesRaw.add(new ChainedStateHandle<>(Collections.singletonList(opStateRaw))); + + OperatorSubtaskState operatorSubtaskState = new OperatorSubtaskState(null, opStateBackend, opStateRaw, keyedStateBackend, keyedStateRaw); + TaskStateSnapshot taskOperatorSubtaskStates = new TaskStateSnapshot(); + taskOperatorSubtaskStates.putSubtaskStateByOperatorID(OperatorID.fromJobVertexID(jobVertexID2), operatorSubtaskState); + AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint( jid, jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(), checkpointId, new CheckpointMetrics(), - checkpointStateHandles); + taskOperatorSubtaskStates); coord.receiveAcknowledgeMessage(acknowledgeCheckpoint); } @@ -2525,6 +2505,7 @@ private void testRestoreLatestCheckpointedStateWithChangingParallelism(boolean s compareKeyedState(Collections.singletonList(originalKeyedStateBackend), keyedStateBackend); compareKeyedState(Collections.singletonList(originalKeyedStateRaw), keyGroupStateRaw); } + comparePartitionableState(expectedOpStatesBackend, actualOpStatesBackend); comparePartitionableState(expectedOpStatesRaw, actualOpStatesRaw); } @@ -2576,14 +2557,11 @@ public void testStateRecoveryWithTopologyChange(int scaleType) throws Exception operatorStates.put(id.f1, taskState); for (int index = 0; index < taskState.getParallelism(); index++) { StreamStateHandle subNonPartitionedState = - generateStateForVertex(id.f0, index) - .get(0); + generateStateForVertex(id.f0, index); OperatorStateHandle subManagedOperatorState = - generateChainedPartitionableStateHandle(id.f0, index, 2, 8, false) - .get(0); + generatePartitionableStateHandle(id.f0, index, 2, 8, false); OperatorStateHandle subRawOperatorState = - generateChainedPartitionableStateHandle(id.f0, index, 2, 8, true) - .get(0); + generatePartitionableStateHandle(id.f0, index, 2, 8, true); OperatorSubtaskState subtaskState = new OperatorSubtaskState(subNonPartitionedState, subManagedOperatorState, @@ -2723,38 +2701,38 @@ public void testStateRecoveryWithTopologyChange(int scaleType) throws Exception // operator1 { int operatorIndexInChain = 1; - ChainedStateHandle expectSubNonPartitionedState = generateStateForVertex(id1.f0, i); - ChainedStateHandle expectedManagedOpState = generateChainedPartitionableStateHandle( + StreamStateHandle expectSubNonPartitionedState = generateStateForVertex(id1.f0, i); + OperatorStateHandle expectedManagedOpState = generatePartitionableStateHandle( id1.f0, i, 2, 8, false); - ChainedStateHandle expectedRawOpState = generateChainedPartitionableStateHandle( + OperatorStateHandle expectedRawOpState = generatePartitionableStateHandle( id1.f0, i, 2, 8, true); assertTrue(CommonTestUtils.isSteamContentEqual( - expectSubNonPartitionedState.get(0).openInputStream(), + expectSubNonPartitionedState.openInputStream(), actualSubNonPartitionedState.get(operatorIndexInChain).openInputStream())); - assertTrue(CommonTestUtils.isSteamContentEqual(expectedManagedOpState.get(0).openInputStream(), + assertTrue(CommonTestUtils.isSteamContentEqual(expectedManagedOpState.openInputStream(), actualSubManagedOperatorState.get(operatorIndexInChain).iterator().next().openInputStream())); - assertTrue(CommonTestUtils.isSteamContentEqual(expectedRawOpState.get(0).openInputStream(), + assertTrue(CommonTestUtils.isSteamContentEqual(expectedRawOpState.openInputStream(), actualSubRawOperatorState.get(operatorIndexInChain).iterator().next().openInputStream())); } // operator2 { int operatorIndexInChain = 0; - ChainedStateHandle expectSubNonPartitionedState = generateStateForVertex(id2.f0, i); - ChainedStateHandle expectedManagedOpState = generateChainedPartitionableStateHandle( + StreamStateHandle expectSubNonPartitionedState = generateStateForVertex(id2.f0, i); + OperatorStateHandle expectedManagedOpState = generatePartitionableStateHandle( id2.f0, i, 2, 8, false); - ChainedStateHandle expectedRawOpState = generateChainedPartitionableStateHandle( + OperatorStateHandle expectedRawOpState = generatePartitionableStateHandle( id2.f0, i, 2, 8, true); - assertTrue(CommonTestUtils.isSteamContentEqual(expectSubNonPartitionedState.get(0).openInputStream(), + assertTrue(CommonTestUtils.isSteamContentEqual(expectSubNonPartitionedState.openInputStream(), actualSubNonPartitionedState.get(operatorIndexInChain).openInputStream())); - assertTrue(CommonTestUtils.isSteamContentEqual(expectedManagedOpState.get(0).openInputStream(), + assertTrue(CommonTestUtils.isSteamContentEqual(expectedManagedOpState.openInputStream(), actualSubManagedOperatorState.get(operatorIndexInChain).iterator().next().openInputStream())); - assertTrue(CommonTestUtils.isSteamContentEqual(expectedRawOpState.get(0).openInputStream(), + assertTrue(CommonTestUtils.isSteamContentEqual(expectedRawOpState.openInputStream(), actualSubRawOperatorState.get(operatorIndexInChain).iterator().next().openInputStream())); } } @@ -2972,19 +2950,50 @@ public static Tuple2> serializeTogetherAndTrackOffsets( return new Tuple2<>(allSerializedValuesConcatenated, offsets); } - public static ChainedStateHandle generateStateForVertex( + public static StreamStateHandle generateStateForVertex( JobVertexID jobVertexID, int index) throws IOException { Random random = new Random(jobVertexID.hashCode() + index); int value = random.nextInt(); - return generateChainedStateHandle(value); + return generateStreamStateHandle(value); + } + + public static StreamStateHandle generateStreamStateHandle(Serializable value) throws IOException { + return TestByteStreamStateHandleDeepCompare.fromSerializable(String.valueOf(UUID.randomUUID()), value); } public static ChainedStateHandle generateChainedStateHandle( Serializable value) throws IOException { return ChainedStateHandle.wrapSingleHandle( - TestByteStreamStateHandleDeepCompare.fromSerializable(String.valueOf(UUID.randomUUID()), value)); + generateStreamStateHandle(value)); + } + + public static OperatorStateHandle generatePartitionableStateHandle( + JobVertexID jobVertexID, + int index, + int namedStates, + int partitionsPerState, + boolean rawState) throws IOException { + + Map> statesListsMap = new HashMap<>(namedStates); + + for (int i = 0; i < namedStates; ++i) { + List testStatesLists = new ArrayList<>(partitionsPerState); + // generate state + int seed = jobVertexID.hashCode() * index + i * namedStates; + if (rawState) { + seed = (seed + 1) * 31; + } + Random random = new Random(seed); + for (int j = 0; j < partitionsPerState; ++j) { + int simulatedStateValue = random.nextInt(); + testStatesLists.add(simulatedStateValue); + } + statesListsMap.put("state-" + i, testStatesLists); + } + + return generatePartitionableStateHandle(statesListsMap); } public static ChainedStateHandle generateChainedPartitionableStateHandle( @@ -3011,11 +3020,11 @@ public static ChainedStateHandle generateChainedPartitionab statesListsMap.put("state-" + i, testStatesLists); } - return generateChainedPartitionableStateHandle(statesListsMap); + return ChainedStateHandle.wrapSingleHandle(generatePartitionableStateHandle(statesListsMap)); } - private static ChainedStateHandle generateChainedPartitionableStateHandle( - Map> states) throws IOException { + private static OperatorStateHandle generatePartitionableStateHandle( + Map> states) throws IOException { List> namedStateSerializables = new ArrayList<>(states.size()); @@ -3030,20 +3039,18 @@ private static ChainedStateHandle generateChainedPartitiona int idx = 0; for (Map.Entry> entry : states.entrySet()) { offsetsMap.put( - entry.getKey(), - new OperatorStateHandle.StateMetaInfo( - serializationWithOffsets.f1.get(idx), - OperatorStateHandle.Mode.SPLIT_DISTRIBUTE)); + entry.getKey(), + new OperatorStateHandle.StateMetaInfo( + serializationWithOffsets.f1.get(idx), + OperatorStateHandle.Mode.SPLIT_DISTRIBUTE)); ++idx; } ByteStreamStateHandle streamStateHandle = new TestByteStreamStateHandleDeepCompare( - String.valueOf(UUID.randomUUID()), - serializationWithOffsets.f0); + String.valueOf(UUID.randomUUID()), + serializationWithOffsets.f0); - OperatorStateHandle operatorStateHandle = - new OperatorStateHandle(offsetsMap, streamStateHandle); - return ChainedStateHandle.wrapSingleHandle(operatorStateHandle); + return new OperatorStateHandle(offsetsMap, streamStateHandle); } static ExecutionJobVertex mockExecutionJobVertex( @@ -3137,24 +3144,23 @@ private static ExecutionVertex mockExecutionVertex( return vertex; } - static SubtaskState mockSubtaskState( + static TaskStateSnapshot mockSubtaskState( JobVertexID jobVertexID, int index, KeyGroupRange keyGroupRange) throws IOException { - ChainedStateHandle nonPartitionedState = generateStateForVertex(jobVertexID, index); - ChainedStateHandle partitionableState = generateChainedPartitionableStateHandle(jobVertexID, index, 2, 8, false); + StreamStateHandle nonPartitionedState = generateStateForVertex(jobVertexID, index); + OperatorStateHandle partitionableState = generatePartitionableStateHandle(jobVertexID, index, 2, 8, false); KeyGroupsStateHandle partitionedKeyGroupState = generateKeyGroupState(jobVertexID, keyGroupRange, false); - SubtaskState subtaskState = mock(SubtaskState.class, withSettings().serializable()); + TaskStateSnapshot subtaskStates = spy(new TaskStateSnapshot()); + OperatorSubtaskState subtaskState = spy(new OperatorSubtaskState( + nonPartitionedState, partitionableState, null, partitionedKeyGroupState, null) + ); - doReturn(nonPartitionedState).when(subtaskState).getLegacyOperatorState(); - doReturn(partitionableState).when(subtaskState).getManagedOperatorState(); - doReturn(null).when(subtaskState).getRawOperatorState(); - doReturn(partitionedKeyGroupState).when(subtaskState).getManagedKeyedState(); - doReturn(null).when(subtaskState).getRawKeyedState(); + subtaskStates.putSubtaskStateByOperatorID(OperatorID.fromJobVertexID(jobVertexID), subtaskState); - return subtaskState; + return subtaskStates; } public static void verifyStateRestore( @@ -3165,10 +3171,10 @@ public static void verifyStateRestore( TaskStateHandles taskStateHandles = executionJobVertex.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskStateHandles(); - ChainedStateHandle expectNonPartitionedState = generateStateForVertex(jobVertexID, i); + StreamStateHandle expectNonPartitionedState = generateStateForVertex(jobVertexID, i); ChainedStateHandle actualNonPartitionedState = taskStateHandles.getLegacyOperatorState(); assertTrue(CommonTestUtils.isSteamContentEqual( - expectNonPartitionedState.get(0).openInputStream(), + expectNonPartitionedState.openInputStream(), actualNonPartitionedState.get(0).openInputStream())); ChainedStateHandle expectedOpStateBackend = @@ -3631,16 +3637,16 @@ public void testSavepointsAreNotAddedToCompletedCheckpointStore() throws Excepti completedCheckpointStore.getLatestCheckpoint().getCheckpointID() == checkpointIDCounter.getLast()); } - private static final class SpyInjectingOperatorState extends OperatorState { - - private static final long serialVersionUID = -4004437428483663815L; - - public SpyInjectingOperatorState(OperatorID taskID, int parallelism, int maxParallelism) { - super(taskID, parallelism, maxParallelism); - } - - public void putState(int subtaskIndex, OperatorSubtaskState subtaskState) { - super.putState(subtaskIndex, spy(subtaskState)); - } - } +// private static final class SpyInjectingOperatorState extends OperatorState { +// +// private static final long serialVersionUID = -4004437428483663815L; +// +// public SpyInjectingOperatorState(OperatorID taskID, int parallelism, int maxParallelism) { +// super(taskID, parallelism, maxParallelism); +// } +// +// public void putState(int subtaskIndex, OperatorSubtaskState subtaskState) { +// super.putState(subtaskIndex, (subtaskState != null) ? spy(subtaskState) : null); +// } +// } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java index 7d2456881fc33..f4807a3ef8b54 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java @@ -36,6 +36,7 @@ import org.apache.flink.runtime.state.StreamStateHandle; import org.apache.flink.runtime.state.TaskStateHandles; import org.apache.flink.runtime.util.SerializableObject; + import org.hamcrest.BaseMatcher; import org.hamcrest.Description; import org.junit.Test; @@ -118,10 +119,22 @@ public void testSetState() { PendingCheckpoint pending = coord.getPendingCheckpoints().values().iterator().next(); final long checkpointId = pending.getCheckpointId(); - SubtaskState checkpointStateHandles = new SubtaskState(serializedState, null, null, serializedKeyGroupStates, null); - coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec1.getAttemptId(), checkpointId, new CheckpointMetrics(), checkpointStateHandles)); - coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec2.getAttemptId(), checkpointId, new CheckpointMetrics(), checkpointStateHandles)); - coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec3.getAttemptId(), checkpointId, new CheckpointMetrics(), checkpointStateHandles)); + TaskStateSnapshot subtaskStates = new TaskStateSnapshot(); + + subtaskStates.putSubtaskStateByOperatorID( + OperatorID.fromJobVertexID(statefulId), + new OperatorSubtaskState( + serializedState.get(0), + null, + null, + serializedKeyGroupStates, + null)); + + //SubtaskState checkpointStateHandles = new SubtaskState(serializedState, null, null, serializedKeyGroupStates, null); + + coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec1.getAttemptId(), checkpointId, new CheckpointMetrics(), subtaskStates)); + coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec2.getAttemptId(), checkpointId, new CheckpointMetrics(), subtaskStates)); + coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec3.getAttemptId(), checkpointId, new CheckpointMetrics(), subtaskStates)); coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statelessExec1.getAttemptId(), checkpointId)); coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statelessExec2.getAttemptId(), checkpointId)); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java index a96b5979a4159..2fc0d0c7c3416 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java @@ -324,7 +324,7 @@ public void testNullSubtaskStateLeadsToStatelessTask() throws Exception { @Test public void testNonNullSubtaskStateLeadsToStatefulTask() throws Exception { PendingCheckpoint pending = createPendingCheckpoint(CheckpointProperties.forStandardCheckpoint(), null); - pending.acknowledgeTask(ATTEMPT_ID, mock(SubtaskState.class), mock(CheckpointMetrics.class)); + pending.acknowledgeTask(ATTEMPT_ID, mock(TaskStateSnapshot.class), mock(CheckpointMetrics.class)); Assert.assertFalse(pending.getOperatorStates().isEmpty()); } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java index a63b02d785f19..38964168e1576 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java @@ -18,16 +18,6 @@ package org.apache.flink.runtime.jobmanager; -import akka.actor.ActorRef; -import akka.actor.ActorSystem; -import akka.actor.Identify; -import akka.actor.PoisonPill; -import akka.actor.Props; -import akka.japi.pf.FI; -import akka.japi.pf.ReceiveBuilder; -import akka.pattern.Patterns; -import akka.testkit.CallingThreadDispatcher; -import akka.testkit.JavaTestKit; import org.apache.flink.api.common.JobID; import org.apache.flink.configuration.ConfigConstants; import org.apache.flink.configuration.Configuration; @@ -44,8 +34,9 @@ import org.apache.flink.runtime.checkpoint.CheckpointOptions; import org.apache.flink.runtime.checkpoint.CheckpointRecoveryFactory; import org.apache.flink.runtime.checkpoint.CompletedCheckpointStore; +import org.apache.flink.runtime.checkpoint.OperatorSubtaskState; import org.apache.flink.runtime.checkpoint.StandaloneCheckpointIDCounter; -import org.apache.flink.runtime.checkpoint.SubtaskState; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.clusterframework.types.ResourceID; import org.apache.flink.runtime.execution.librarycache.BlobLibraryCacheManager; import org.apache.flink.runtime.executiongraph.restart.FixedDelayRestartStrategy; @@ -59,6 +50,7 @@ import org.apache.flink.runtime.jobgraph.JobStatus; import org.apache.flink.runtime.jobgraph.JobVertex; import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable; import org.apache.flink.runtime.jobgraph.tasks.ExternalizedCheckpointSettings; import org.apache.flink.runtime.jobgraph.tasks.JobCheckpointingSettings; @@ -69,8 +61,6 @@ import org.apache.flink.runtime.leaderelection.TestingLeaderRetrievalService; import org.apache.flink.runtime.messages.JobManagerMessages; import org.apache.flink.runtime.metrics.MetricRegistry; -import org.apache.flink.runtime.state.ChainedStateHandle; -import org.apache.flink.runtime.state.StreamStateHandle; import org.apache.flink.runtime.state.TaskStateHandles; import org.apache.flink.runtime.state.memory.ByteStreamStateHandle; import org.apache.flink.runtime.taskmanager.TaskManager; @@ -83,23 +73,24 @@ import org.apache.flink.runtime.testutils.RecoverableCompletedCheckpointStore; import org.apache.flink.runtime.util.TestByteStreamStateHandleDeepCompare; import org.apache.flink.util.InstantiationUtil; - import org.apache.flink.util.TestLogger; + +import akka.actor.ActorRef; +import akka.actor.ActorSystem; +import akka.actor.Identify; +import akka.actor.PoisonPill; +import akka.actor.Props; +import akka.japi.pf.FI; +import akka.japi.pf.ReceiveBuilder; +import akka.pattern.Patterns; +import akka.testkit.CallingThreadDispatcher; +import akka.testkit.JavaTestKit; import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; -import scala.Int; -import scala.Option; -import scala.PartialFunction; -import scala.concurrent.Await; -import scala.concurrent.Future; -import scala.concurrent.duration.Deadline; -import scala.concurrent.duration.FiniteDuration; -import scala.runtime.BoxedUnit; - import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; @@ -113,6 +104,15 @@ import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; +import scala.Int; +import scala.Option; +import scala.PartialFunction; +import scala.concurrent.Await; +import scala.concurrent.Future; +import scala.concurrent.duration.Deadline; +import scala.concurrent.duration.FiniteDuration; +import scala.runtime.BoxedUnit; + import static org.hamcrest.Matchers.containsInAnyOrder; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertThat; @@ -567,10 +567,16 @@ public boolean triggerCheckpoint(CheckpointMetaData checkpointMetaData, Checkpoi String.valueOf(UUID.randomUUID()), InstantiationUtil.serializeObject(checkpointMetaData.getCheckpointId())); - ChainedStateHandle chainedStateHandle = - new ChainedStateHandle(Collections.singletonList(byteStreamStateHandle)); - SubtaskState checkpointStateHandles = - new SubtaskState(chainedStateHandle, null, null, null, null); + TaskStateSnapshot checkpointStateHandles = new TaskStateSnapshot(); + checkpointStateHandles.putSubtaskStateByOperatorID( + OperatorID.fromJobVertexID(getEnvironment().getJobVertexId()), + new OperatorSubtaskState( + byteStreamStateHandle, + null, + null, + null, + null) + ); getEnvironment().acknowledgeCheckpoint( checkpointMetaData.getCheckpointId(), diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java index bc420cc27799b..d022cdcf59e69 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java @@ -24,14 +24,17 @@ import org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTest; import org.apache.flink.runtime.checkpoint.CheckpointMetrics; import org.apache.flink.runtime.checkpoint.CheckpointOptions; -import org.apache.flink.runtime.checkpoint.SubtaskState; +import org.apache.flink.runtime.checkpoint.OperatorSubtaskState; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint; import org.apache.flink.runtime.messages.checkpoint.NotifyCheckpointComplete; import org.apache.flink.runtime.messages.checkpoint.TriggerCheckpoint; import org.apache.flink.runtime.state.KeyGroupRange; import org.apache.flink.runtime.state.StreamStateHandle; + import org.junit.Test; import java.io.IOException; @@ -68,13 +71,17 @@ public void testConfirmTaskCheckpointed() { KeyGroupRange keyGroupRange = KeyGroupRange.of(42,42); - SubtaskState checkpointStateHandles = - new SubtaskState( - CheckpointCoordinatorTest.generateChainedStateHandle(new MyHandle()), - CheckpointCoordinatorTest.generateChainedPartitionableStateHandle(new JobVertexID(), 0, 2, 8, false), - null, - CheckpointCoordinatorTest.generateKeyGroupState(keyGroupRange, Collections.singletonList(new MyHandle())), - null); + TaskStateSnapshot checkpointStateHandles = new TaskStateSnapshot(); + checkpointStateHandles.putSubtaskStateByOperatorID( + new OperatorID(), + new OperatorSubtaskState( + CheckpointCoordinatorTest.generateStreamStateHandle(new MyHandle()), + CheckpointCoordinatorTest.generatePartitionableStateHandle(new JobVertexID(), 0, 2, 8, false), + null, + CheckpointCoordinatorTest.generateKeyGroupState(keyGroupRange, Collections.singletonList(new MyHandle())), + null + ) + ); AcknowledgeCheckpoint withState = new AcknowledgeCheckpoint( new JobID(), diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java index 851fa967be729..8ed06b2ef3682 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java @@ -26,7 +26,7 @@ import org.apache.flink.runtime.accumulators.AccumulatorRegistry; import org.apache.flink.runtime.broadcast.BroadcastVariableManager; import org.apache.flink.runtime.checkpoint.CheckpointMetrics; -import org.apache.flink.runtime.checkpoint.SubtaskState; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.execution.Environment; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.io.disk.iomanager.IOManager; @@ -156,7 +156,7 @@ public void acknowledgeCheckpoint(long checkpointId, CheckpointMetrics checkpoin } @Override - public void acknowledgeCheckpoint(long checkpointId, CheckpointMetrics checkpointMetrics, SubtaskState subtaskState) { + public void acknowledgeCheckpoint(long checkpointId, CheckpointMetrics checkpointMetrics, TaskStateSnapshot subtaskState) { } @Override diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java index 4f0242e131ab5..7514cc4200d74 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java @@ -27,7 +27,7 @@ import org.apache.flink.runtime.accumulators.AccumulatorRegistry; import org.apache.flink.runtime.broadcast.BroadcastVariableManager; import org.apache.flink.runtime.checkpoint.CheckpointMetrics; -import org.apache.flink.runtime.checkpoint.SubtaskState; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.execution.Environment; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.io.disk.iomanager.IOManager; @@ -50,8 +50,8 @@ import org.apache.flink.runtime.util.TestingTaskManagerRuntimeInfo; import org.apache.flink.types.Record; import org.apache.flink.util.MutableObjectIterator; - import org.apache.flink.util.Preconditions; + import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; @@ -354,7 +354,7 @@ public void acknowledgeCheckpoint(long checkpointId, CheckpointMetrics checkpoin } @Override - public void acknowledgeCheckpoint(long checkpointId, CheckpointMetrics checkpointMetrics, SubtaskState subtaskState) { + public void acknowledgeCheckpoint(long checkpointId, CheckpointMetrics checkpointMetrics, TaskStateSnapshot subtaskState) { throw new UnsupportedOperationException(); } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/util/JvmExitOnFatalErrorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/util/JvmExitOnFatalErrorTest.java index c78a3d5bef2cb..984c35856be7d 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/util/JvmExitOnFatalErrorTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/util/JvmExitOnFatalErrorTest.java @@ -27,7 +27,7 @@ import org.apache.flink.runtime.blob.BlobKey; import org.apache.flink.runtime.broadcast.BroadcastVariableManager; import org.apache.flink.runtime.checkpoint.CheckpointMetrics; -import org.apache.flink.runtime.checkpoint.SubtaskState; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.clusterframework.types.AllocationID; import org.apache.flink.runtime.concurrent.Future; import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor; @@ -70,7 +70,8 @@ import java.util.concurrent.Executors; import static org.junit.Assume.assumeTrue; -import static org.mockito.Mockito.*; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; /** * Test that verifies the behavior of blocking shutdown hooks and of the @@ -235,7 +236,7 @@ public InputSplit getNextInputSplit(ClassLoader userCodeClassLoader) { private static final class NoOpCheckpointResponder implements CheckpointResponder { @Override - public void acknowledgeCheckpoint(JobID j, ExecutionAttemptID e, long i, CheckpointMetrics c, SubtaskState s) {} + public void acknowledgeCheckpoint(JobID j, ExecutionAttemptID e, long i, CheckpointMetrics c, TaskStateSnapshot s) {} @Override public void declineCheckpoint(JobID j, ExecutionAttemptID e, long l, Throwable t) {} diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamConfig.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamConfig.java index 77caa34d11fbf..13100db01d837 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamConfig.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamConfig.java @@ -21,6 +21,7 @@ import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.configuration.Configuration; +import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.operators.util.CorruptConfigurationException; import org.apache.flink.runtime.state.AbstractStateBackend; import org.apache.flink.runtime.util.ClassLoaderUtil; @@ -76,6 +77,7 @@ public class StreamConfig implements Serializable { private static final String OUT_STREAM_EDGES = "outStreamEdges"; private static final String IN_STREAM_EDGES = "inStreamEdges"; private static final String OPERATOR_NAME = "operatorName"; + private static final String OPERATOR_ID = "operatorID"; private static final String CHAIN_END = "chainEnd"; private static final String CHECKPOINTING_ENABLED = "checkpointing"; @@ -213,7 +215,7 @@ public void setStreamOperator(StreamOperator operator) { } } - public T getStreamOperator(ClassLoader cl) { + public > T getStreamOperator(ClassLoader cl) { try { return InstantiationUtil.readObjectFromConfig(this.config, SERIALIZEDUDF, cl); } @@ -411,6 +413,15 @@ public Map getTransitiveChainedTaskConfigs(ClassLoader cl } } + public void setOperatorID(OperatorID operatorID) { + this.config.setBytes(OPERATOR_ID, operatorID.getBytes()); + } + + public OperatorID getOperatorID() { + byte[] operatorIDBytes = config.getBytes(OPERATOR_ID, null); + return new OperatorID(Preconditions.checkNotNull(operatorIDBytes)); + } + public void setOperatorName(String name) { this.config.setString(OPERATOR_NAME, name); } diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java index e70962b9b2623..abaa74e7d2ecc 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java @@ -246,7 +246,9 @@ private List createChain( operatorHashes = new ArrayList<>(); chainedOperatorHashes.put(startNodeId, operatorHashes); } - operatorHashes.add(new Tuple2<>(hashes.get(currentNodeId), legacyHashes.get(1).get(currentNodeId))); + + byte[] primaryHashBytes = hashes.get(currentNodeId); + operatorHashes.add(new Tuple2<>(primaryHashBytes, legacyHashes.get(1).get(currentNodeId))); chainedNames.put(currentNodeId, createChainedName(currentNodeId, chainableOutputs)); chainedMinResources.put(currentNodeId, createChainedMinResources(currentNodeId, chainableOutputs)); @@ -280,13 +282,16 @@ private List createChain( chainedConfigs.put(startNodeId, new HashMap()); } config.setChainIndex(chainIndex); - config.setOperatorName(streamGraph.getStreamNode(currentNodeId).getOperatorName()); + StreamNode node = streamGraph.getStreamNode(currentNodeId); + config.setOperatorName(node.getOperatorName()); chainedConfigs.get(startNodeId).put(currentNodeId, config); } + + config.setOperatorID(new OperatorID(primaryHashBytes)); + if (chainableOutputs.isEmpty()) { config.setChainEnd(); } - return transitiveOutEdges; } else { diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java index d711518e49f4c..324bc8c6bcbcd 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java @@ -36,6 +36,7 @@ import org.apache.flink.metrics.MetricGroup; import org.apache.flink.runtime.checkpoint.CheckpointOptions; import org.apache.flink.runtime.checkpoint.CheckpointOptions.CheckpointType; +import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.metrics.groups.OperatorMetricGroup; import org.apache.flink.runtime.state.AbstractKeyedStateBackend; import org.apache.flink.runtime.state.CheckpointStreamFactory; @@ -179,7 +180,6 @@ public abstract class AbstractStreamOperator public void setup(StreamTask containingTask, StreamConfig config, Output> output) { this.container = containingTask; this.config = config; - this.metrics = container.getEnvironment().getMetricGroup().addOperator(config.getOperatorName()); this.output = new CountingOutput(output, ((OperatorMetricGroup) this.metrics).getIOMetricGroup().getNumRecordsOutCounter()); if (config.isChainStart()) { @@ -973,6 +973,11 @@ public void processWatermark2(Watermark mark) throws Exception { } } + @Override + public OperatorID getOperatorID() { + return config.getOperatorID(); + } + @VisibleForTesting public int numProcessingTimeTimers() { return timeServiceManager == null ? 0 : diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java index 61578b23a6d51..3c26f50ebaf0e 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java @@ -20,6 +20,7 @@ import org.apache.flink.annotation.PublicEvolving; import org.apache.flink.metrics.MetricGroup; import org.apache.flink.runtime.checkpoint.CheckpointOptions; +import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.state.StreamStateHandle; import org.apache.flink.streaming.api.graph.StreamConfig; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; @@ -149,4 +150,5 @@ StreamStateHandle snapshotLegacyOperatorState( MetricGroup getMetricGroup(); + OperatorID getOperatorID(); } diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java index c35a6dc5b3684..4a6a4fb7144a2 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java @@ -25,12 +25,14 @@ import org.apache.flink.runtime.checkpoint.CheckpointMetaData; import org.apache.flink.runtime.checkpoint.CheckpointMetrics; import org.apache.flink.runtime.checkpoint.CheckpointOptions; -import org.apache.flink.runtime.checkpoint.SubtaskState; +import org.apache.flink.runtime.checkpoint.OperatorSubtaskState; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.execution.CancelTaskException; import org.apache.flink.runtime.execution.Environment; import org.apache.flink.runtime.io.network.api.CancelCheckpointMarker; import org.apache.flink.runtime.io.network.api.serialization.EventSerializer; import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter; +import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable; import org.apache.flink.runtime.jobgraph.tasks.StatefulTask; import org.apache.flink.runtime.state.AbstractKeyedStateBackend; @@ -64,13 +66,12 @@ import java.io.Closeable; import java.io.IOException; -import java.util.ArrayList; import java.util.Collection; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; -import java.util.concurrent.RunnableFuture; import java.util.concurrent.ThreadFactory; import java.util.concurrent.atomic.AtomicReference; @@ -675,6 +676,7 @@ private void initializeOperators(boolean restored) throws Exception { StreamOperator operator = allOperators[chainIdx]; if (null != operator) { if (restored && restoreStateHandles != null) { + operator.initializeState(restoreStateHandles.getStateByOperatorID(operator.getOperatorID())); operator.initializeState(new OperatorStateHandles(restoreStateHandles, chainIdx)); } else { operator.initializeState(null); @@ -850,12 +852,9 @@ private static final class AsyncCheckpointRunnable implements Runnable, Closeabl private final StreamTask owner; - private final List snapshotInProgressList; + private final Map operatorSnapshotsInProgress; - private RunnableFuture futureKeyedBackendStateHandles; - private RunnableFuture futureKeyedStreamStateHandles; - - private List nonPartitionedStateHandles; + private Map nonPartitionedStateHandles; private final CheckpointMetaData checkpointMetaData; private final CheckpointMetrics checkpointMetrics; @@ -867,81 +866,60 @@ private static final class AsyncCheckpointRunnable implements Runnable, Closeabl AsyncCheckpointRunnable( StreamTask owner, - List nonPartitionedStateHandles, - List snapshotInProgressList, + Map nonPartitionedStateHandles, + Map operatorSnapshotsInProgress, CheckpointMetaData checkpointMetaData, CheckpointMetrics checkpointMetrics, long asyncStartNanos) { this.owner = Preconditions.checkNotNull(owner); - this.snapshotInProgressList = Preconditions.checkNotNull(snapshotInProgressList); + this.operatorSnapshotsInProgress = Preconditions.checkNotNull(operatorSnapshotsInProgress); this.checkpointMetaData = Preconditions.checkNotNull(checkpointMetaData); this.checkpointMetrics = Preconditions.checkNotNull(checkpointMetrics); this.nonPartitionedStateHandles = nonPartitionedStateHandles; this.asyncStartNanos = asyncStartNanos; - - if (!snapshotInProgressList.isEmpty()) { - // TODO Currently only the head operator of a chain can have keyed state, so simply access it directly. - int headIndex = snapshotInProgressList.size() - 1; - OperatorSnapshotResult snapshotInProgress = snapshotInProgressList.get(headIndex); - if (null != snapshotInProgress) { - this.futureKeyedBackendStateHandles = snapshotInProgress.getKeyedStateManagedFuture(); - this.futureKeyedStreamStateHandles = snapshotInProgress.getKeyedStateRawFuture(); - } - } } @Override public void run() { FileSystemSafetyNet.initializeSafetyNetForThread(); try { - // Keyed state handle future, currently only one (the head) operator can have this - KeyedStateHandle keyedStateHandleBackend = FutureUtil.runIfNotDoneAndGet(futureKeyedBackendStateHandles); - KeyedStateHandle keyedStateHandleStream = FutureUtil.runIfNotDoneAndGet(futureKeyedStreamStateHandles); - - List operatorStatesBackend = new ArrayList<>(snapshotInProgressList.size()); - List operatorStatesStream = new ArrayList<>(snapshotInProgressList.size()); - - for (OperatorSnapshotResult snapshotInProgress : snapshotInProgressList) { - if (null != snapshotInProgress) { - operatorStatesBackend.add( - FutureUtil.runIfNotDoneAndGet(snapshotInProgress.getOperatorStateManagedFuture())); - operatorStatesStream.add( - FutureUtil.runIfNotDoneAndGet(snapshotInProgress.getOperatorStateRawFuture())); - } else { - operatorStatesBackend.add(null); - operatorStatesStream.add(null); - } - } + boolean hasState = false; + final TaskStateSnapshot taskOperatorSubtaskStates = + new TaskStateSnapshot(operatorSnapshotsInProgress.size()); - final long asyncEndNanos = System.nanoTime(); - final long asyncDurationMillis = (asyncEndNanos - asyncStartNanos) / 1_000_000; + for (Map.Entry entry : operatorSnapshotsInProgress.entrySet()) { - checkpointMetrics.setAsyncDurationMillis(asyncDurationMillis); + OperatorID operatorID = entry.getKey(); + OperatorSnapshotResult snapshotInProgress = entry.getValue(); - ChainedStateHandle chainedNonPartitionedOperatorsState = - new ChainedStateHandle<>(nonPartitionedStateHandles); + OperatorSubtaskState operatorSubtaskState = new OperatorSubtaskState( + nonPartitionedStateHandles.get(operatorID), + FutureUtil.runIfNotDoneAndGet(snapshotInProgress.getOperatorStateManagedFuture()), + FutureUtil.runIfNotDoneAndGet(snapshotInProgress.getOperatorStateRawFuture()), + FutureUtil.runIfNotDoneAndGet(snapshotInProgress.getKeyedStateManagedFuture()), + FutureUtil.runIfNotDoneAndGet(snapshotInProgress.getKeyedStateRawFuture()) + ); - ChainedStateHandle chainedOperatorStateBackend = - new ChainedStateHandle<>(operatorStatesBackend); + hasState |= operatorSubtaskState.hasState(); + taskOperatorSubtaskStates.putSubtaskStateByOperatorID(operatorID, operatorSubtaskState); + } - ChainedStateHandle chainedOperatorStateStream = - new ChainedStateHandle<>(operatorStatesStream); + final long asyncEndNanos = System.nanoTime(); + final long asyncDurationMillis = (asyncEndNanos - asyncStartNanos) / 1_000_000; - SubtaskState subtaskState = createSubtaskStateFromSnapshotStateHandles( - chainedNonPartitionedOperatorsState, - chainedOperatorStateBackend, - chainedOperatorStateStream, - keyedStateHandleBackend, - keyedStateHandleStream); + checkpointMetrics.setAsyncDurationMillis(asyncDurationMillis); if (asyncCheckpointState.compareAndSet(CheckpointingOperation.AsynCheckpointState.RUNNING, CheckpointingOperation.AsynCheckpointState.COMPLETED)) { + // we signal a stateless task by reporting null, so that there are no attempts to assign empty state + // to stateless tasks on restore. This enables simple job modifications that only concern + // stateless without the need to assign them uids to match their (always empty) states. owner.getEnvironment().acknowledgeCheckpoint( checkpointMetaData.getCheckpointId(), checkpointMetrics, - subtaskState); + hasState ? taskOperatorSubtaskStates : null); if (LOG.isDebugEnabled()) { LOG.debug("{} - finished asynchronous part of checkpoint {}. Asynchronous duration: {} ms", @@ -988,38 +966,13 @@ public void close() { } } - private SubtaskState createSubtaskStateFromSnapshotStateHandles( - ChainedStateHandle chainedNonPartitionedOperatorsState, - ChainedStateHandle chainedOperatorStateBackend, - ChainedStateHandle chainedOperatorStateStream, - KeyedStateHandle keyedStateHandleBackend, - KeyedStateHandle keyedStateHandleStream) { - - boolean hasAnyState = keyedStateHandleBackend != null - || keyedStateHandleStream != null - || !chainedOperatorStateBackend.isEmpty() - || !chainedOperatorStateStream.isEmpty() - || !chainedNonPartitionedOperatorsState.isEmpty(); - - // we signal a stateless task by reporting null, so that there are no attempts to assign empty state to - // stateless tasks on restore. This allows for simple job modifications that only concern stateless without - // the need to assign them uids to match their (always empty) states. - return hasAnyState ? new SubtaskState( - chainedNonPartitionedOperatorsState, - chainedOperatorStateBackend, - chainedOperatorStateStream, - keyedStateHandleBackend, - keyedStateHandleStream) - : null; - } - private void cleanup() throws Exception { if (asyncCheckpointState.compareAndSet(CheckpointingOperation.AsynCheckpointState.RUNNING, CheckpointingOperation.AsynCheckpointState.DISCARDED)) { LOG.debug("Cleanup AsyncCheckpointRunnable for checkpoint {} of {}.", checkpointMetaData.getCheckpointId(), owner.getName()); Exception exception = null; // clean up ongoing operator snapshot results and non partitioned state handles - for (OperatorSnapshotResult operatorSnapshotResult : snapshotInProgressList) { + for (OperatorSnapshotResult operatorSnapshotResult : operatorSnapshotsInProgress.values()) { if (operatorSnapshotResult != null) { try { operatorSnapshotResult.cancel(); @@ -1031,7 +984,7 @@ private void cleanup() throws Exception { // discard non partitioned state handles try { - StateUtil.bestEffortDiscardAllStateObjects(nonPartitionedStateHandles); + StateUtil.bestEffortDiscardAllStateObjects(nonPartitionedStateHandles.values()); } catch (Exception discardException) { exception = ExceptionUtils.firstOrSuppressed(discardException, exception); } @@ -1069,8 +1022,8 @@ private static final class CheckpointingOperation { // ------------------------ - private final List nonPartitionedStates; - private final List snapshotInProgressList; + private final Map nonPartitionedStates; + private final Map operatorSnapshotsInProgress; public CheckpointingOperation( StreamTask owner, @@ -1083,8 +1036,8 @@ public CheckpointingOperation( this.checkpointOptions = Preconditions.checkNotNull(checkpointOptions); this.checkpointMetrics = Preconditions.checkNotNull(checkpointMetrics); this.allOperators = owner.operatorChain.getAllOperators(); - this.nonPartitionedStates = new ArrayList<>(allOperators.length); - this.snapshotInProgressList = new ArrayList<>(allOperators.length); + this.nonPartitionedStates = new HashMap<>(allOperators.length); + this.operatorSnapshotsInProgress = new HashMap<>(allOperators.length); } public void executeCheckpointing() throws Exception { @@ -1119,7 +1072,7 @@ public void executeCheckpointing() throws Exception { } finally { if (failed) { // Cleanup to release resources - for (OperatorSnapshotResult operatorSnapshotResult : snapshotInProgressList) { + for (OperatorSnapshotResult operatorSnapshotResult : operatorSnapshotsInProgress.values()) { if (null != operatorSnapshotResult) { try { operatorSnapshotResult.cancel(); @@ -1130,7 +1083,7 @@ public void executeCheckpointing() throws Exception { } // Cleanup non partitioned state handles - for (StreamStateHandle nonPartitionedState : nonPartitionedStates) { + for (StreamStateHandle nonPartitionedState : nonPartitionedStates.values()) { if (nonPartitionedState != null) { try { nonPartitionedState.discardState(); @@ -1156,21 +1109,19 @@ public void executeCheckpointing() throws Exception { private void checkpointStreamOperator(StreamOperator op) throws Exception { if (null != op) { // first call the legacy checkpoint code paths - nonPartitionedStates.add(op.snapshotLegacyOperatorState( - checkpointMetaData.getCheckpointId(), - checkpointMetaData.getTimestamp(), - checkpointOptions)); + StreamStateHandle legacyOperatorState = op.snapshotLegacyOperatorState( + checkpointMetaData.getCheckpointId(), + checkpointMetaData.getTimestamp(), + checkpointOptions); + + OperatorID operatorID = op.getOperatorID(); + nonPartitionedStates.put(operatorID, legacyOperatorState); OperatorSnapshotResult snapshotInProgress = op.snapshotState( checkpointMetaData.getCheckpointId(), checkpointMetaData.getTimestamp(), checkpointOptions); - - snapshotInProgressList.add(snapshotInProgress); - } else { - nonPartitionedStates.add(null); - OperatorSnapshotResult emptySnapshotInProgress = new OperatorSnapshotResult(); - snapshotInProgressList.add(emptySnapshotInProgress); + operatorSnapshotsInProgress.put(operatorID, snapshotInProgress); } } @@ -1179,7 +1130,7 @@ public void runAsyncCheckpointingAndAcknowledge() throws IOException { AsyncCheckpointRunnable asyncCheckpointRunnable = new AsyncCheckpointRunnable( owner, nonPartitionedStates, - snapshotInProgressList, + operatorSnapshotsInProgress, checkpointMetaData, checkpointMetrics, startAsyncPartNano); diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperatorLifecycleTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperatorLifecycleTest.java index e8b4c9e83c98e..6d9eb0e9062ef 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperatorLifecycleTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperatorLifecycleTest.java @@ -25,6 +25,7 @@ import org.apache.flink.runtime.checkpoint.CheckpointMetaData; import org.apache.flink.runtime.checkpoint.CheckpointOptions; import org.apache.flink.runtime.execution.ExecutionState; +import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.state.StateInitializationContext; import org.apache.flink.runtime.state.StateSnapshotContext; import org.apache.flink.runtime.state.StreamStateHandle; @@ -84,7 +85,7 @@ public class AbstractUdfStreamOperatorLifecycleTest { "UDF::close"); private static final String ALL_METHODS_STREAM_OPERATOR = "[close[], dispose[], getChainingStrategy[], " + - "getMetricGroup[], initializeState[class org.apache.flink.streaming.runtime.tasks.OperatorStateHandles], " + + "getMetricGroup[], getOperatorID[], initializeState[class org.apache.flink.streaming.runtime.tasks.OperatorStateHandles], " + "notifyOfCompletedCheckpoint[long], open[], setChainingStrategy[class " + "org.apache.flink.streaming.api.operators.ChainingStrategy], setKeyContextElement1[class " + "org.apache.flink.streaming.runtime.streamrecord.StreamRecord], " + @@ -132,6 +133,7 @@ public void testLifeCycleFull() throws Exception { MockSourceFunction srcFun = new MockSourceFunction(); cfg.setStreamOperator(new LifecycleTrackingStreamSource(srcFun, true)); + cfg.setOperatorID(new OperatorID()); cfg.setTimeCharacteristic(TimeCharacteristic.ProcessingTime); Task task = StreamTaskTest.createTask(SourceStreamTask.class, cfg, taskManagerConfig); @@ -154,6 +156,7 @@ public void testLifeCycleCancel() throws Exception { StreamConfig cfg = new StreamConfig(new Configuration()); MockSourceFunction srcFun = new MockSourceFunction(); cfg.setStreamOperator(new LifecycleTrackingStreamSource(srcFun, false)); + cfg.setOperatorID(new OperatorID()); cfg.setTimeCharacteristic(TimeCharacteristic.ProcessingTime); Task task = StreamTaskTest.createTask(SourceStreamTask.class, cfg, taskManagerConfig); diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/async/AsyncWaitOperatorTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/async/AsyncWaitOperatorTest.java index f9a1cd00ed091..abfb5bccce200 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/async/AsyncWaitOperatorTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/async/AsyncWaitOperatorTest.java @@ -29,11 +29,12 @@ import org.apache.flink.runtime.checkpoint.CheckpointMetaData; import org.apache.flink.runtime.checkpoint.CheckpointMetrics; import org.apache.flink.runtime.checkpoint.CheckpointOptions; -import org.apache.flink.runtime.checkpoint.SubtaskState; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.execution.Environment; import org.apache.flink.runtime.io.network.api.CheckpointBarrier; import org.apache.flink.runtime.jobgraph.JobGraph; import org.apache.flink.runtime.jobgraph.JobVertex; +import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.metrics.groups.TaskMetricGroup; import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider; import org.apache.flink.runtime.operators.testutils.UnregisteredTaskMetricsGroup; @@ -62,6 +63,7 @@ import org.apache.flink.streaming.runtime.tasks.StreamTask; import org.apache.flink.streaming.runtime.tasks.TestProcessingTimeService; import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness; +import org.apache.flink.streaming.util.OperatorIDMappedStateToChainConverter; import org.apache.flink.streaming.util.TestHarnessUtil; import org.apache.flink.util.Preconditions; import org.apache.flink.util.TestLogger; @@ -501,6 +503,7 @@ public void testStateSnapshotAndRestore() throws Exception { final StreamConfig streamConfig = testHarness.getStreamConfig(); streamConfig.setStreamOperator(operator); + streamConfig.setOperatorID(new OperatorID(42L, 4711L)); final AcknowledgeStreamMockEnvironment env = new AcknowledgeStreamMockEnvironment( testHarness.jobConfig, @@ -540,7 +543,9 @@ public void testStateSnapshotAndRestore() throws Exception { // set the operator state from previous attempt into the restored one final OneInputStreamTask restoredTask = new OneInputStreamTask<>(); - restoredTask.setInitialState(new TaskStateHandles(env.getCheckpointStateHandles())); + TaskStateSnapshot subtaskStates = env.getCheckpointStateHandles(); + TaskStateHandles stateHandles = OperatorIDMappedStateToChainConverter.convert(subtaskStates, streamConfig, 1); + restoredTask.setInitialState(stateHandles); final OneInputStreamTaskTestHarness restoredTaskHarness = new OneInputStreamTaskTestHarness<>(restoredTask, BasicTypeInfo.INT_TYPE_INFO, BasicTypeInfo.INT_TYPE_INFO); @@ -595,7 +600,7 @@ public void testStateSnapshotAndRestore() throws Exception { private static class AcknowledgeStreamMockEnvironment extends StreamMockEnvironment { private volatile long checkpointId; - private volatile SubtaskState checkpointStateHandles; + private volatile TaskStateSnapshot checkpointStateHandles; private final OneShotLatch checkpointLatch = new OneShotLatch(); @@ -614,7 +619,7 @@ public long getCheckpointId() { public void acknowledgeCheckpoint( long checkpointId, CheckpointMetrics checkpointMetrics, - SubtaskState checkpointStateHandles) { + TaskStateSnapshot checkpointStateHandles) { this.checkpointId = checkpointId; this.checkpointStateHandles = checkpointStateHandles; @@ -625,7 +630,7 @@ public OneShotLatch getCheckpointLatch() { return checkpointLatch; } - public SubtaskState getCheckpointStateHandles() { + public TaskStateSnapshot getCheckpointStateHandles() { return checkpointStateHandles; } } diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/StreamTaskTimerTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/StreamTaskTimerTest.java index 6e3be0365fc33..65e59f8ac756c 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/StreamTaskTimerTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/StreamTaskTimerTest.java @@ -20,6 +20,7 @@ import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.streaming.api.graph.StreamConfig; import org.apache.flink.streaming.api.operators.StreamMap; import org.apache.flink.streaming.runtime.tasks.OneInputStreamTask; @@ -53,6 +54,7 @@ public void testOpenCloseAndTimestamps() throws Exception { StreamMap mapOperator = new StreamMap<>(new DummyMapFunction()); streamConfig.setStreamOperator(mapOperator); + streamConfig.setOperatorID(new OperatorID()); testHarness.invoke(); testHarness.waitForTaskRunning(); diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/TestProcessingTimeServiceTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/TestProcessingTimeServiceTest.java index 675ffa3570ba7..d621b0bb12adb 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/TestProcessingTimeServiceTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/TestProcessingTimeServiceTest.java @@ -19,6 +19,7 @@ package org.apache.flink.streaming.runtime.operators; import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.streaming.api.graph.StreamConfig; import org.apache.flink.streaming.api.operators.StreamMap; import org.apache.flink.streaming.runtime.tasks.AsyncExceptionHandler; @@ -53,6 +54,7 @@ public void testCustomTimeServiceProvider() throws Throwable { StreamMap mapOperator = new StreamMap<>(new StreamTaskTimerTest.DummyMapFunction()); streamConfig.setStreamOperator(mapOperator); + streamConfig.setOperatorID(new OperatorID()); testHarness.invoke(); diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/BlockingCheckpointsTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/BlockingCheckpointsTest.java index 51328abbebc12..3b8178bb16fbf 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/BlockingCheckpointsTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/BlockingCheckpointsTest.java @@ -45,6 +45,7 @@ import org.apache.flink.runtime.io.network.netty.PartitionProducerStateChecker; import org.apache.flink.runtime.io.network.partition.ResultPartitionConsumableNotifier; import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider; import org.apache.flink.runtime.memory.MemoryManager; import org.apache.flink.runtime.operators.testutils.UnregisteredTaskMetricsGroup; @@ -93,6 +94,7 @@ public void testBlockingNonInterruptibleCheckpoint() throws Exception { Configuration taskConfig = new Configuration(); StreamConfig cfg = new StreamConfig(taskConfig); cfg.setStreamOperator(new TestOperator()); + cfg.setOperatorID(new OperatorID()); cfg.setStateBackend(new LockingStreamStateBackend()); Task task = createTask(taskConfig); diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java index 25b504b916c16..691f0de4a8b16 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java @@ -40,6 +40,7 @@ import org.apache.flink.runtime.io.network.netty.PartitionProducerStateChecker; import org.apache.flink.runtime.io.network.partition.ResultPartitionConsumableNotifier; import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider; import org.apache.flink.runtime.memory.MemoryManager; import org.apache.flink.runtime.operators.testutils.UnregisteredTaskMetricsGroup; @@ -144,9 +145,11 @@ private void testRestoreWithInterrupt(int mode) throws Exception { case KEYED_RAW: cfg.setStateKeySerializer(IntSerializer.INSTANCE); cfg.setStreamOperator(new StreamSource<>(new TestSource())); + cfg.setOperatorID(new OperatorID()); break; case LEGACY: cfg.setStreamOperator(new StreamSource<>(new TestSourceLegacy())); + cfg.setOperatorID(new OperatorID()); break; default: throw new IllegalArgumentException(); diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java index f7987a122589f..84a7ef7af126e 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java @@ -34,9 +34,10 @@ import org.apache.flink.runtime.checkpoint.CheckpointMetaData; import org.apache.flink.runtime.checkpoint.CheckpointMetrics; import org.apache.flink.runtime.checkpoint.CheckpointOptions; -import org.apache.flink.runtime.checkpoint.SubtaskState; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.io.network.api.CancelCheckpointMarker; import org.apache.flink.runtime.io.network.api.CheckpointBarrier; +import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider; import org.apache.flink.runtime.state.StateInitializationContext; import org.apache.flink.runtime.state.StateSnapshotContext; @@ -53,6 +54,7 @@ import org.apache.flink.streaming.runtime.partitioner.BroadcastPartitioner; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.streaming.runtime.streamstatus.StreamStatus; +import org.apache.flink.streaming.util.OperatorIDMappedStateToChainConverter; import org.apache.flink.streaming.util.TestHarnessUtil; import org.apache.flink.util.InstantiationUtil; import org.apache.flink.util.Preconditions; @@ -109,6 +111,7 @@ public void testOpenCloseAndTimestamps() throws Exception { StreamConfig streamConfig = testHarness.getStreamConfig(); StreamMap mapOperator = new StreamMap(new TestOpenCloseMapFunction()); streamConfig.setStreamOperator(mapOperator); + streamConfig.setOperatorID(new OperatorID()); long initialTime = 0L; ConcurrentLinkedQueue expectedOutput = new ConcurrentLinkedQueue(); @@ -151,6 +154,7 @@ public void testWatermarkAndStreamStatusForwarding() throws Exception { StreamConfig streamConfig = testHarness.getStreamConfig(); StreamMap mapOperator = new StreamMap(new IdentityMap()); streamConfig.setStreamOperator(mapOperator); + streamConfig.setOperatorID(new OperatorID()); ConcurrentLinkedQueue expectedOutput = new ConcurrentLinkedQueue(); long initialTime = 0L; @@ -412,6 +416,7 @@ public void testCheckpointBarriers() throws Exception { StreamConfig streamConfig = testHarness.getStreamConfig(); StreamMap mapOperator = new StreamMap(new IdentityMap()); streamConfig.setStreamOperator(mapOperator); + streamConfig.setOperatorID(new OperatorID()); ConcurrentLinkedQueue expectedOutput = new ConcurrentLinkedQueue(); long initialTime = 0L; @@ -471,6 +476,7 @@ public void testOvertakingCheckpointBarriers() throws Exception { StreamConfig streamConfig = testHarness.getStreamConfig(); StreamMap mapOperator = new StreamMap(new IdentityMap()); streamConfig.setStreamOperator(mapOperator); + streamConfig.setOperatorID(new OperatorID()); ConcurrentLinkedQueue expectedOutput = new ConcurrentLinkedQueue(); long initialTime = 0L; @@ -580,15 +586,23 @@ public void testSnapshottingAndRestoring() throws Exception { testHarness.waitForTaskCompletion(deadline.timeLeft().toMillis()); final OneInputStreamTask restoredTask = new OneInputStreamTask(); - restoredTask.setInitialState(new TaskStateHandles(env.getCheckpointStateHandles())); - final OneInputStreamTaskTestHarness restoredTaskHarness = new OneInputStreamTaskTestHarness(restoredTask, BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO); + final OneInputStreamTaskTestHarness restoredTaskHarness = + new OneInputStreamTaskTestHarness(restoredTask, BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO); restoredTaskHarness.configureForKeyedStream(keySelector, BasicTypeInfo.STRING_TYPE_INFO); StreamConfig restoredTaskStreamConfig = restoredTaskHarness.getStreamConfig(); configureChainedTestingStreamOperator(restoredTaskStreamConfig, numberChainedTasks, seed, recoveryTimestamp); + TaskStateSnapshot stateHandles = env.getCheckpointStateHandles(); + Assert.assertEquals(numberChainedTasks, stateHandles.getSubtaskStateMappings().size()); + + TaskStateHandles taskStateHandles = + OperatorIDMappedStateToChainConverter.convert(stateHandles, restoredTaskStreamConfig, numberChainedTasks); + + restoredTask.setInitialState(taskStateHandles); + TestingStreamOperator.numberRestoreCalls = 0; restoredTaskHarness.invoke(); @@ -601,6 +615,7 @@ public void testSnapshottingAndRestoring() throws Exception { TestingStreamOperator.numberRestoreCalls = 0; } + //============================================================================================== // Utility functions and classes //============================================================================================== @@ -618,6 +633,7 @@ private void configureChainedTestingStreamOperator( TestingStreamOperator previousOperator = new TestingStreamOperator<>(random.nextLong(), recoveryTimestamp); streamConfig.setStreamOperator(previousOperator); + streamConfig.setOperatorID(new OperatorID(0L, 0L)); // create the chain of operators Map chainedTaskConfigs = new HashMap<>(numberChainedTasks - 1); @@ -627,6 +643,7 @@ private void configureChainedTestingStreamOperator( TestingStreamOperator chainedOperator = new TestingStreamOperator<>(random.nextLong(), recoveryTimestamp); StreamConfig chainedConfig = new StreamConfig(new Configuration()); chainedConfig.setStreamOperator(chainedOperator); + chainedConfig.setOperatorID(new OperatorID(0L, chainedIndex)); chainedTaskConfigs.put(chainedIndex, chainedConfig); StreamEdge outputEdge = new StreamEdge( @@ -673,7 +690,7 @@ public IN getKey(IN value) throws Exception { private static class AcknowledgeStreamMockEnvironment extends StreamMockEnvironment { private volatile long checkpointId; - private volatile SubtaskState checkpointStateHandles; + private volatile TaskStateSnapshot checkpointStateHandles; private final OneShotLatch checkpointLatch = new OneShotLatch(); @@ -692,7 +709,7 @@ public long getCheckpointId() { public void acknowledgeCheckpoint( long checkpointId, CheckpointMetrics checkpointMetrics, - SubtaskState checkpointStateHandles) { + TaskStateSnapshot checkpointStateHandles) { this.checkpointId = checkpointId; this.checkpointStateHandles = checkpointStateHandles; @@ -703,7 +720,7 @@ public OneShotLatch getCheckpointLatch() { return checkpointLatch; } - public SubtaskState getCheckpointStateHandles() { + public TaskStateSnapshot getCheckpointStateHandles() { return checkpointStateHandles; } } diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceExternalCheckpointTriggerTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceExternalCheckpointTriggerTest.java index 47a53500be822..b3b0a9f414e35 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceExternalCheckpointTriggerTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceExternalCheckpointTriggerTest.java @@ -25,6 +25,7 @@ import org.apache.flink.runtime.checkpoint.CheckpointOptions; import org.apache.flink.runtime.checkpoint.MasterTriggerRestoreHook; import org.apache.flink.runtime.io.network.api.CheckpointBarrier; +import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.streaming.api.checkpoint.ExternallyInducedSource; import org.apache.flink.streaming.api.functions.source.ParallelSourceFunction; import org.apache.flink.streaming.api.graph.StreamConfig; @@ -64,6 +65,7 @@ public void testCheckpointsTriggeredBySource() throws Exception { StreamConfig streamConfig = testHarness.getStreamConfig(); StreamSource sourceOperator = new StreamSource<>(source); streamConfig.setStreamOperator(sourceOperator); + streamConfig.setOperatorID(new OperatorID()); // this starts the source thread testHarness.invoke(); diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceStreamTaskTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceStreamTaskTest.java index 27818bcafa36d..8867632a5c3ff 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceStreamTaskTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceStreamTaskTest.java @@ -24,6 +24,7 @@ import org.apache.flink.configuration.Configuration; import org.apache.flink.runtime.checkpoint.CheckpointMetaData; import org.apache.flink.runtime.checkpoint.CheckpointOptions; +import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.streaming.api.checkpoint.ListCheckpointed; import org.apache.flink.streaming.api.functions.source.RichSourceFunction; import org.apache.flink.streaming.api.functions.source.SourceFunction; @@ -63,6 +64,7 @@ public void testOpenClose() throws Exception { StreamConfig streamConfig = testHarness.getStreamConfig(); StreamSource sourceOperator = new StreamSource<>(new OpenCloseTestSource()); streamConfig.setStreamOperator(sourceOperator); + streamConfig.setOperatorID(new OperatorID()); testHarness.invoke(); testHarness.waitForTaskCompletion(); @@ -106,6 +108,7 @@ public void testCheckpointing() throws Exception { StreamConfig streamConfig = testHarness.getStreamConfig(); StreamSource, ?> sourceOperator = new StreamSource<>(new MockSource(numElements, sourceCheckpointDelay, sourceReadDelay)); streamConfig.setStreamOperator(sourceOperator); + streamConfig.setOperatorID(new OperatorID()); // prepare the diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java index 5b995c67b8e94..231f59e97fb2a 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java @@ -28,7 +28,7 @@ import org.apache.flink.runtime.accumulators.AccumulatorRegistry; import org.apache.flink.runtime.broadcast.BroadcastVariableManager; import org.apache.flink.runtime.checkpoint.CheckpointMetrics; -import org.apache.flink.runtime.checkpoint.SubtaskState; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.event.AbstractEvent; import org.apache.flink.runtime.execution.Environment; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; @@ -333,7 +333,7 @@ public void acknowledgeCheckpoint(long checkpointId, CheckpointMetrics checkpoin } @Override - public void acknowledgeCheckpoint(long checkpointId, CheckpointMetrics checkpointMetrics, SubtaskState subtaskState) { + public void acknowledgeCheckpoint(long checkpointId, CheckpointMetrics checkpointMetrics, TaskStateSnapshot subtaskState) { } @Override diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskCancellationBarrierTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskCancellationBarrierTest.java index 6e3c299f9ed9d..36bdc054b9340 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskCancellationBarrierTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskCancellationBarrierTest.java @@ -24,6 +24,7 @@ import org.apache.flink.runtime.checkpoint.CheckpointOptions; import org.apache.flink.runtime.checkpoint.decline.CheckpointDeclineOnCancellationBarrierException; import org.apache.flink.runtime.io.network.api.CancelCheckpointMarker; +import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.streaming.api.functions.co.CoMapFunction; import org.apache.flink.streaming.api.graph.StreamConfig; import org.apache.flink.streaming.api.operators.AbstractStreamOperator; @@ -91,6 +92,7 @@ public void testDeclineCallOnCancelBarrierOneInput() throws Exception { StreamConfig streamConfig = testHarness.getStreamConfig(); StreamMap mapOperator = new StreamMap<>(new IdentityMap()); streamConfig.setStreamOperator(mapOperator); + streamConfig.setOperatorID(new OperatorID()); StreamMockEnvironment environment = spy(testHarness.createEnvironment()); @@ -135,6 +137,7 @@ public void testDeclineCallOnCancelBarrierTwoInputs() throws Exception { StreamConfig streamConfig = testHarness.getStreamConfig(); CoStreamMap op = new CoStreamMap<>(new UnionCoMap()); streamConfig.setStreamOperator(op); + streamConfig.setOperatorID(new OperatorID()); StreamMockEnvironment environment = spy(testHarness.createEnvironment()); diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTerminationTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTerminationTest.java index f021b389c811e..86e93dd0a89c2 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTerminationTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTerminationTest.java @@ -44,6 +44,7 @@ import org.apache.flink.runtime.io.network.netty.PartitionProducerStateChecker; import org.apache.flink.runtime.io.network.partition.ResultPartitionConsumableNotifier; import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider; import org.apache.flink.runtime.memory.MemoryManager; import org.apache.flink.runtime.operators.testutils.UnregisteredTaskMetricsGroup; @@ -108,6 +109,7 @@ public void testConcurrentAsyncCheckpointCannotFailFinishedStreamTask() throws E final AbstractStateBackend blockingStateBackend = new BlockingStateBackend(); streamConfig.setStreamOperator(noOpStreamOperator); + streamConfig.setOperatorID(new OperatorID()); streamConfig.setStateBackend(blockingStateBackend); final long checkpointId = 0L; diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java index 923b912437a2c..cab91d8f879b5 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java @@ -32,7 +32,9 @@ import org.apache.flink.runtime.checkpoint.CheckpointMetaData; import org.apache.flink.runtime.checkpoint.CheckpointMetrics; import org.apache.flink.runtime.checkpoint.CheckpointOptions; +import org.apache.flink.runtime.checkpoint.OperatorSubtaskState; import org.apache.flink.runtime.checkpoint.SubtaskState; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.clusterframework.types.AllocationID; import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor; import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor; @@ -49,6 +51,7 @@ import org.apache.flink.runtime.io.network.partition.ResultPartitionConsumableNotifier; import org.apache.flink.runtime.io.network.partition.ResultPartitionManager; import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable; import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider; import org.apache.flink.runtime.memory.MemoryManager; @@ -56,7 +59,6 @@ import org.apache.flink.runtime.query.TaskKvStateRegistry; import org.apache.flink.runtime.state.AbstractKeyedStateBackend; import org.apache.flink.runtime.state.AbstractStateBackend; -import org.apache.flink.runtime.state.ChainedStateHandle; import org.apache.flink.runtime.state.CheckpointStreamFactory; import org.apache.flink.runtime.state.DoneFuture; import org.apache.flink.runtime.state.KeyGroupRange; @@ -158,6 +160,7 @@ public class StreamTaskTest extends TestLogger { public void testEarlyCanceling() throws Exception { Deadline deadline = new FiniteDuration(2, TimeUnit.MINUTES).fromNow(); StreamConfig cfg = new StreamConfig(new Configuration()); + cfg.setOperatorID(new OperatorID(4711L, 42L)); cfg.setStreamOperator(new SlowlyDeserializingOperator()); cfg.setTimeCharacteristic(TimeCharacteristic.ProcessingTime); @@ -203,6 +206,7 @@ public void testStateBackendLoadingAndClosing() throws Exception { taskManagerConfig.setString(CoreOptions.STATE_BACKEND, MockStateBackend.class.getName()); StreamConfig cfg = new StreamConfig(new Configuration()); + cfg.setOperatorID(new OperatorID(4711L, 42L)); cfg.setStreamOperator(new StreamSource<>(new MockSourceFunction())); cfg.setTimeCharacteristic(TimeCharacteristic.ProcessingTime); @@ -227,6 +231,7 @@ public void testStateBackendClosingOnFailure() throws Exception { taskManagerConfig.setString(CoreOptions.STATE_BACKEND, MockStateBackend.class.getName()); StreamConfig cfg = new StreamConfig(new Configuration()); + cfg.setOperatorID(new OperatorID(4711L, 42L)); cfg.setStreamOperator(new StreamSource<>(new MockSourceFunction())); cfg.setTimeCharacteristic(TimeCharacteristic.ProcessingTime); @@ -324,6 +329,13 @@ public void testFailingCheckpointStreamOperator() throws Exception { when(streamOperator2.snapshotLegacyOperatorState(anyLong(), anyLong(), any(CheckpointOptions.class))).thenReturn(streamStateHandle2); when(streamOperator3.snapshotLegacyOperatorState(anyLong(), anyLong(), any(CheckpointOptions.class))).thenReturn(streamStateHandle3); + OperatorID operatorID1 = new OperatorID(); + OperatorID operatorID2 = new OperatorID(); + OperatorID operatorID3 = new OperatorID(); + when(streamOperator1.getOperatorID()).thenReturn(operatorID1); + when(streamOperator2.getOperatorID()).thenReturn(operatorID2); + when(streamOperator3.getOperatorID()).thenReturn(operatorID3); + // set up the task StreamOperator[] streamOperators = {streamOperator1, streamOperator2, streamOperator3}; @@ -399,6 +411,13 @@ public void testFailingAsyncCheckpointRunnable() throws Exception { when(streamOperator2.snapshotLegacyOperatorState(anyLong(), anyLong(), any(CheckpointOptions.class))).thenReturn(streamStateHandle2); when(streamOperator3.snapshotLegacyOperatorState(anyLong(), anyLong(), any(CheckpointOptions.class))).thenReturn(streamStateHandle3); + OperatorID operatorID1 = new OperatorID(); + OperatorID operatorID2 = new OperatorID(); + OperatorID operatorID3 = new OperatorID(); + when(streamOperator1.getOperatorID()).thenReturn(operatorID1); + when(streamOperator2.getOperatorID()).thenReturn(operatorID2); + when(streamOperator3.getOperatorID()).thenReturn(operatorID3); + StreamOperator[] streamOperators = {streamOperator1, streamOperator2, streamOperator3}; OperatorChain> operatorChain = mock(OperatorChain.class); @@ -455,7 +474,7 @@ public Object answer(InvocationOnMock invocation) throws Throwable { return null; } - }).when(mockEnvironment).acknowledgeCheckpoint(anyLong(), any(CheckpointMetrics.class), any(SubtaskState.class)); + }).when(mockEnvironment).acknowledgeCheckpoint(anyLong(), any(CheckpointMetrics.class), any(TaskStateSnapshot.class)); StreamTask> streamTask = mock(StreamTask.class, Mockito.CALLS_REAL_METHODS); CheckpointMetaData checkpointMetaData = new CheckpointMetaData(checkpointId, timestamp); @@ -505,18 +524,19 @@ public Object answer(InvocationOnMock invocation) throws Throwable { acknowledgeCheckpointLatch.await(); - ArgumentCaptor subtaskStateCaptor = ArgumentCaptor.forClass(SubtaskState.class); + ArgumentCaptor subtaskStateCaptor = ArgumentCaptor.forClass(TaskStateSnapshot.class); // check that the checkpoint has been completed verify(mockEnvironment).acknowledgeCheckpoint(eq(checkpointId), any(CheckpointMetrics.class), subtaskStateCaptor.capture()); - SubtaskState subtaskState = subtaskStateCaptor.getValue(); + TaskStateSnapshot subtaskStates = subtaskStateCaptor.getValue(); + OperatorSubtaskState subtaskState = subtaskStates.getSubtaskStateMappings().iterator().next().getValue(); // check that the subtask state contains the expected state handles assertEquals(managedKeyedStateHandle, subtaskState.getManagedKeyedState()); assertEquals(rawKeyedStateHandle, subtaskState.getRawKeyedState()); - assertEquals(new ChainedStateHandle<>(Collections.singletonList(managedOperatorStateHandle)), subtaskState.getManagedOperatorState()); - assertEquals(new ChainedStateHandle<>(Collections.singletonList(rawOperatorStateHandle)), subtaskState.getRawOperatorState()); + assertEquals(managedOperatorStateHandle, subtaskState.getManagedOperatorState()); + assertEquals(rawOperatorStateHandle, subtaskState.getRawOperatorState()); // check that the state handles have not been discarded verify(managedKeyedStateHandle, never()).discardState(); @@ -558,18 +578,19 @@ public void testAsyncCheckpointingConcurrentCloseBeforeAcknowledge() throws Exce Environment mockEnvironment = mock(Environment.class); when(mockEnvironment.getTaskInfo()).thenReturn(mockTaskInfo); - whenNew(SubtaskState.class).withAnyArguments().thenAnswer(new Answer() { + whenNew(OperatorSubtaskState.class).withAnyArguments().thenAnswer(new Answer() { @Override - public SubtaskState answer(InvocationOnMock invocation) throws Throwable { + public OperatorSubtaskState answer(InvocationOnMock invocation) throws Throwable { createSubtask.trigger(); completeSubtask.await(); - - return new SubtaskState( - (ChainedStateHandle) invocation.getArguments()[0], - (ChainedStateHandle) invocation.getArguments()[1], - (ChainedStateHandle) invocation.getArguments()[2], - (KeyedStateHandle) invocation.getArguments()[3], - (KeyedStateHandle) invocation.getArguments()[4]); + Object[] arguments = invocation.getArguments(); + return new OperatorSubtaskState( + (StreamStateHandle) arguments[0], + (OperatorStateHandle) arguments[1], + (OperatorStateHandle) arguments[2], + (KeyedStateHandle) arguments[3], + (KeyedStateHandle) arguments[4] + ); } }); @@ -577,7 +598,9 @@ public SubtaskState answer(InvocationOnMock invocation) throws Throwable { CheckpointMetaData checkpointMetaData = new CheckpointMetaData(checkpointId, timestamp); streamTask.setEnvironment(mockEnvironment); - StreamOperator streamOperator = mock(StreamOperator.class, withSettings().extraInterfaces(StreamCheckpointedOperator.class)); + final StreamOperator streamOperator = mock(StreamOperator.class, withSettings().extraInterfaces(StreamCheckpointedOperator.class)); + final OperatorID operatorID = new OperatorID(); + when(streamOperator.getOperatorID()).thenReturn(operatorID); KeyedStateHandle managedKeyedStateHandle = mock(KeyedStateHandle.class); KeyedStateHandle rawKeyedStateHandle = mock(KeyedStateHandle.class); @@ -636,7 +659,7 @@ public SubtaskState answer(InvocationOnMock invocation) throws Throwable { } // check that the checkpoint has not been acknowledged - verify(mockEnvironment, never()).acknowledgeCheckpoint(eq(checkpointId), any(CheckpointMetrics.class), any(SubtaskState.class)); + verify(mockEnvironment, never()).acknowledgeCheckpoint(eq(checkpointId), any(CheckpointMetrics.class), any(TaskStateSnapshot.class)); // check that the state handles have been discarded verify(managedKeyedStateHandle).discardState(); @@ -676,7 +699,7 @@ public Object answer(InvocationOnMock invocationOnMock) throws Throwable { checkpointCompletedLatch.trigger(); return null; } - }).when(mockEnvironment).acknowledgeCheckpoint(anyLong(), any(CheckpointMetrics.class), any(SubtaskState.class)); + }).when(mockEnvironment).acknowledgeCheckpoint(anyLong(), any(CheckpointMetrics.class), any(TaskStateSnapshot.class)); when(mockEnvironment.getTaskInfo()).thenReturn(mockTaskInfo); @@ -688,6 +711,9 @@ public Object answer(InvocationOnMock invocationOnMock) throws Throwable { StreamOperator statelessOperator = mock(StreamOperator.class, withSettings().extraInterfaces(StreamCheckpointedOperator.class)); + final OperatorID operatorID = new OperatorID(); + when(statelessOperator.getOperatorID()).thenReturn(operatorID); + // mock the returned empty snapshot result (all state handles are null) OperatorSnapshotResult statelessOperatorSnapshotResult = new OperatorSnapshotResult(); when(statelessOperator.snapshotState(anyLong(), anyLong(), any(CheckpointOptions.class))) diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTestHarness.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTestHarness.java index a02fe4e7d9a56..19d48e195f2ef 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTestHarness.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTestHarness.java @@ -24,6 +24,7 @@ import org.apache.flink.configuration.Configuration; import org.apache.flink.runtime.event.AbstractEvent; import org.apache.flink.runtime.io.network.partition.consumer.StreamTestSingleInputGate; +import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable; import org.apache.flink.runtime.memory.MemoryManager; import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider; @@ -142,6 +143,7 @@ public void setupOutputForSingletonOperatorChain() { streamConfig.setNumberOfOutputs(1); streamConfig.setTypeSerializerOut(outputSerializer); streamConfig.setVertexID(0); + streamConfig.setOperatorID(new OperatorID(4711L, 123L)); StreamOperator dummyOperator = new AbstractStreamOperator() { private static final long serialVersionUID = 1L; diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TwoInputStreamTaskTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TwoInputStreamTaskTest.java index 66531ac51ce28..d785c0d7517f5 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TwoInputStreamTaskTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TwoInputStreamTaskTest.java @@ -23,6 +23,7 @@ import org.apache.flink.runtime.checkpoint.CheckpointOptions; import org.apache.flink.runtime.io.network.api.CancelCheckpointMarker; import org.apache.flink.runtime.io.network.api.CheckpointBarrier; +import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.streaming.api.functions.co.CoMapFunction; import org.apache.flink.streaming.api.functions.co.RichCoMapFunction; import org.apache.flink.streaming.api.graph.StreamConfig; @@ -64,6 +65,7 @@ public void testOpenCloseAndTimestamps() throws Exception { StreamConfig streamConfig = testHarness.getStreamConfig(); CoStreamMap coMapOperator = new CoStreamMap(new TestOpenCloseMapFunction()); streamConfig.setStreamOperator(coMapOperator); + streamConfig.setOperatorID(new OperatorID()); long initialTime = 0L; ConcurrentLinkedQueue expectedOutput = new ConcurrentLinkedQueue(); @@ -110,6 +112,7 @@ public void testWatermarkAndStreamStatusForwarding() throws Exception { StreamConfig streamConfig = testHarness.getStreamConfig(); CoStreamMap coMapOperator = new CoStreamMap(new IdentityMap()); streamConfig.setStreamOperator(coMapOperator); + streamConfig.setOperatorID(new OperatorID()); ConcurrentLinkedQueue expectedOutput = new ConcurrentLinkedQueue(); long initialTime = 0L; @@ -216,6 +219,7 @@ public void testCheckpointBarriers() throws Exception { StreamConfig streamConfig = testHarness.getStreamConfig(); CoStreamMap coMapOperator = new CoStreamMap(new IdentityMap()); streamConfig.setStreamOperator(coMapOperator); + streamConfig.setOperatorID(new OperatorID()); ConcurrentLinkedQueue expectedOutput = new ConcurrentLinkedQueue(); long initialTime = 0L; @@ -296,6 +300,7 @@ public void testOvertakingCheckpointBarriers() throws Exception { StreamConfig streamConfig = testHarness.getStreamConfig(); CoStreamMap coMapOperator = new CoStreamMap(new IdentityMap()); streamConfig.setStreamOperator(coMapOperator); + streamConfig.setOperatorID(new OperatorID()); ConcurrentLinkedQueue expectedOutput = new ConcurrentLinkedQueue(); long initialTime = 0L; diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/AbstractStreamOperatorTestHarness.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/AbstractStreamOperatorTestHarness.java index 47e8726874d33..b1a7d69d880dd 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/AbstractStreamOperatorTestHarness.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/AbstractStreamOperatorTestHarness.java @@ -35,6 +35,7 @@ import org.apache.flink.runtime.checkpoint.RoundRobinOperatorStateRepartitioner; import org.apache.flink.runtime.checkpoint.StateAssignmentOperation; import org.apache.flink.runtime.execution.Environment; +import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.operators.testutils.MockEnvironment; import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider; import org.apache.flink.runtime.state.CheckpointStreamFactory; @@ -154,6 +155,7 @@ public AbstractStreamOperatorTestHarness( Configuration underlyingConfig = environment.getTaskConfiguration(); this.config = new StreamConfig(underlyingConfig); this.config.setCheckpointingEnabled(true); + this.config.setOperatorID(new OperatorID()); this.executionConfig = environment.getExecutionConfig(); this.closableRegistry = new CloseableRegistry(); this.checkpointLock = new Object(); diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/OperatorIDMappedStateToChainConverter.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/OperatorIDMappedStateToChainConverter.java new file mode 100644 index 0000000000000..0bb3ddf08b0a9 --- /dev/null +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/OperatorIDMappedStateToChainConverter.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.util; + +import org.apache.flink.runtime.checkpoint.OperatorSubtaskState; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; +import org.apache.flink.runtime.jobgraph.OperatorID; +import org.apache.flink.runtime.state.ChainedStateHandle; +import org.apache.flink.runtime.state.OperatorStateHandle; +import org.apache.flink.runtime.state.StreamStateHandle; +import org.apache.flink.runtime.state.TaskStateHandles; +import org.apache.flink.streaming.api.graph.StreamConfig; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +/** + * Utility to convert state between operator id mapped and chain mapped. + */ +public class OperatorIDMappedStateToChainConverter { + + public static TaskStateHandles convert( + TaskStateSnapshot subtaskStates, + StreamConfig streamConfig, + int chainLength) { + + List operatorIDsInChainOrder = new ArrayList<>(chainLength); + operatorIDsInChainOrder.add(streamConfig.getOperatorID()); + Map chainedTaskConfigs = + streamConfig.getTransitiveChainedTaskConfigs(streamConfig.getClass().getClassLoader()); + for (int i = 1; i < chainLength; ++i) { + operatorIDsInChainOrder.add(chainedTaskConfigs.get(i).getOperatorID()); + } + return convert(subtaskStates, operatorIDsInChainOrder); + } + + public static TaskStateHandles convert(TaskStateSnapshot subtaskStates, List operatorIDsInChainOrder) { + final int chainLength = operatorIDsInChainOrder.size(); + + List legacyStateChain = new ArrayList<>(chainLength); + List> managedOpState = new ArrayList<>(chainLength); + List> rawOpState = new ArrayList<>(chainLength); + + for (int i = 1; i < chainLength; ++i) { + OperatorSubtaskState subtaskState = subtaskStates.getSubtaskStateByOperatorID(operatorIDsInChainOrder.get(i)); + legacyStateChain.add(subtaskState.getLegacyOperatorState()); + managedOpState.add(singletonListOrNull(subtaskState.getManagedOperatorState())); + rawOpState.add(singletonListOrNull(subtaskState.getRawOperatorState())); + } + + OperatorSubtaskState subtaskState = subtaskStates.getSubtaskStateByOperatorID(operatorIDsInChainOrder.get(0)); + legacyStateChain.add(subtaskState.getLegacyOperatorState()); + managedOpState.add(singletonListOrNull(subtaskState.getManagedOperatorState())); + rawOpState.add(singletonListOrNull(subtaskState.getRawOperatorState())); + + ChainedStateHandle legacyChainedStateHandle = new ChainedStateHandle<>(legacyStateChain); + + TaskStateHandles taskStateHandles = new TaskStateHandles( + legacyChainedStateHandle, + managedOpState, + rawOpState, + singletonListOrNull(subtaskState.getManagedKeyedState()), + singletonListOrNull(subtaskState.getRawKeyedState()) + ); + + return taskStateHandles; + } + + private static List singletonListOrNull(T item) { + return item != null ? Collections.singletonList(item) : null; + } +} From a50eda8602d2034753b42413d23842a888e73611 Mon Sep 17 00:00:00 2001 From: Stefan Richter Date: Tue, 11 Jul 2017 17:10:03 +0200 Subject: [PATCH 2/5] [FLINK-7213] Introduce TaskStateSnapshot to unify TaskStateHandles and SubtaskState --- .../checkpoint/OperatorSubtaskState.java | 66 +++++-- .../RoundRobinOperatorStateRepartitioner.java | 4 + .../checkpoint/StateAssignmentOperation.java | 179 ++++++++++-------- .../runtime/checkpoint/TaskStateSnapshot.java | 5 +- .../savepoint/SavepointV2Serializer.java | 20 +- .../deployment/TaskDeploymentDescriptor.java | 8 +- .../runtime/executiongraph/Execution.java | 8 +- .../executiongraph/ExecutionVertex.java | 8 +- .../runtime/jobgraph/tasks/StatefulTask.java | 8 +- .../state/StateInitializationContextImpl.java | 11 +- .../flink/runtime/state/TaskStateHandles.java | 15 +- .../flink/runtime/taskmanager/Task.java | 8 +- .../CheckpointCoordinatorFailureTest.java | 9 +- .../checkpoint/CheckpointCoordinatorTest.java | 136 ++++++++----- .../CheckpointStateRestoreTest.java | 40 ++-- .../CompletedCheckpointStoreTest.java | 2 +- .../TaskDeploymentDescriptorTest.java | 4 +- .../ExecutionVertexLocalityTest.java | 10 +- .../jobmanager/JobManagerHARecoveryTest.java | 12 +- .../taskmanager/TaskAsyncCallTest.java | 6 +- .../runtime/taskmanager/TaskStopTest.java | 26 +-- .../api/operators/AbstractStreamOperator.java | 12 +- .../api/operators/StreamOperator.java | 4 +- .../streaming/runtime/tasks/StreamTask.java | 49 ++--- .../async/AsyncWaitOperatorTest.java | 5 +- .../runtime/io/BarrierBufferTest.java | 4 +- .../runtime/io/BarrierTrackerTest.java | 4 +- .../tasks/InterruptSensitiveRestoreTest.java | 29 +-- .../runtime/tasks/OneInputStreamTaskTest.java | 7 +- .../runtime/tasks/StreamTaskTest.java | 23 ++- .../AbstractStreamOperatorTestHarness.java | 20 +- ...OperatorIDMappedStateToChainConverter.java | 91 --------- .../test/checkpointing/SavepointITCase.java | 2 +- 33 files changed, 420 insertions(+), 415 deletions(-) delete mode 100644 flink-streaming-java/src/test/java/org/apache/flink/streaming/util/OperatorIDMappedStateToChainConverter.java diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorSubtaskState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorSubtaskState.java index afe2c2fe87769..d4c79ebe4dace 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorSubtaskState.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorSubtaskState.java @@ -18,6 +18,7 @@ package org.apache.flink.runtime.checkpoint; +import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.runtime.state.CompositeStateHandle; import org.apache.flink.runtime.state.KeyedStateHandle; import org.apache.flink.runtime.state.OperatorStateHandle; @@ -30,6 +31,9 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import javax.annotation.Nonnull; +import javax.annotation.Nullable; + import java.util.ArrayList; import java.util.Collection; import java.util.Collections; @@ -39,9 +43,15 @@ * This class encapsulates the state for one parallel instance of an operator. The complete state of a (logical) * operator (e.g. a flatmap operator) consists of the union of all {@link OperatorSubtaskState}s from all * parallel tasks that physically execute parallelized, physical instances of the operator. - *

- * The full state of the logical operator is represented by {@link OperatorState} which consists of + *

The full state of the logical operator is represented by {@link OperatorState} which consists of * {@link OperatorSubtaskState}s. + *

Typically, we expect all collections in this class to be of size 0 or 1, because there up to one state handle + * produced per state type (e.g. managed-keyed, raw-operator, ...). In particular, this holds when taking a snapshot. + * The purpose of having the state handles in collections is that this class is also reused in restoring state. + * Under normal circumstances, the expected size of each collection is still 0 or 1, except for scale-down. In + * scale-down, one operator subtask can become responsible for the state of multiple previous subtasks. The collections + * can then store all the state handles that are relevant to build up the new subtask state. + *

There is no collection for legacy state because it is nor rescalable. */ public class OperatorSubtaskState implements CompositeStateHandle { @@ -56,26 +66,31 @@ public class OperatorSubtaskState implements CompositeStateHandle { * Can be removed when we remove the APIs for non-repartitionable operator state. */ @Deprecated + @Nullable private final StreamStateHandle legacyOperatorState; /** * Snapshot from the {@link org.apache.flink.runtime.state.OperatorStateBackend}. */ + @Nonnull private final Collection managedOperatorState; /** * Snapshot written using {@link org.apache.flink.runtime.state.OperatorStateCheckpointOutputStream}. */ + @Nonnull private final Collection rawOperatorState; /** * Snapshot from {@link org.apache.flink.runtime.state.KeyedStateBackend}. */ + @Nonnull private final Collection managedKeyedState; /** * Snapshot written using {@link org.apache.flink.runtime.state.KeyedStateCheckpointOutputStream}. */ + @Nonnull private final Collection rawKeyedState; /** @@ -85,13 +100,26 @@ public class OperatorSubtaskState implements CompositeStateHandle { */ private final long stateSize; - public OperatorSubtaskState() { - this.legacyOperatorState = null; + @VisibleForTesting + public OperatorSubtaskState(StreamStateHandle legacyOperatorState) { + + this.legacyOperatorState = legacyOperatorState; this.managedOperatorState = Collections.emptyList(); this.rawOperatorState = Collections.emptyList(); this.managedKeyedState = Collections.emptyList(); this.rawKeyedState = Collections.emptyList(); - this.stateSize = 0L; + try { + this.stateSize = getSizeNullSafe(legacyOperatorState); + } catch (Exception e) { + throw new RuntimeException("Failed to get state size.", e); + } + } + + /** + * Empty state. + */ + public OperatorSubtaskState() { + this(null); } public OperatorSubtaskState( @@ -108,7 +136,7 @@ public OperatorSubtaskState( this.rawKeyedState = Preconditions.checkNotNull(rawKeyedState); try { - long calculateStateSize = sumAllSizes(legacyOperatorState); + long calculateStateSize = getSizeNullSafe(legacyOperatorState); calculateStateSize += sumAllSizes(managedOperatorState); calculateStateSize += sumAllSizes(rawOperatorState); calculateStateSize += sumAllSizes(managedKeyedState); @@ -119,6 +147,10 @@ public OperatorSubtaskState( } } + /** + * For convenience because the size of the collections is typically 0 or 1. Null values are translated into empty + * Collections (except for legacy state). + */ public OperatorSubtaskState( StreamStateHandle legacyOperatorState, OperatorStateHandle managedOperatorState, @@ -127,23 +159,26 @@ public OperatorSubtaskState( KeyedStateHandle rawKeyedState) { this(legacyOperatorState, - Collections.singletonList(managedOperatorState), - Collections.singletonList(rawOperatorState), - Collections.singletonList(managedKeyedState), - Collections.singletonList(rawKeyedState)); + singletonOrEmptyOnNull(managedOperatorState), + singletonOrEmptyOnNull(rawOperatorState), + singletonOrEmptyOnNull(managedKeyedState), + singletonOrEmptyOnNull(rawKeyedState)); } - private static long sumAllSizes(Collection stateObject) throws Exception { + private static Collection singletonOrEmptyOnNull(T element) { + return element != null ? Collections.singletonList(element) : Collections.emptyList(); + } + private static long sumAllSizes(Collection stateObject) throws Exception { long size = 0L; for (StateObject object : stateObject) { - size += sumAllSizes(object); + size += getSizeNullSafe(object); } return size; } - private static long sumAllSizes(StateObject stateObject) throws Exception { + private static long getSizeNullSafe(StateObject stateObject) throws Exception { return stateObject != null ? stateObject.getStateSize() : 0L; } @@ -154,6 +189,7 @@ private static long sumAllSizes(StateObject stateObject) throws Exception { * Can be removed when we remove the APIs for non-repartitionable operator state. */ @Deprecated + @Nullable public StreamStateHandle getLegacyOperatorState() { return legacyOperatorState; } @@ -161,6 +197,7 @@ public StreamStateHandle getLegacyOperatorState() { /** * Returns a handle to the managed operator state. */ + @Nonnull public Collection getManagedOperatorState() { return managedOperatorState; } @@ -168,6 +205,7 @@ public Collection getManagedOperatorState() { /** * Returns a handle to the raw operator state. */ + @Nonnull public Collection getRawOperatorState() { return rawOperatorState; } @@ -175,6 +213,7 @@ public Collection getRawOperatorState() { /** * Returns a handle to the managed keyed state. */ + @Nonnull public Collection getManagedKeyedState() { return managedKeyedState; } @@ -182,6 +221,7 @@ public Collection getManagedKeyedState() { /** * Returns a handle to the raw keyed state. */ + @Nonnull public Collection getRawKeyedState() { return rawKeyedState; } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/RoundRobinOperatorStateRepartitioner.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/RoundRobinOperatorStateRepartitioner.java index 046096fc85c08..5bf9115756f91 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/RoundRobinOperatorStateRepartitioner.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/RoundRobinOperatorStateRepartitioner.java @@ -89,6 +89,10 @@ private GroupByStateNameResults groupByStateName( for (OperatorStateHandle psh : previousParallelSubtaskStates) { + if(psh == null) { + continue; + } + for (Map.Entry e : psh.getStateNameToPartitionOffsets().entrySet()) { OperatorStateHandle.StateMetaInfo metaInfo = e.getValue(); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java index 250c63e248546..b69285ed5a69e 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java @@ -23,14 +23,12 @@ import org.apache.flink.runtime.executiongraph.ExecutionJobVertex; import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.jobgraph.OperatorID; -import org.apache.flink.runtime.state.ChainedStateHandle; import org.apache.flink.runtime.state.KeyGroupRange; import org.apache.flink.runtime.state.KeyGroupRangeAssignment; import org.apache.flink.runtime.state.KeyGroupsStateHandle; import org.apache.flink.runtime.state.KeyedStateHandle; import org.apache.flink.runtime.state.OperatorStateHandle; import org.apache.flink.runtime.state.StreamStateHandle; -import org.apache.flink.runtime.state.TaskStateHandles; import org.apache.flink.util.Preconditions; import org.slf4j.Logger; @@ -186,7 +184,8 @@ private void assignAttemptState(ExecutionJobVertex executionJobVertex, List(subNonPartitionableState), - subManagedOperatorState, - subRawOperatorState, - subKeyedState != null ? subKeyedState.f0 : null, - subKeyedState != null ? subKeyedState.f1 : null); + for (int i = 0; i < operatorIDs.size(); ++i) { + + OperatorID operatorID = operatorIDs.get(i); + + Collection rawKeyed = Collections.emptyList(); + Collection managedKeyed = Collections.emptyList(); + + // keyed state case + if (subKeyedState != null) { + managedKeyed = subKeyedState.f0; + rawKeyed = subKeyedState.f1; + } + + OperatorSubtaskState operatorSubtaskState = + new OperatorSubtaskState( + subNonPartitionableState.get(i), + subManagedOperatorState.get(i), + subRawOperatorState.get(i), + managedKeyed, + rawKeyed + ); + + taskState.putSubtaskStateByOperatorID(operatorID, operatorSubtaskState); + } - currentExecutionAttempt.setInitialState(taskStateHandles); + currentExecutionAttempt.setInitialState(taskState); } } } + private static boolean isHeadOperator(int opIdx, List operatorIDs) { + return opIdx == operatorIDs.size() - 1; + } public void checkParallelismPreconditions(List operatorStates, ExecutionJobVertex executionJobVertex) { @@ -240,18 +260,18 @@ private void reAssignSubPartitionableState( List> subManagedOperatorState, List> subRawOperatorState) { - if (newMangedOperatorStates.get(operatorIndex) != null) { - subManagedOperatorState.add(newMangedOperatorStates.get(operatorIndex).get(subTaskIndex)); + if (newMangedOperatorStates.get(operatorIndex) != null && !newMangedOperatorStates.get(operatorIndex).isEmpty()) { + Collection operatorStateHandles = newMangedOperatorStates.get(operatorIndex).get(subTaskIndex); + subManagedOperatorState.add(operatorStateHandles != null ? operatorStateHandles : Collections.emptyList()); } else { - subManagedOperatorState.add(null); + subManagedOperatorState.add(Collections.emptyList()); } - if (newRawOperatorStates.get(operatorIndex) != null) { - subRawOperatorState.add(newRawOperatorStates.get(operatorIndex).get(subTaskIndex)); + if (newRawOperatorStates.get(operatorIndex) != null && !newRawOperatorStates.get(operatorIndex).isEmpty()) { + Collection operatorStateHandles = newRawOperatorStates.get(operatorIndex).get(subTaskIndex); + subRawOperatorState.add(operatorStateHandles != null ? operatorStateHandles : Collections.emptyList()); } else { - subRawOperatorState.add(null); + subRawOperatorState.add(Collections.emptyList()); } - - } private Tuple2, Collection> reAssignSubKeyedStates( @@ -269,17 +289,19 @@ private Tuple2, Collection> reAss subManagedKeyedState = operatorState.getState(subTaskIndex).getManagedKeyedState(); subRawKeyedState = operatorState.getState(subTaskIndex).getRawKeyedState(); } else { - subManagedKeyedState = null; - subRawKeyedState = null; + subManagedKeyedState = Collections.emptyList(); + subRawKeyedState = Collections.emptyList(); } } else { subManagedKeyedState = getManagedKeyedStateHandles(operatorState, keyGroupPartitions.get(subTaskIndex)); subRawKeyedState = getRawKeyedStateHandles(operatorState, keyGroupPartitions.get(subTaskIndex)); } - if (subManagedKeyedState == null && subRawKeyedState == null) { + + if (subManagedKeyedState.isEmpty() && subRawKeyedState.isEmpty()) { return null; + } else { + return new Tuple2<>(subManagedKeyedState, subRawKeyedState); } - return new Tuple2<>(subManagedKeyedState, subRawKeyedState); } @@ -315,7 +337,7 @@ private void reDistributePartitionableStates( List>> newManagedOperatorStates, List>> newRawOperatorStates) { - //collect the old partitionalbe state + //collect the old partitionable state List> oldManagedOperatorStates = new ArrayList<>(); List> oldRawOperatorStates = new ArrayList<>(); @@ -348,19 +370,16 @@ private void collectPartionableStates( for (int i = 0; i < operatorState.getParallelism(); i++) { OperatorSubtaskState operatorSubtaskState = operatorState.getState(i); if (operatorSubtaskState != null) { - if (operatorSubtaskState.getManagedOperatorState() != null) { - if (managedOperatorState == null) { - managedOperatorState = new ArrayList<>(); - } - managedOperatorState.addAll(operatorSubtaskState.getManagedOperatorState()); + + if (managedOperatorState == null) { + managedOperatorState = new ArrayList<>(); } + managedOperatorState.addAll(operatorSubtaskState.getManagedOperatorState()); - if (operatorSubtaskState.getRawOperatorState() != null) { - if (rawOperatorState == null) { - rawOperatorState = new ArrayList<>(); - } - rawOperatorState.addAll(operatorSubtaskState.getRawOperatorState()); + if (rawOperatorState == null) { + rawOperatorState = new ArrayList<>(); } + rawOperatorState.addAll(operatorSubtaskState.getRawOperatorState()); } } @@ -379,31 +398,19 @@ private void collectPartionableStates( * @return all managedKeyedStateHandles which have intersection with given KeyGroupRange */ public static List getManagedKeyedStateHandles( - OperatorState operatorState, - KeyGroupRange subtaskKeyGroupRange) { + OperatorState operatorState, + KeyGroupRange subtaskKeyGroupRange) { - List subtaskKeyedStateHandles = null; + List subtaskKeyedStateHandles = new ArrayList<>(); for (int i = 0; i < operatorState.getParallelism(); i++) { - if (operatorState.getState(i) != null && operatorState.getState(i).getManagedKeyedState() != null) { + if (operatorState.getState(i) != null) { Collection keyedStateHandles = operatorState.getState(i).getManagedKeyedState(); - for (KeyedStateHandle keyedStateHandle : keyedStateHandles) { - - //TODO deduplicate code!!!!!!! - if(keyedStateHandle != null) { - - KeyedStateHandle intersectedKeyedStateHandle = - keyedStateHandle.getIntersection(subtaskKeyGroupRange); - - if (intersectedKeyedStateHandle != null) { - if (subtaskKeyedStateHandles == null) { - subtaskKeyedStateHandles = new ArrayList<>(); - } - subtaskKeyedStateHandles.add(intersectedKeyedStateHandle); - } - } - } + extractIntersectingState( + keyedStateHandles, + subtaskKeyGroupRange, + subtaskKeyedStateHandles); } } @@ -422,31 +429,40 @@ public static List getRawKeyedStateHandles( OperatorState operatorState, KeyGroupRange subtaskKeyGroupRange) { - List subtaskKeyedStateHandles = null; + List extractedKeyedStateHandles = new ArrayList<>(); for (int i = 0; i < operatorState.getParallelism(); i++) { - if (operatorState.getState(i) != null && operatorState.getState(i).getRawKeyedState() != null) { - + if (operatorState.getState(i) != null) { Collection rawKeyedState = operatorState.getState(i).getRawKeyedState(); + extractIntersectingState( + rawKeyedState, + subtaskKeyGroupRange, + extractedKeyedStateHandles); + } + } + + return extractedKeyedStateHandles; + } + + /** + * Extracts certain key group ranges from the given state handles and adds them to the collector. + */ + private static void extractIntersectingState( + Collection originalSubtaskStateHandles, + KeyGroupRange rangeToExtract, + List extractedStateCollector) { - for (KeyedStateHandle keyedStateHandle : rawKeyedState) { + for (KeyedStateHandle keyedStateHandle : originalSubtaskStateHandles) { - if (keyedStateHandle != null) { + if (keyedStateHandle != null) { - KeyedStateHandle intersectedKeyedStateHandle = keyedStateHandle.getIntersection(subtaskKeyGroupRange); + KeyedStateHandle intersectedKeyedStateHandle = keyedStateHandle.getIntersection(rangeToExtract); - if (intersectedKeyedStateHandle != null) { - if (subtaskKeyedStateHandles == null) { - subtaskKeyedStateHandles = new ArrayList<>(); - } - subtaskKeyedStateHandles.add(intersectedKeyedStateHandle); - } - } + if (intersectedKeyedStateHandle != null) { + extractedStateCollector.add(intersectedKeyedStateHandle); } } } - - return subtaskKeyedStateHandles; } /** @@ -570,7 +586,7 @@ public static List> applyRepartitioner( int newParallelism) { if (chainOpParallelStates == null) { - return null; + return Collections.emptyList(); } //We only redistribute if the parallelism of the operator changed from previous executions @@ -583,20 +599,23 @@ public static List> applyRepartitioner( List> repackStream = new ArrayList<>(newParallelism); for (OperatorStateHandle operatorStateHandle : chainOpParallelStates) { - Map partitionOffsets = + if (operatorStateHandle != null) { + Map partitionOffsets = operatorStateHandle.getStateNameToPartitionOffsets(); - for (OperatorStateHandle.StateMetaInfo metaInfo : partitionOffsets.values()) { - // if we find any broadcast state, we cannot take the shortcut and need to go through repartitioning - if (OperatorStateHandle.Mode.BROADCAST.equals(metaInfo.getDistributionMode())) { - return opStateRepartitioner.repartitionState( + for (OperatorStateHandle.StateMetaInfo metaInfo : partitionOffsets.values()) { + + // if we find any broadcast state, we cannot take the shortcut and need to go through repartitioning + if (OperatorStateHandle.Mode.BROADCAST.equals(metaInfo.getDistributionMode())) { + return opStateRepartitioner.repartitionState( chainOpParallelStates, newParallelism); + } } - } - repackStream.add(Collections.singletonList(operatorStateHandle)); + repackStream.add(Collections.singletonList(operatorStateHandle)); + } } return repackStream; } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskStateSnapshot.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskStateSnapshot.java index ee59ed97a8691..d464423134459 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskStateSnapshot.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskStateSnapshot.java @@ -34,11 +34,12 @@ * register their state under their operator id. Each operator instance is a physical execution responsible for * processing a partition of the data that goes through a logical operator. This partitioning happens to parallelize * execution of logical operators, e.g. distributing a map function. - *

- * One instance of this class contains the information that one task will send to acknowledge a checkpoint request by t + *

One instance of this class contains the information that one task will send to acknowledge a checkpoint request by t * he checkpoint coordinator. Tasks run operator instances in parallel, so the union of all * {@link TaskStateSnapshot} that are collected by the checkpoint coordinator from all tasks represent the whole * state of a job at the time of the checkpoint. + *

This class should be called TaskState once the old class with this name that we keep for backwards + * compatibility goes away. */ public class TaskStateSnapshot implements CompositeStateHandle { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV2Serializer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV2Serializer.java index 4cbbfcfba8b5b..15628a0429e31 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV2Serializer.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV2Serializer.java @@ -240,6 +240,18 @@ private MasterState deserializeMasterState(DataInputStream dis) throws IOExcepti // task state (de)serialization methods // ------------------------------------------------------------------------ + private static T extractSingleton(Collection collection) { + if (collection == null || collection.isEmpty()) { + return null; + } + + if (collection.size() == 1) { + return collection.iterator().next(); + } else { + throw new IllegalStateException("Expected singleton collection, but found size: " + collection.size()); + } + } + private static void serializeSubtaskState(OperatorSubtaskState subtaskState, DataOutputStream dos) throws IOException { dos.writeLong(-1); @@ -252,7 +264,7 @@ private static void serializeSubtaskState(OperatorSubtaskState subtaskState, Dat serializeStreamStateHandle(nonPartitionableState, dos); } - OperatorStateHandle operatorStateBackend = subtaskState.getManagedOperatorState(); + OperatorStateHandle operatorStateBackend = extractSingleton(subtaskState.getManagedOperatorState()); len = operatorStateBackend != null ? 1 : 0; dos.writeInt(len); @@ -260,7 +272,7 @@ private static void serializeSubtaskState(OperatorSubtaskState subtaskState, Dat serializeOperatorStateHandle(operatorStateBackend, dos); } - OperatorStateHandle operatorStateFromStream = subtaskState.getRawOperatorState(); + OperatorStateHandle operatorStateFromStream = extractSingleton(subtaskState.getRawOperatorState()); len = operatorStateFromStream != null ? 1 : 0; dos.writeInt(len); @@ -268,10 +280,10 @@ private static void serializeSubtaskState(OperatorSubtaskState subtaskState, Dat serializeOperatorStateHandle(operatorStateFromStream, dos); } - KeyedStateHandle keyedStateBackend = subtaskState.getManagedKeyedState(); + KeyedStateHandle keyedStateBackend = extractSingleton(subtaskState.getManagedKeyedState()); serializeKeyedStateHandle(keyedStateBackend, dos); - KeyedStateHandle keyedStateStream = subtaskState.getRawKeyedState(); + KeyedStateHandle keyedStateStream = extractSingleton(subtaskState.getRawKeyedState()); serializeKeyedStateHandle(keyedStateStream, dos); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptor.java b/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptor.java index 0578b787290d9..1fa5eb5b51484 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptor.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptor.java @@ -18,11 +18,11 @@ package org.apache.flink.runtime.deployment; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.clusterframework.types.AllocationID; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.executiongraph.JobInformation; import org.apache.flink.runtime.executiongraph.TaskInformation; -import org.apache.flink.runtime.state.TaskStateHandles; import org.apache.flink.util.Preconditions; import org.apache.flink.util.SerializedValue; @@ -64,7 +64,7 @@ public final class TaskDeploymentDescriptor implements Serializable { private final int targetSlotNumber; /** State handles for the sub task. */ - private final TaskStateHandles taskStateHandles; + private final TaskStateSnapshot taskStateHandles; public TaskDeploymentDescriptor( SerializedValue serializedJobInformation, @@ -74,7 +74,7 @@ public TaskDeploymentDescriptor( int subtaskIndex, int attemptNumber, int targetSlotNumber, - TaskStateHandles taskStateHandles, + TaskStateSnapshot taskStateHandles, Collection resultPartitionDeploymentDescriptors, Collection inputGateDeploymentDescriptors) { @@ -153,7 +153,7 @@ public Collection getInputGates() { return inputGates; } - public TaskStateHandles getTaskStateHandles() { + public TaskStateSnapshot getTaskStateHandles() { return taskStateHandles; } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java index c0f1f39fdc5f4..98fbbacd8c915 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java @@ -24,6 +24,7 @@ import org.apache.flink.runtime.JobException; import org.apache.flink.runtime.accumulators.StringifiedAccumulatorResult; import org.apache.flink.runtime.checkpoint.CheckpointOptions; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.clusterframework.types.ResourceID; import org.apache.flink.runtime.concurrent.ApplyFunction; import org.apache.flink.runtime.concurrent.BiFunction; @@ -46,7 +47,6 @@ import org.apache.flink.runtime.jobmanager.slots.TaskManagerGateway; import org.apache.flink.runtime.messages.Acknowledge; import org.apache.flink.runtime.messages.StackTraceSampleResponse; -import org.apache.flink.runtime.state.TaskStateHandles; import org.apache.flink.runtime.taskmanager.TaskManagerLocation; import org.apache.flink.util.ExceptionUtils; @@ -138,7 +138,7 @@ public class Execution implements AccessExecution, Archiveable getPreferredLocations() { */ public Iterable getPreferredLocationsBasedOnState() { TaskManagerLocation priorLocation; - if (currentExecution.getTaskStateHandles() != null && (priorLocation = getLatestPriorLocation()) != null) { + if (currentExecution.getTaskStateSnapshot() != null && (priorLocation = getLatestPriorLocation()) != null) { return Collections.singleton(priorLocation); } else { @@ -719,7 +719,7 @@ void notifyStateTransition(Execution execution, ExecutionState newState, Throwab TaskDeploymentDescriptor createDeploymentDescriptor( ExecutionAttemptID executionId, SimpleSlot targetSlot, - TaskStateHandles taskStateHandles, + TaskStateSnapshot taskStateHandles, int attemptNumber) throws ExecutionGraphException { // Produced intermediate results diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/tasks/StatefulTask.java b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/tasks/StatefulTask.java index 0930011896353..00db01ffd2e04 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/tasks/StatefulTask.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/tasks/StatefulTask.java @@ -21,7 +21,7 @@ import org.apache.flink.runtime.checkpoint.CheckpointMetaData; import org.apache.flink.runtime.checkpoint.CheckpointMetrics; import org.apache.flink.runtime.checkpoint.CheckpointOptions; -import org.apache.flink.runtime.state.TaskStateHandles; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; /** * This interface must be implemented by any invokable that has recoverable state and participates @@ -35,7 +35,7 @@ public interface StatefulTask { * * @param taskStateHandles All state handle for the task. */ - void setInitialState(TaskStateHandles taskStateHandles) throws Exception; + void setInitialState(TaskStateSnapshot taskStateHandles) throws Exception; /** * This method is called to trigger a checkpoint, asynchronously by the checkpoint @@ -43,8 +43,8 @@ public interface StatefulTask { * *

This method is called for tasks that start the checkpoints by injecting the initial barriers, * i.e., the source tasks. In contrast, checkpoints on downstream operators, which are the result of - * receiving checkpoint barriers, invoke the {@link #triggerCheckpointOnBarrier(CheckpointMetaData, CheckpointMetrics)} - * method. + * receiving checkpoint barriers, invoke the + * {@link #triggerCheckpointOnBarrier(CheckpointMetaData, CheckpointOptions, CheckpointMetrics)} method. * * @param checkpointMetaData Meta data for about this checkpoint * @param checkpointOptions Options for performing this checkpoint diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateInitializationContextImpl.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateInitializationContextImpl.java index d82af7217a7cd..031d7c717284c 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateInitializationContextImpl.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateInitializationContextImpl.java @@ -18,7 +18,6 @@ package org.apache.flink.runtime.state; -import org.apache.commons.io.IOUtils; import org.apache.flink.api.common.state.KeyedStateStore; import org.apache.flink.api.common.state.OperatorStateStore; import org.apache.flink.api.java.tuple.Tuple2; @@ -26,6 +25,8 @@ import org.apache.flink.core.fs.FSDataInputStream; import org.apache.flink.util.Preconditions; +import org.apache.commons.io.IOUtils; + import java.io.IOException; import java.util.ArrayList; import java.util.Collection; @@ -139,6 +140,7 @@ public void close() { } private static Collection transform(Collection keyedStateHandles) { + if (keyedStateHandles == null) { return null; } @@ -146,13 +148,14 @@ private static Collection transform(Collection keyGroupsStateHandles = new ArrayList<>(); for (KeyedStateHandle keyedStateHandle : keyedStateHandles) { - if (! (keyedStateHandle instanceof KeyGroupsStateHandle)) { + + if (keyedStateHandle instanceof KeyGroupsStateHandle) { + keyGroupsStateHandles.add((KeyGroupsStateHandle) keyedStateHandle); + } else if (keyedStateHandle != null) { throw new IllegalStateException("Unexpected state handle type, " + "expected: " + KeyGroupsStateHandle.class + ", but found: " + keyedStateHandle.getClass() + "."); } - - keyGroupsStateHandles.add((KeyGroupsStateHandle) keyedStateHandle); } return keyGroupsStateHandles; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskStateHandles.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskStateHandles.java index 2fde5485049f9..9a00e68854435 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskStateHandles.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskStateHandles.java @@ -19,6 +19,7 @@ package org.apache.flink.runtime.state; import org.apache.flink.runtime.checkpoint.SubtaskState; +import org.apache.flink.util.Preconditions; import java.io.Serializable; import java.util.ArrayList; @@ -57,7 +58,11 @@ public class TaskStateHandles implements Serializable { private final List> rawOperatorState; public TaskStateHandles() { - this(null, null, null, null, null); + this(null, + Collections.>emptyList(), + Collections.>emptyList(), + Collections.emptyList(), + Collections.emptyList()); } public TaskStateHandles(SubtaskState checkpointStateHandles) { @@ -76,10 +81,10 @@ public TaskStateHandles( Collection rawKeyedState) { this.legacyOperatorState = legacyOperatorState; - this.managedKeyedState = managedKeyedState; - this.rawKeyedState = rawKeyedState; - this.managedOperatorState = managedOperatorState; - this.rawOperatorState = rawOperatorState; + this.managedKeyedState = Preconditions.checkNotNull(managedKeyedState); + this.rawKeyedState = Preconditions.checkNotNull(rawKeyedState); + this.managedOperatorState = Preconditions.checkNotNull(managedOperatorState); + this.rawOperatorState = Preconditions.checkNotNull(rawOperatorState); } /** diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java index 65ce18c5c5628..9fb1fc7363ede 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java @@ -34,6 +34,7 @@ import org.apache.flink.runtime.broadcast.BroadcastVariableManager; import org.apache.flink.runtime.checkpoint.CheckpointMetaData; import org.apache.flink.runtime.checkpoint.CheckpointOptions; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.checkpoint.decline.CheckpointDeclineTaskNotCheckpointingException; import org.apache.flink.runtime.checkpoint.decline.CheckpointDeclineTaskNotReadyException; import org.apache.flink.runtime.clusterframework.types.AllocationID; @@ -68,16 +69,17 @@ import org.apache.flink.runtime.memory.MemoryManager; import org.apache.flink.runtime.metrics.groups.TaskMetricGroup; import org.apache.flink.runtime.query.TaskKvStateRegistry; -import org.apache.flink.runtime.state.TaskStateHandles; import org.apache.flink.util.ExceptionUtils; import org.apache.flink.util.Preconditions; import org.apache.flink.util.SerializedValue; import org.apache.flink.util.WrappingRuntimeException; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; import javax.annotation.Nonnull; import javax.annotation.Nullable; + import java.io.IOException; import java.net.URL; import java.util.Collection; @@ -250,7 +252,7 @@ public class Task implements Runnable, TaskActions { * The handles to the states that the task was initialized with. Will be set * to null after the initialization, to be memory friendly. */ - private volatile TaskStateHandles taskStateHandles; + private volatile TaskStateSnapshot taskStateHandles; /** Initialized from the Flink configuration. May also be set at the ExecutionConfig */ private long taskCancellationInterval; @@ -272,7 +274,7 @@ public Task( Collection resultPartitionDeploymentDescriptors, Collection inputGateDeploymentDescriptors, int targetSlotNumber, - TaskStateHandles taskStateHandles, + TaskStateSnapshot taskStateHandles, MemoryManager memManager, IOManager ioManager, NetworkEnvironment networkEnvironment, diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java index 584d4fa31cbdb..88b95f5959fd7 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java @@ -43,7 +43,6 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; -import static org.mockito.Matchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; @@ -127,10 +126,10 @@ public void testFailingCompletedCheckpointStoreAdd() throws Exception { // make sure that the subtask state has been discarded after we could not complete it. verify(operatorSubtaskState).discardState(); verify(operatorSubtaskState.getLegacyOperatorState()).discardState(); - verify(operatorSubtaskState.getManagedOperatorState()).discardState(); - verify(operatorSubtaskState.getRawOperatorState()).discardState(); - verify(operatorSubtaskState.getManagedKeyedState()).discardState(); - verify(operatorSubtaskState.getRawKeyedState()).discardState(); + verify(operatorSubtaskState.getManagedOperatorState().iterator().next()).discardState(); + verify(operatorSubtaskState.getRawOperatorState().iterator().next()).discardState(); + verify(operatorSubtaskState.getManagedKeyedState().iterator().next()).discardState(); + verify(operatorSubtaskState.getRawKeyedState().iterator().next()).discardState(); } private static final class FailingCompletedCheckpointStore implements CompletedCheckpointStore { diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java index 7b87d1ea81a44..6be660781bafd 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java @@ -45,7 +45,6 @@ import org.apache.flink.runtime.state.OperatorStateHandle; import org.apache.flink.runtime.state.SharedStateRegistry; import org.apache.flink.runtime.state.StreamStateHandle; -import org.apache.flink.runtime.state.TaskStateHandles; import org.apache.flink.runtime.state.filesystem.FileStateHandle; import org.apache.flink.runtime.state.memory.ByteStreamStateHandle; import org.apache.flink.runtime.testutils.CommonTestUtils; @@ -2484,26 +2483,35 @@ private void testRestoreLatestCheckpointedStateWithChangingParallelism(boolean s List>> actualOpStatesBackend = new ArrayList<>(newJobVertex2.getParallelism()); List>> actualOpStatesRaw = new ArrayList<>(newJobVertex2.getParallelism()); for (int i = 0; i < newJobVertex2.getParallelism(); i++) { - KeyGroupsStateHandle originalKeyedStateBackend = generateKeyGroupState(jobVertexID2, newKeyGroupPartitions2.get(i), false); - KeyGroupsStateHandle originalKeyedStateRaw = generateKeyGroupState(jobVertexID2, newKeyGroupPartitions2.get(i), true); - TaskStateHandles taskStateHandles = newJobVertex2.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskStateHandles(); + List operatorIDs = newJobVertex2.getOperatorIDs(); - ChainedStateHandle operatorState = taskStateHandles.getLegacyOperatorState(); - List> opStateBackend = taskStateHandles.getManagedOperatorState(); - List> opStateRaw = taskStateHandles.getRawOperatorState(); - Collection keyedStateBackend = taskStateHandles.getManagedKeyedState(); - Collection keyGroupStateRaw = taskStateHandles.getRawKeyedState(); + KeyGroupsStateHandle originalKeyedStateBackend = generateKeyGroupState(jobVertexID2, newKeyGroupPartitions2.get(i), false); + KeyGroupsStateHandle originalKeyedStateRaw = generateKeyGroupState(jobVertexID2, newKeyGroupPartitions2.get(i), true); - actualOpStatesBackend.add(opStateBackend); - actualOpStatesRaw.add(opStateRaw); - // the 'non partition state' is not null because it is recombined. - assertNotNull(operatorState); - for (int index = 0; index < operatorState.getLength(); index++) { - assertNull(operatorState.get(index)); + TaskStateSnapshot taskStateHandles = newJobVertex2.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskStateSnapshot(); + + final int headOpIndex = operatorIDs.size() - 1; + List> allParallelManagedOpStates = new ArrayList<>(operatorIDs.size()); + List> allParallelRawOpStates = new ArrayList<>(operatorIDs.size()); + + for (int idx = 0; idx < operatorIDs.size(); ++idx) { + OperatorID operatorID = operatorIDs.get(idx); + OperatorSubtaskState opState = taskStateHandles.getSubtaskStateByOperatorID(operatorID); + Assert.assertNull(opState.getLegacyOperatorState()); + Collection opStateBackend = opState.getManagedOperatorState(); + Collection opStateRaw = opState.getRawOperatorState(); + allParallelManagedOpStates.add(opStateBackend); + allParallelRawOpStates.add(opStateRaw); + if (idx == headOpIndex) { + Collection keyedStateBackend = opState.getManagedKeyedState(); + Collection keyGroupStateRaw = opState.getRawKeyedState(); + compareKeyedState(Collections.singletonList(originalKeyedStateBackend), keyedStateBackend); + compareKeyedState(Collections.singletonList(originalKeyedStateRaw), keyGroupStateRaw); + } } - compareKeyedState(Collections.singletonList(originalKeyedStateBackend), keyedStateBackend); - compareKeyedState(Collections.singletonList(originalKeyedStateRaw), keyGroupStateRaw); + actualOpStatesBackend.add(allParallelManagedOpStates); + actualOpStatesRaw.add(allParallelRawOpStates); } comparePartitionableState(expectedOpStatesBackend, actualOpStatesBackend); @@ -2683,24 +2691,30 @@ public void testStateRecoveryWithTopologyChange(int scaleType) throws Exception for (int i = 0; i < newJobVertex1.getParallelism(); i++) { - TaskStateHandles taskStateHandles = newJobVertex1.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskStateHandles(); - ChainedStateHandle actualSubNonPartitionedState = taskStateHandles.getLegacyOperatorState(); - List> actualSubManagedOperatorState = taskStateHandles.getManagedOperatorState(); - List> actualSubRawOperatorState = taskStateHandles.getRawOperatorState(); + final List operatorIds = newJobVertex1.getOperatorIDs(); + + TaskStateSnapshot stateSnapshot = newJobVertex1.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskStateSnapshot(); - assertNull(taskStateHandles.getManagedKeyedState()); - assertNull(taskStateHandles.getRawKeyedState()); + OperatorSubtaskState headOpState = stateSnapshot.getSubtaskStateByOperatorID(operatorIds.get(operatorIds.size() - 1)); + assertTrue(headOpState.getManagedKeyedState().isEmpty()); + assertTrue(headOpState.getRawKeyedState().isEmpty()); // operator5 { int operatorIndexInChain = 2; - assertNull(actualSubNonPartitionedState.get(operatorIndexInChain)); - assertNull(actualSubManagedOperatorState.get(operatorIndexInChain)); - assertNull(actualSubRawOperatorState.get(operatorIndexInChain)); + OperatorSubtaskState opState = + stateSnapshot.getSubtaskStateByOperatorID(operatorIds.get(operatorIndexInChain)); + + assertNull(opState.getLegacyOperatorState()); + assertTrue(opState.getManagedOperatorState().isEmpty()); + assertTrue(opState.getRawOperatorState().isEmpty()); } // operator1 { int operatorIndexInChain = 1; + OperatorSubtaskState opState = + stateSnapshot.getSubtaskStateByOperatorID(operatorIds.get(operatorIndexInChain)); + StreamStateHandle expectSubNonPartitionedState = generateStateForVertex(id1.f0, i); OperatorStateHandle expectedManagedOpState = generatePartitionableStateHandle( id1.f0, i, 2, 8, false); @@ -2709,31 +2723,43 @@ public void testStateRecoveryWithTopologyChange(int scaleType) throws Exception assertTrue(CommonTestUtils.isSteamContentEqual( expectSubNonPartitionedState.openInputStream(), - actualSubNonPartitionedState.get(operatorIndexInChain).openInputStream())); + opState.getLegacyOperatorState().openInputStream())); + Collection managedOperatorState = opState.getManagedOperatorState(); + assertEquals(1, managedOperatorState.size()); assertTrue(CommonTestUtils.isSteamContentEqual(expectedManagedOpState.openInputStream(), - actualSubManagedOperatorState.get(operatorIndexInChain).iterator().next().openInputStream())); + managedOperatorState.iterator().next().openInputStream())); + Collection rawOperatorState = opState.getRawOperatorState(); + assertEquals(1, rawOperatorState.size()); assertTrue(CommonTestUtils.isSteamContentEqual(expectedRawOpState.openInputStream(), - actualSubRawOperatorState.get(operatorIndexInChain).iterator().next().openInputStream())); + rawOperatorState.iterator().next().openInputStream())); } // operator2 { int operatorIndexInChain = 0; + OperatorSubtaskState opState = + stateSnapshot.getSubtaskStateByOperatorID(operatorIds.get(operatorIndexInChain)); + StreamStateHandle expectSubNonPartitionedState = generateStateForVertex(id2.f0, i); OperatorStateHandle expectedManagedOpState = generatePartitionableStateHandle( id2.f0, i, 2, 8, false); OperatorStateHandle expectedRawOpState = generatePartitionableStateHandle( id2.f0, i, 2, 8, true); - assertTrue(CommonTestUtils.isSteamContentEqual(expectSubNonPartitionedState.openInputStream(), - actualSubNonPartitionedState.get(operatorIndexInChain).openInputStream())); + assertTrue(CommonTestUtils.isSteamContentEqual( + expectSubNonPartitionedState.openInputStream(), + opState.getLegacyOperatorState().openInputStream())); + Collection managedOperatorState = opState.getManagedOperatorState(); + assertEquals(1, managedOperatorState.size()); assertTrue(CommonTestUtils.isSteamContentEqual(expectedManagedOpState.openInputStream(), - actualSubManagedOperatorState.get(operatorIndexInChain).iterator().next().openInputStream())); + managedOperatorState.iterator().next().openInputStream())); + Collection rawOperatorState = opState.getRawOperatorState(); + assertEquals(1, rawOperatorState.size()); assertTrue(CommonTestUtils.isSteamContentEqual(expectedRawOpState.openInputStream(), - actualSubRawOperatorState.get(operatorIndexInChain).iterator().next().openInputStream())); + rawOperatorState.iterator().next().openInputStream())); } } @@ -2741,38 +2767,48 @@ public void testStateRecoveryWithTopologyChange(int scaleType) throws Exception List>> actualRawOperatorStates = new ArrayList<>(newJobVertex2.getParallelism()); for (int i = 0; i < newJobVertex2.getParallelism(); i++) { - TaskStateHandles taskStateHandles = newJobVertex2.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskStateHandles(); + + final List operatorIds = newJobVertex2.getOperatorIDs(); + + TaskStateSnapshot stateSnapshot = newJobVertex2.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskStateSnapshot(); // operator 3 { int operatorIndexInChain = 1; + OperatorSubtaskState opState = + stateSnapshot.getSubtaskStateByOperatorID(operatorIds.get(operatorIndexInChain)); + List> actualSubManagedOperatorState = new ArrayList<>(1); - actualSubManagedOperatorState.add(taskStateHandles.getManagedOperatorState().get(operatorIndexInChain)); + actualSubManagedOperatorState.add(opState.getManagedOperatorState()); List> actualSubRawOperatorState = new ArrayList<>(1); - actualSubRawOperatorState.add(taskStateHandles.getRawOperatorState().get(operatorIndexInChain)); + actualSubRawOperatorState.add(opState.getRawOperatorState()); actualManagedOperatorStates.add(actualSubManagedOperatorState); actualRawOperatorStates.add(actualSubRawOperatorState); - assertNull(taskStateHandles.getLegacyOperatorState().get(operatorIndexInChain)); + assertNull(opState.getLegacyOperatorState()); } // operator 6 { int operatorIndexInChain = 0; - assertNull(taskStateHandles.getManagedOperatorState().get(operatorIndexInChain)); - assertNull(taskStateHandles.getRawOperatorState().get(operatorIndexInChain)); - assertNull(taskStateHandles.getLegacyOperatorState().get(operatorIndexInChain)); + OperatorSubtaskState opState = + stateSnapshot.getSubtaskStateByOperatorID(operatorIds.get(operatorIndexInChain)); + assertNull(opState.getLegacyOperatorState()); + assertTrue(opState.getManagedOperatorState().isEmpty()); + assertTrue(opState.getRawOperatorState().isEmpty()); } KeyGroupsStateHandle originalKeyedStateBackend = generateKeyGroupState(id3.f0, newKeyGroupPartitions2.get(i), false); KeyGroupsStateHandle originalKeyedStateRaw = generateKeyGroupState(id3.f0, newKeyGroupPartitions2.get(i), true); + OperatorSubtaskState headOpState = + stateSnapshot.getSubtaskStateByOperatorID(operatorIds.get(operatorIds.size() - 1)); - Collection keyedStateBackend = taskStateHandles.getManagedKeyedState(); - Collection keyGroupStateRaw = taskStateHandles.getRawKeyedState(); + Collection keyedStateBackend = headOpState.getManagedKeyedState(); + Collection keyGroupStateRaw = headOpState.getRawKeyedState(); compareKeyedState(Collections.singletonList(originalKeyedStateBackend), keyedStateBackend); @@ -3169,27 +3205,27 @@ public static void verifyStateRestore( for (int i = 0; i < executionJobVertex.getParallelism(); i++) { - TaskStateHandles taskStateHandles = executionJobVertex.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskStateHandles(); + final List operatorIds = executionJobVertex.getOperatorIDs(); + + TaskStateSnapshot stateSnapshot = executionJobVertex.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskStateSnapshot(); + + OperatorSubtaskState operatorState = stateSnapshot.getSubtaskStateByOperatorID(OperatorID.fromJobVertexID(jobVertexID)); StreamStateHandle expectNonPartitionedState = generateStateForVertex(jobVertexID, i); - ChainedStateHandle actualNonPartitionedState = taskStateHandles.getLegacyOperatorState(); assertTrue(CommonTestUtils.isSteamContentEqual( expectNonPartitionedState.openInputStream(), - actualNonPartitionedState.get(0).openInputStream())); + operatorState.getLegacyOperatorState().openInputStream())); ChainedStateHandle expectedOpStateBackend = generateChainedPartitionableStateHandle(jobVertexID, i, 2, 8, false); - List> actualPartitionableState = taskStateHandles.getManagedOperatorState(); - assertTrue(CommonTestUtils.isSteamContentEqual( expectedOpStateBackend.get(0).openInputStream(), - actualPartitionableState.get(0).iterator().next().openInputStream())); + operatorState.getManagedOperatorState().iterator().next().openInputStream())); KeyGroupsStateHandle expectPartitionedKeyGroupState = generateKeyGroupState( jobVertexID, keyGroupPartitions.get(i), false); - Collection actualPartitionedKeyGroupState = taskStateHandles.getManagedKeyedState(); - compareKeyedState(Collections.singletonList(expectPartitionedKeyGroupState), actualPartitionedKeyGroupState); + compareKeyedState(Collections.singletonList(expectPartitionedKeyGroupState), operatorState.getManagedKeyedState()); } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java index f4807a3ef8b54..53150350bb6b1 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java @@ -34,7 +34,6 @@ import org.apache.flink.runtime.state.KeyedStateHandle; import org.apache.flink.runtime.state.OperatorStateHandle; import org.apache.flink.runtime.state.StreamStateHandle; -import org.apache.flink.runtime.state.TaskStateHandles; import org.apache.flink.runtime.util.SerializableObject; import org.hamcrest.BaseMatcher; @@ -42,11 +41,11 @@ import org.junit.Test; import org.mockito.Mockito; -import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Objects; import static org.junit.Assert.assertEquals; import static org.junit.Assert.fail; @@ -119,16 +118,16 @@ public void testSetState() { PendingCheckpoint pending = coord.getPendingCheckpoints().values().iterator().next(); final long checkpointId = pending.getCheckpointId(); - TaskStateSnapshot subtaskStates = new TaskStateSnapshot(); + final TaskStateSnapshot subtaskStates = new TaskStateSnapshot(); subtaskStates.putSubtaskStateByOperatorID( OperatorID.fromJobVertexID(statefulId), new OperatorSubtaskState( serializedState.get(0), - null, - null, - serializedKeyGroupStates, - null)); + Collections.emptyList(), + Collections.emptyList(), + Collections.singletonList(serializedKeyGroupStates), + Collections.emptyList())); //SubtaskState checkpointStateHandles = new SubtaskState(serializedState, null, null, serializedKeyGroupStates, null); @@ -146,33 +145,26 @@ public void testSetState() { // verify that each stateful vertex got the state - final TaskStateHandles taskStateHandles = new TaskStateHandles( - serializedState, - Collections.>singletonList(null), - Collections.>singletonList(null), - Collections.singletonList(serializedKeyGroupStates), - null); - - BaseMatcher matcher = new BaseMatcher() { + BaseMatcher matcher = new BaseMatcher() { @Override public boolean matches(Object o) { - if (o instanceof TaskStateHandles) { - return o.equals(taskStateHandles); + if (o instanceof TaskStateSnapshot) { + return Objects.equals(o, subtaskStates); } return false; } @Override public void describeTo(Description description) { - description.appendValue(taskStateHandles); + description.appendValue(subtaskStates); } }; verify(statefulExec1, times(1)).setInitialState(Mockito.argThat(matcher)); verify(statefulExec2, times(1)).setInitialState(Mockito.argThat(matcher)); verify(statefulExec3, times(1)).setInitialState(Mockito.argThat(matcher)); - verify(statelessExec1, times(0)).setInitialState(Mockito.any()); - verify(statelessExec2, times(0)).setInitialState(Mockito.any()); + verify(statelessExec1, times(0)).setInitialState(Mockito.any()); + verify(statelessExec2, times(0)).setInitialState(Mockito.any()); } catch (Exception e) { e.printStackTrace(); @@ -263,9 +255,9 @@ public void testNonRestoredState() throws Exception { Map checkpointTaskStates = new HashMap<>(); { OperatorState taskState = new OperatorState(operatorId1, 3, 3); - taskState.putState(0, new OperatorSubtaskState(serializedState, null, null, null, null)); - taskState.putState(1, new OperatorSubtaskState(serializedState, null, null, null, null)); - taskState.putState(2, new OperatorSubtaskState(serializedState, null, null, null, null)); + taskState.putState(0, new OperatorSubtaskState(serializedState)); + taskState.putState(1, new OperatorSubtaskState(serializedState)); + taskState.putState(2, new OperatorSubtaskState(serializedState)); checkpointTaskStates.put(operatorId1, taskState); } @@ -292,7 +284,7 @@ public void testNonRestoredState() throws Exception { // There is no task for this { OperatorState taskState = new OperatorState(newOperatorID, 1, 1); - taskState.putState(0, new OperatorSubtaskState(serializedState, null, null, null, null)); + taskState.putState(0, new OperatorSubtaskState(serializedState)); checkpointTaskStates.put(newOperatorID, taskState); } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java index 1fe4e65979cc8..320dc2df52bfb 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java @@ -331,7 +331,7 @@ static class TestOperatorSubtaskState extends OperatorSubtaskState { boolean discarded; public TestOperatorSubtaskState() { - super(null, null, null, null, null); + super(); this.registered = false; this.discarded = false; } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorTest.java index 36c9cadeaec3e..9ed4851cad516 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorTest.java @@ -23,6 +23,7 @@ import org.apache.flink.configuration.Configuration; import org.apache.flink.core.testutils.CommonTestUtils; import org.apache.flink.runtime.blob.BlobKey; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.clusterframework.types.AllocationID; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.executiongraph.JobInformation; @@ -30,7 +31,6 @@ import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable; import org.apache.flink.runtime.operators.BatchTask; -import org.apache.flink.runtime.state.TaskStateHandles; import org.apache.flink.util.SerializedValue; import org.junit.Test; @@ -73,7 +73,7 @@ public void testSerialization() { final SerializedValue serializedJobVertexInformation = new SerializedValue<>(new TaskInformation( vertexID, taskName, currentNumberOfSubtasks, numberOfKeyGroups, invokableClass.getName(), taskConfiguration)); final int targetSlotNumber = 47; - final TaskStateHandles taskStateHandles = new TaskStateHandles(); + final TaskStateSnapshot taskStateHandles = new TaskStateSnapshot(); final TaskDeploymentDescriptor orig = new TaskDeploymentDescriptor( serializedJobInformation, diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionVertexLocalityTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionVertexLocalityTest.java index 0eed90d271bcc..c9b7a40a78b0a 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionVertexLocalityTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionVertexLocalityTest.java @@ -23,6 +23,7 @@ import org.apache.flink.configuration.Configuration; import org.apache.flink.metrics.groups.UnregisteredMetricsGroup; import org.apache.flink.runtime.checkpoint.StandaloneCheckpointRecoveryFactory; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.clusterframework.types.AllocationID; import org.apache.flink.runtime.clusterframework.types.ResourceID; import org.apache.flink.runtime.clusterframework.types.ResourceProfile; @@ -38,7 +39,6 @@ import org.apache.flink.runtime.jobmanager.slots.AllocatedSlot; import org.apache.flink.runtime.jobmanager.slots.SlotOwner; import org.apache.flink.runtime.jobmanager.slots.TaskManagerGateway; -import org.apache.flink.runtime.state.TaskStateHandles; import org.apache.flink.runtime.taskmanager.TaskManagerLocation; import org.apache.flink.runtime.testingUtils.TestingUtils; import org.apache.flink.runtime.testtasks.NoOpInvokable; @@ -51,8 +51,10 @@ import java.util.Iterator; import java.util.concurrent.TimeUnit; -import static org.mockito.Mockito.*; -import static org.junit.Assert.*; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.mock; /** * Tests that the execution vertex handles locality preferences well. @@ -169,7 +171,7 @@ public void testLocalityBasedOnState() throws Exception { // target state ExecutionVertex target = graph.getAllVertices().get(targetVertexId).getTaskVertices()[i]; - target.getCurrentExecutionAttempt().setInitialState(mock(TaskStateHandles.class)); + target.getCurrentExecutionAttempt().setInitialState(mock(TaskStateSnapshot.class)); } // validate that the target vertices have the state's location as the location preference diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java index 38964168e1576..23f0a389076c3 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java @@ -61,7 +61,6 @@ import org.apache.flink.runtime.leaderelection.TestingLeaderRetrievalService; import org.apache.flink.runtime.messages.JobManagerMessages; import org.apache.flink.runtime.metrics.MetricRegistry; -import org.apache.flink.runtime.state.TaskStateHandles; import org.apache.flink.runtime.state.memory.ByteStreamStateHandle; import org.apache.flink.runtime.taskmanager.TaskManager; import org.apache.flink.runtime.testingUtils.TestingJobManager; @@ -552,10 +551,10 @@ public static class BlockingStatefulInvokable extends BlockingInvokable implemen @Override public void setInitialState( - TaskStateHandles taskStateHandles) throws Exception { + TaskStateSnapshot taskStateHandles) throws Exception { int subtaskIndex = getIndexInSubtaskGroup(); if (subtaskIndex < recoveredStates.length) { - try (FSDataInputStream in = taskStateHandles.getLegacyOperatorState().get(0).openInputStream()) { + try (FSDataInputStream in = taskStateHandles.getSubtaskStateMappings().iterator().next().getValue().getLegacyOperatorState().openInputStream()) { recoveredStates[subtaskIndex] = InstantiationUtil.deserializeObject(in, getUserCodeClassLoader()); } } @@ -570,12 +569,7 @@ public boolean triggerCheckpoint(CheckpointMetaData checkpointMetaData, Checkpoi TaskStateSnapshot checkpointStateHandles = new TaskStateSnapshot(); checkpointStateHandles.putSubtaskStateByOperatorID( OperatorID.fromJobVertexID(getEnvironment().getJobVertexId()), - new OperatorSubtaskState( - byteStreamStateHandle, - null, - null, - null, - null) + new OperatorSubtaskState(byteStreamStateHandle) ); getEnvironment().acknowledgeCheckpoint( diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java index c6d2fec2f0daf..085a38699772f 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java @@ -27,6 +27,7 @@ import org.apache.flink.runtime.checkpoint.CheckpointMetaData; import org.apache.flink.runtime.checkpoint.CheckpointMetrics; import org.apache.flink.runtime.checkpoint.CheckpointOptions; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.clusterframework.types.AllocationID; import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor; import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor; @@ -49,7 +50,6 @@ import org.apache.flink.runtime.metrics.groups.TaskIOMetricGroup; import org.apache.flink.runtime.metrics.groups.TaskMetricGroup; import org.apache.flink.runtime.query.TaskKvStateRegistry; -import org.apache.flink.runtime.state.TaskStateHandles; import org.apache.flink.runtime.util.TestingTaskManagerRuntimeInfo; import org.apache.flink.util.SerializedValue; @@ -187,7 +187,7 @@ private static Task createTask() throws Exception { Collections.emptyList(), Collections.emptyList(), 0, - new TaskStateHandles(), + new TaskStateSnapshot(), mock(MemoryManager.class), mock(IOManager.class), networkEnvironment, @@ -228,7 +228,7 @@ public void invoke() throws Exception { } @Override - public void setInitialState(TaskStateHandles taskStateHandles) throws Exception {} + public void setInitialState(TaskStateSnapshot taskStateHandles) throws Exception {} @Override public boolean triggerCheckpoint(CheckpointMetaData checkpointMetaData, CheckpointOptions checkpointOptions) { diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskStopTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskStopTest.java index 40678de125424..1ebd4adf73883 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskStopTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskStopTest.java @@ -20,39 +20,41 @@ import org.apache.flink.api.common.JobID; import org.apache.flink.api.common.TaskInfo; import org.apache.flink.configuration.Configuration; +import org.apache.flink.runtime.broadcast.BroadcastVariableManager; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.clusterframework.types.AllocationID; import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor; import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor; -import org.apache.flink.runtime.executiongraph.JobInformation; -import org.apache.flink.runtime.executiongraph.TaskInformation; -import org.apache.flink.runtime.io.network.netty.PartitionProducerStateChecker; -import org.apache.flink.runtime.io.network.partition.ResultPartitionConsumableNotifier; -import org.apache.flink.runtime.jobgraph.JobVertexID; -import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider; -import org.apache.flink.runtime.metrics.groups.TaskIOMetricGroup; -import org.apache.flink.runtime.metrics.groups.TaskMetricGroup; -import org.apache.flink.runtime.broadcast.BroadcastVariableManager; import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor; import org.apache.flink.runtime.execution.ExecutionState; import org.apache.flink.runtime.execution.librarycache.LibraryCacheManager; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; +import org.apache.flink.runtime.executiongraph.JobInformation; +import org.apache.flink.runtime.executiongraph.TaskInformation; import org.apache.flink.runtime.filecache.FileCache; import org.apache.flink.runtime.io.disk.iomanager.IOManager; import org.apache.flink.runtime.io.network.NetworkEnvironment; +import org.apache.flink.runtime.io.network.netty.PartitionProducerStateChecker; +import org.apache.flink.runtime.io.network.partition.ResultPartitionConsumableNotifier; +import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable; +import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider; import org.apache.flink.runtime.jobgraph.tasks.StoppableTask; import org.apache.flink.runtime.memory.MemoryManager; -import org.apache.flink.runtime.state.TaskStateHandles; +import org.apache.flink.runtime.metrics.groups.TaskIOMetricGroup; +import org.apache.flink.runtime.metrics.groups.TaskMetricGroup; + import org.junit.Test; import org.junit.runner.RunWith; import org.powermock.core.classloader.annotations.PrepareForTest; import org.powermock.modules.junit4.PowerMockRunner; -import scala.concurrent.duration.FiniteDuration; import java.lang.reflect.Field; import java.util.Collections; import java.util.concurrent.Executor; +import scala.concurrent.duration.FiniteDuration; + import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -88,7 +90,7 @@ public void doMocking(AbstractInvokable taskMock) throws Exception { Collections.emptyList(), Collections.emptyList(), 0, - mock(TaskStateHandles.class), + mock(TaskStateSnapshot.class), mock(MemoryManager.class), mock(IOManager.class), mock(NetworkEnvironment.class), diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java index 324bc8c6bcbcd..a72b9fe3491e7 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java @@ -36,6 +36,7 @@ import org.apache.flink.metrics.MetricGroup; import org.apache.flink.runtime.checkpoint.CheckpointOptions; import org.apache.flink.runtime.checkpoint.CheckpointOptions.CheckpointType; +import org.apache.flink.runtime.checkpoint.OperatorSubtaskState; import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.metrics.groups.OperatorMetricGroup; import org.apache.flink.runtime.state.AbstractKeyedStateBackend; @@ -61,7 +62,6 @@ import org.apache.flink.streaming.api.watermark.Watermark; import org.apache.flink.streaming.runtime.streamrecord.LatencyMarker; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; -import org.apache.flink.streaming.runtime.tasks.OperatorStateHandles; import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService; import org.apache.flink.streaming.runtime.tasks.StreamTask; import org.apache.flink.util.OutputTag; @@ -208,13 +208,13 @@ public MetricGroup getMetricGroup() { } @Override - public final void initializeState(OperatorStateHandles stateHandles) throws Exception { + public final void initializeState(OperatorSubtaskState stateHandles) throws Exception { Collection keyedStateHandlesRaw = null; Collection operatorStateHandlesRaw = null; Collection operatorStateHandlesBackend = null; - boolean restoring = null != stateHandles; + boolean restoring = (null != stateHandles); initKeyedState(); //TODO we should move the actual initialization of this from StreamTask to this class @@ -266,13 +266,13 @@ public final void initializeState(OperatorStateHandles stateHandles) throws Exce * Can be removed when we remove the APIs for non-repartitionable operator state. */ @Deprecated - private void restoreStreamCheckpointed(OperatorStateHandles stateHandles) throws Exception { + private void restoreStreamCheckpointed(OperatorSubtaskState stateHandles) throws Exception { StreamStateHandle state = stateHandles.getLegacyOperatorState(); if (null != state) { if (this instanceof CheckpointedRestoringOperator) { - LOG.debug("Restore state of task {} in chain ({}).", - stateHandles.getOperatorChainIndex(), getContainingTask().getName()); + LOG.debug("Restore state of task {} in operator with id ({}).", + getContainingTask().getName(), getOperatorID()); FSDataInputStream is = state.openInputStream(); try { diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java index 3c26f50ebaf0e..9d5e02b1a3d37 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java @@ -20,11 +20,11 @@ import org.apache.flink.annotation.PublicEvolving; import org.apache.flink.metrics.MetricGroup; import org.apache.flink.runtime.checkpoint.CheckpointOptions; +import org.apache.flink.runtime.checkpoint.OperatorSubtaskState; import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.state.StreamStateHandle; import org.apache.flink.streaming.api.graph.StreamConfig; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; -import org.apache.flink.streaming.runtime.tasks.OperatorStateHandles; import org.apache.flink.streaming.runtime.tasks.StreamTask; import java.io.Serializable; @@ -124,7 +124,7 @@ StreamStateHandle snapshotLegacyOperatorState( * * @param stateHandles state handles to the operator state. */ - void initializeState(OperatorStateHandles stateHandles) throws Exception; + void initializeState(OperatorSubtaskState stateHandles) throws Exception; /** * Called when the checkpoint with the given ID is completed and acknowledged on the JobManager. diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java index 4a6a4fb7144a2..70b1d78ad85a7 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java @@ -37,7 +37,6 @@ import org.apache.flink.runtime.jobgraph.tasks.StatefulTask; import org.apache.flink.runtime.state.AbstractKeyedStateBackend; import org.apache.flink.runtime.state.AbstractStateBackend; -import org.apache.flink.runtime.state.ChainedStateHandle; import org.apache.flink.runtime.state.CheckpointStreamFactory; import org.apache.flink.runtime.state.KeyGroupRange; import org.apache.flink.runtime.state.KeyedStateHandle; @@ -46,7 +45,6 @@ import org.apache.flink.runtime.state.StateBackend; import org.apache.flink.runtime.state.StateUtil; import org.apache.flink.runtime.state.StreamStateHandle; -import org.apache.flink.runtime.state.TaskStateHandles; import org.apache.flink.runtime.taskmanager.DispatcherThreadFactory; import org.apache.flink.streaming.api.TimeCharacteristic; import org.apache.flink.streaming.api.graph.StreamConfig; @@ -56,7 +54,6 @@ import org.apache.flink.streaming.runtime.io.RecordWriterOutput; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.streaming.runtime.streamstatus.StreamStatusMaintainer; -import org.apache.flink.util.CollectionUtil; import org.apache.flink.util.ExceptionUtils; import org.apache.flink.util.FutureUtil; import org.apache.flink.util.Preconditions; @@ -68,7 +65,6 @@ import java.io.IOException; import java.util.Collection; import java.util.HashMap; -import java.util.List; import java.util.Map; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; @@ -159,7 +155,7 @@ public abstract class StreamTask> /** The map of user-defined accumulators of this task. */ private Map> accumulatorMap; - private TaskStateHandles restoreStateHandles; + private TaskStateSnapshot taskStateSnapshot; /** The currently active background materialization threads. */ private final CloseableRegistry cancelables = new CloseableRegistry(); @@ -509,8 +505,8 @@ RecordWriterOutput[] getStreamOutputs() { // ------------------------------------------------------------------------ @Override - public void setInitialState(TaskStateHandles taskStateHandles) { - this.restoreStateHandles = taskStateHandles; + public void setInitialState(TaskStateSnapshot taskStateHandles) { + this.taskStateSnapshot = taskStateHandles; } @Override @@ -659,12 +655,11 @@ private void checkpointState( private void initializeState() throws Exception { - boolean restored = null != restoreStateHandles; + boolean restored = null != taskStateSnapshot; if (restored) { - checkRestorePreconditions(operatorChain.getChainLength()); initializeOperators(true); - restoreStateHandles = null; // free for GC + taskStateSnapshot = null; // free for GC } else { initializeOperators(false); } @@ -675,9 +670,8 @@ private void initializeOperators(boolean restored) throws Exception { for (int chainIdx = 0; chainIdx < allOperators.length; ++chainIdx) { StreamOperator operator = allOperators[chainIdx]; if (null != operator) { - if (restored && restoreStateHandles != null) { - operator.initializeState(restoreStateHandles.getStateByOperatorID(operator.getOperatorID())); - operator.initializeState(new OperatorStateHandles(restoreStateHandles, chainIdx)); + if (restored && taskStateSnapshot != null) { + operator.initializeState(taskStateSnapshot.getSubtaskStateByOperatorID(operator.getOperatorID())); } else { operator.initializeState(null); } @@ -685,26 +679,6 @@ private void initializeOperators(boolean restored) throws Exception { } } - private void checkRestorePreconditions(int operatorChainLength) { - - ChainedStateHandle nonPartitionableOperatorStates = - restoreStateHandles.getLegacyOperatorState(); - List> operatorStates = - restoreStateHandles.getManagedOperatorState(); - - if (nonPartitionableOperatorStates != null) { - Preconditions.checkState(nonPartitionableOperatorStates.getLength() == operatorChainLength, - "Invalid Invalid number of operator states. Found :" + nonPartitionableOperatorStates.getLength() - + ". Expected: " + operatorChainLength); - } - - if (!CollectionUtil.isNullOrEmpty(operatorStates)) { - Preconditions.checkArgument(operatorStates.size() == operatorChainLength, - "Invalid number of operator states. Found :" + operatorStates.size() + - ". Expected: " + operatorChainLength); - } - } - // ------------------------------------------------------------------------ // State backend // ------------------------------------------------------------------------ @@ -770,8 +744,13 @@ public AbstractKeyedStateBackend createKeyedStateBackend( cancelables.registerClosable(keyedStateBackend); // restore if we have some old state - Collection restoreKeyedStateHandles = - restoreStateHandles == null ? null : restoreStateHandles.getManagedKeyedState(); + Collection restoreKeyedStateHandles = null; + + if (taskStateSnapshot != null) { + OperatorSubtaskState stateByOperatorID = + taskStateSnapshot.getSubtaskStateByOperatorID(headOperator.getOperatorID()); + restoreKeyedStateHandles = stateByOperatorID != null ? stateByOperatorID.getManagedKeyedState() : null; + } keyedStateBackend.restore(restoreKeyedStateHandles); diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/async/AsyncWaitOperatorTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/async/AsyncWaitOperatorTest.java index abfb5bccce200..51941d04cc01b 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/async/AsyncWaitOperatorTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/async/AsyncWaitOperatorTest.java @@ -38,7 +38,6 @@ import org.apache.flink.runtime.metrics.groups.TaskMetricGroup; import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider; import org.apache.flink.runtime.operators.testutils.UnregisteredTaskMetricsGroup; -import org.apache.flink.runtime.state.TaskStateHandles; import org.apache.flink.runtime.taskmanager.TaskManagerRuntimeInfo; import org.apache.flink.runtime.util.TestingTaskManagerRuntimeInfo; import org.apache.flink.streaming.api.datastream.AsyncDataStream; @@ -63,7 +62,6 @@ import org.apache.flink.streaming.runtime.tasks.StreamTask; import org.apache.flink.streaming.runtime.tasks.TestProcessingTimeService; import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness; -import org.apache.flink.streaming.util.OperatorIDMappedStateToChainConverter; import org.apache.flink.streaming.util.TestHarnessUtil; import org.apache.flink.util.Preconditions; import org.apache.flink.util.TestLogger; @@ -544,8 +542,7 @@ public void testStateSnapshotAndRestore() throws Exception { // set the operator state from previous attempt into the restored one final OneInputStreamTask restoredTask = new OneInputStreamTask<>(); TaskStateSnapshot subtaskStates = env.getCheckpointStateHandles(); - TaskStateHandles stateHandles = OperatorIDMappedStateToChainConverter.convert(subtaskStates, streamConfig, 1); - restoredTask.setInitialState(stateHandles); + restoredTask.setInitialState(subtaskStates); final OneInputStreamTaskTestHarness restoredTaskHarness = new OneInputStreamTaskTestHarness<>(restoredTask, BasicTypeInfo.INT_TYPE_INFO, BasicTypeInfo.INT_TYPE_INFO); diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierBufferTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierBufferTest.java index c2cf7f3f91ed5..491b23d17b057 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierBufferTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierBufferTest.java @@ -23,6 +23,7 @@ import org.apache.flink.runtime.checkpoint.CheckpointMetaData; import org.apache.flink.runtime.checkpoint.CheckpointMetrics; import org.apache.flink.runtime.checkpoint.CheckpointOptions; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.checkpoint.decline.CheckpointDeclineOnCancellationBarrierException; import org.apache.flink.runtime.checkpoint.decline.CheckpointDeclineSubsumedException; import org.apache.flink.runtime.io.disk.iomanager.IOManager; @@ -34,7 +35,6 @@ import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler; import org.apache.flink.runtime.io.network.partition.consumer.BufferOrEvent; import org.apache.flink.runtime.jobgraph.tasks.StatefulTask; -import org.apache.flink.runtime.state.TaskStateHandles; import org.hamcrest.BaseMatcher; import org.hamcrest.Description; @@ -1484,7 +1484,7 @@ long getLastReportedBytesBufferedInAlignment() { } @Override - public void setInitialState(TaskStateHandles taskStateHandles) throws Exception { + public void setInitialState(TaskStateSnapshot taskStateHandles) throws Exception { throw new UnsupportedOperationException("should never be called"); } diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierTrackerTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierTrackerTest.java index 847db5cec006f..cde90104b91ac 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierTrackerTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierTrackerTest.java @@ -22,13 +22,13 @@ import org.apache.flink.runtime.checkpoint.CheckpointMetaData; import org.apache.flink.runtime.checkpoint.CheckpointMetrics; import org.apache.flink.runtime.checkpoint.CheckpointOptions; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.io.network.api.CancelCheckpointMarker; import org.apache.flink.runtime.io.network.api.CheckpointBarrier; import org.apache.flink.runtime.io.network.buffer.Buffer; import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler; import org.apache.flink.runtime.io.network.partition.consumer.BufferOrEvent; import org.apache.flink.runtime.jobgraph.tasks.StatefulTask; -import org.apache.flink.runtime.state.TaskStateHandles; import org.junit.Test; @@ -498,7 +498,7 @@ private CheckpointSequenceValidator(long... checkpointIDs) { } @Override - public void setInitialState(TaskStateHandles taskStateHandles) throws Exception { + public void setInitialState(TaskStateSnapshot taskStateHandles) throws Exception { throw new UnsupportedOperationException("should never be called"); } diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java index 691f0de4a8b16..6a0b8f33927e3 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java @@ -26,6 +26,8 @@ import org.apache.flink.core.testutils.OneShotLatch; import org.apache.flink.runtime.blob.BlobKey; import org.apache.flink.runtime.broadcast.BroadcastVariableManager; +import org.apache.flink.runtime.checkpoint.OperatorSubtaskState; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.clusterframework.types.AllocationID; import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor; import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor; @@ -45,7 +47,6 @@ import org.apache.flink.runtime.memory.MemoryManager; import org.apache.flink.runtime.operators.testutils.UnregisteredTaskMetricsGroup; import org.apache.flink.runtime.query.TaskKvStateRegistry; -import org.apache.flink.runtime.state.ChainedStateHandle; import org.apache.flink.runtime.state.DefaultOperatorStateBackend; import org.apache.flink.runtime.state.FunctionInitializationContext; import org.apache.flink.runtime.state.FunctionSnapshotContext; @@ -56,7 +57,6 @@ import org.apache.flink.runtime.state.OperatorStateHandle; import org.apache.flink.runtime.state.StateInitializationContext; import org.apache.flink.runtime.state.StreamStateHandle; -import org.apache.flink.runtime.state.TaskStateHandles; import org.apache.flink.runtime.taskmanager.CheckpointResponder; import org.apache.flink.runtime.taskmanager.Task; import org.apache.flink.runtime.taskmanager.TaskManagerActions; @@ -189,11 +189,11 @@ private static Task createTask( when(networkEnvironment.createKvStateTaskRegistry(any(JobID.class), any(JobVertexID.class))) .thenReturn(mock(TaskKvStateRegistry.class)); - ChainedStateHandle operatorState = null; - List keyedStateFromBackend = Collections.emptyList(); - List keyedStateFromStream = Collections.emptyList(); - List> operatorStateBackend = Collections.emptyList(); - List> operatorStateStream = Collections.emptyList(); + StreamStateHandle operatorState = null; + Collection keyedStateFromBackend = Collections.emptyList(); + Collection keyedStateFromStream = Collections.emptyList(); + Collection operatorStateBackend = Collections.emptyList(); + Collection operatorStateStream = Collections.emptyList(); Map operatorStateMetadata = new HashMap<>(1); OperatorStateHandle.StateMetaInfo metaInfo = @@ -210,10 +210,10 @@ private static Task createTask( switch (mode) { case OPERATOR_MANAGED: - operatorStateBackend = Collections.singletonList(operatorStateHandles); + operatorStateBackend = operatorStateHandles; break; case OPERATOR_RAW: - operatorStateStream = Collections.singletonList(operatorStateHandles); + operatorStateStream = operatorStateHandles; break; case KEYED_MANAGED: keyedStateFromBackend = keyedStateHandles; @@ -222,19 +222,22 @@ private static Task createTask( keyedStateFromStream = keyedStateHandles; break; case LEGACY: - operatorState = new ChainedStateHandle<>(Collections.singletonList(state)); + operatorState = state; break; default: throw new IllegalArgumentException(); } - TaskStateHandles taskStateHandles = new TaskStateHandles( + OperatorSubtaskState operatorSubtaskState = new OperatorSubtaskState( operatorState, operatorStateBackend, operatorStateStream, keyedStateFromBackend, keyedStateFromStream); + JobVertexID jobVertexID = new JobVertexID(); + TaskStateSnapshot stateSnapshot = new TaskStateSnapshot(); + stateSnapshot.putSubtaskStateByOperatorID(OperatorID.fromJobVertexID(jobVertexID), operatorSubtaskState); JobInformation jobInformation = new JobInformation( new JobID(), "test job name", @@ -244,7 +247,7 @@ private static Task createTask( Collections.emptyList()); TaskInformation taskInformation = new TaskInformation( - new JobVertexID(), + jobVertexID, "test task name", 1, 1, @@ -261,7 +264,7 @@ private static Task createTask( Collections.emptyList(), Collections.emptyList(), 0, - taskStateHandles, + stateSnapshot, mock(MemoryManager.class), mock(IOManager.class), networkEnvironment, diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java index 84a7ef7af126e..8309afb82348f 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java @@ -41,7 +41,6 @@ import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider; import org.apache.flink.runtime.state.StateInitializationContext; import org.apache.flink.runtime.state.StateSnapshotContext; -import org.apache.flink.runtime.state.TaskStateHandles; import org.apache.flink.streaming.api.collector.selector.OutputSelector; import org.apache.flink.streaming.api.graph.StreamConfig; import org.apache.flink.streaming.api.graph.StreamEdge; @@ -54,7 +53,6 @@ import org.apache.flink.streaming.runtime.partitioner.BroadcastPartitioner; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.streaming.runtime.streamstatus.StreamStatus; -import org.apache.flink.streaming.util.OperatorIDMappedStateToChainConverter; import org.apache.flink.streaming.util.TestHarnessUtil; import org.apache.flink.util.InstantiationUtil; import org.apache.flink.util.Preconditions; @@ -598,10 +596,7 @@ public void testSnapshottingAndRestoring() throws Exception { TaskStateSnapshot stateHandles = env.getCheckpointStateHandles(); Assert.assertEquals(numberChainedTasks, stateHandles.getSubtaskStateMappings().size()); - TaskStateHandles taskStateHandles = - OperatorIDMappedStateToChainConverter.convert(stateHandles, restoredTaskStreamConfig, numberChainedTasks); - - restoredTask.setInitialState(taskStateHandles); + restoredTask.setInitialState(stateHandles); TestingStreamOperator.numberRestoreCalls = 0; diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java index cab91d8f879b5..09e9a1b26144c 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java @@ -67,7 +67,6 @@ import org.apache.flink.runtime.state.OperatorStateHandle; import org.apache.flink.runtime.state.StateBackendFactory; import org.apache.flink.runtime.state.StreamStateHandle; -import org.apache.flink.runtime.state.TaskStateHandles; import org.apache.flink.runtime.taskmanager.CheckpointResponder; import org.apache.flink.runtime.taskmanager.Task; import org.apache.flink.runtime.taskmanager.TaskExecutionState; @@ -130,6 +129,7 @@ import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyCollectionOf; import static org.mockito.Matchers.anyLong; import static org.mockito.Matchers.anyString; import static org.mockito.Matchers.eq; @@ -533,10 +533,10 @@ public Object answer(InvocationOnMock invocation) throws Throwable { OperatorSubtaskState subtaskState = subtaskStates.getSubtaskStateMappings().iterator().next().getValue(); // check that the subtask state contains the expected state handles - assertEquals(managedKeyedStateHandle, subtaskState.getManagedKeyedState()); - assertEquals(rawKeyedStateHandle, subtaskState.getRawKeyedState()); - assertEquals(managedOperatorStateHandle, subtaskState.getManagedOperatorState()); - assertEquals(rawOperatorStateHandle, subtaskState.getRawOperatorState()); + assertEquals(Collections.singletonList(managedKeyedStateHandle), subtaskState.getManagedKeyedState()); + assertEquals(Collections.singletonList(rawKeyedStateHandle), subtaskState.getRawKeyedState()); + assertEquals(Collections.singletonList(managedOperatorStateHandle), subtaskState.getManagedOperatorState()); + assertEquals(Collections.singletonList(rawOperatorStateHandle), subtaskState.getRawOperatorState()); // check that the state handles have not been discarded verify(managedKeyedStateHandle, never()).discardState(); @@ -578,8 +578,15 @@ public void testAsyncCheckpointingConcurrentCloseBeforeAcknowledge() throws Exce Environment mockEnvironment = mock(Environment.class); when(mockEnvironment.getTaskInfo()).thenReturn(mockTaskInfo); - whenNew(OperatorSubtaskState.class).withAnyArguments().thenAnswer(new Answer() { - @Override + whenNew(OperatorSubtaskState.class). + withArguments( + any(StreamStateHandle.class), + anyCollectionOf(OperatorStateHandle.class), + anyCollectionOf(OperatorStateHandle.class), + anyCollectionOf(KeyedStateHandle.class), + anyCollectionOf(KeyedStateHandle.class)). + thenAnswer(new Answer() { + @Override public OperatorSubtaskState answer(InvocationOnMock invocation) throws Throwable { createSubtask.trigger(); completeSubtask.await(); @@ -829,7 +836,7 @@ public static Task createTask( Collections.emptyList(), Collections.emptyList(), 0, - new TaskStateHandles(), + new TaskStateSnapshot(), mock(MemoryManager.class), mock(IOManager.class), network, diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/AbstractStreamOperatorTestHarness.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/AbstractStreamOperatorTestHarness.java index b1a7d69d880dd..15802353ab5c5 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/AbstractStreamOperatorTestHarness.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/AbstractStreamOperatorTestHarness.java @@ -32,6 +32,7 @@ import org.apache.flink.migration.util.MigrationInstantiationUtil; import org.apache.flink.runtime.checkpoint.CheckpointOptions; import org.apache.flink.runtime.checkpoint.OperatorStateRepartitioner; +import org.apache.flink.runtime.checkpoint.OperatorSubtaskState; import org.apache.flink.runtime.checkpoint.RoundRobinOperatorStateRepartitioner; import org.apache.flink.runtime.checkpoint.StateAssignmentOperation; import org.apache.flink.runtime.execution.Environment; @@ -338,7 +339,7 @@ public void initializeStateFromLegacyCheckpoint(String checkpointFilename) throw } /** - * Calls {@link org.apache.flink.streaming.api.operators.StreamOperator#initializeState(OperatorStateHandles)}. + * Calls {@link org.apache.flink.streaming.api.operators.StreamOperator#initializeState(OperatorSubtaskState)}. * Calls {@link org.apache.flink.streaming.api.operators.StreamOperator#setup(StreamTask, StreamConfig, Output)} * if it was not called before. * @@ -395,13 +396,12 @@ public void initializeState(OperatorStateHandles operatorStateHandles) throws Ex rawOperatorState, numSubtasks).get(subtaskIndex); - OperatorStateHandles massagedOperatorStateHandles = new OperatorStateHandles( - 0, - operatorStateHandles.getLegacyOperatorState(), - localManagedKeyGroupState, - localRawKeyGroupState, - localManagedOperatorState, - localRawOperatorState); + OperatorSubtaskState massagedOperatorStateHandles = new OperatorSubtaskState( + operatorStateHandles.getLegacyOperatorState(), + nullToEmptyCollection(localManagedOperatorState), + nullToEmptyCollection(localRawOperatorState), + nullToEmptyCollection(localManagedKeyGroupState), + nullToEmptyCollection(localRawKeyGroupState)); operator.initializeState(massagedOperatorStateHandles); } else { @@ -410,6 +410,10 @@ public void initializeState(OperatorStateHandles operatorStateHandles) throws Ex initializeCalled = true; } + private static Collection nullToEmptyCollection(Collection collection) { + return collection != null ? collection : Collections.emptyList(); + } + /** * Takes the different {@link OperatorStateHandles} created by calling {@link #snapshot(long, long)} * on different instances of {@link AbstractStreamOperatorTestHarness} (each one representing one subtask) diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/OperatorIDMappedStateToChainConverter.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/OperatorIDMappedStateToChainConverter.java deleted file mode 100644 index 0bb3ddf08b0a9..0000000000000 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/OperatorIDMappedStateToChainConverter.java +++ /dev/null @@ -1,91 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.streaming.util; - -import org.apache.flink.runtime.checkpoint.OperatorSubtaskState; -import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; -import org.apache.flink.runtime.jobgraph.OperatorID; -import org.apache.flink.runtime.state.ChainedStateHandle; -import org.apache.flink.runtime.state.OperatorStateHandle; -import org.apache.flink.runtime.state.StreamStateHandle; -import org.apache.flink.runtime.state.TaskStateHandles; -import org.apache.flink.streaming.api.graph.StreamConfig; - -import java.util.ArrayList; -import java.util.Collection; -import java.util.Collections; -import java.util.List; -import java.util.Map; - -/** - * Utility to convert state between operator id mapped and chain mapped. - */ -public class OperatorIDMappedStateToChainConverter { - - public static TaskStateHandles convert( - TaskStateSnapshot subtaskStates, - StreamConfig streamConfig, - int chainLength) { - - List operatorIDsInChainOrder = new ArrayList<>(chainLength); - operatorIDsInChainOrder.add(streamConfig.getOperatorID()); - Map chainedTaskConfigs = - streamConfig.getTransitiveChainedTaskConfigs(streamConfig.getClass().getClassLoader()); - for (int i = 1; i < chainLength; ++i) { - operatorIDsInChainOrder.add(chainedTaskConfigs.get(i).getOperatorID()); - } - return convert(subtaskStates, operatorIDsInChainOrder); - } - - public static TaskStateHandles convert(TaskStateSnapshot subtaskStates, List operatorIDsInChainOrder) { - final int chainLength = operatorIDsInChainOrder.size(); - - List legacyStateChain = new ArrayList<>(chainLength); - List> managedOpState = new ArrayList<>(chainLength); - List> rawOpState = new ArrayList<>(chainLength); - - for (int i = 1; i < chainLength; ++i) { - OperatorSubtaskState subtaskState = subtaskStates.getSubtaskStateByOperatorID(operatorIDsInChainOrder.get(i)); - legacyStateChain.add(subtaskState.getLegacyOperatorState()); - managedOpState.add(singletonListOrNull(subtaskState.getManagedOperatorState())); - rawOpState.add(singletonListOrNull(subtaskState.getRawOperatorState())); - } - - OperatorSubtaskState subtaskState = subtaskStates.getSubtaskStateByOperatorID(operatorIDsInChainOrder.get(0)); - legacyStateChain.add(subtaskState.getLegacyOperatorState()); - managedOpState.add(singletonListOrNull(subtaskState.getManagedOperatorState())); - rawOpState.add(singletonListOrNull(subtaskState.getRawOperatorState())); - - ChainedStateHandle legacyChainedStateHandle = new ChainedStateHandle<>(legacyStateChain); - - TaskStateHandles taskStateHandles = new TaskStateHandles( - legacyChainedStateHandle, - managedOpState, - rawOpState, - singletonListOrNull(subtaskState.getManagedKeyedState()), - singletonListOrNull(subtaskState.getRawKeyedState()) - ); - - return taskStateHandles; - } - - private static List singletonListOrNull(T item) { - return item != null ? Collections.singletonList(item) : null; - } -} diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/SavepointITCase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/SavepointITCase.java index a3d45dd48f9ad..a2729feedb313 100644 --- a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/SavepointITCase.java +++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/SavepointITCase.java @@ -352,7 +352,7 @@ protected void run() { errMsg = "Initial operator state mismatch."; assertEquals(errMsg, subtaskState.getLegacyOperatorState(), - tdd.getTaskStateHandles().getLegacyOperatorState().get(chainIndexAndJobVertex.f0)); + tdd.getTaskStateHandles().getSubtaskStateByOperatorID(operatorState.getOperatorID()).getLegacyOperatorState()); } } From bce928fcfec73ff7584840ae7eb6b31fb727604f Mon Sep 17 00:00:00 2001 From: Stefan Richter Date: Tue, 25 Jul 2017 12:14:03 +0200 Subject: [PATCH 3/5] review comments zentol --- .../contrib/streaming/state/RocksDBAsyncSnapshotTest.java | 6 +++--- .../flink/runtime/checkpoint/OperatorSubtaskState.java | 6 ++++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBAsyncSnapshotTest.java b/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBAsyncSnapshotTest.java index 3d56b5e26ac16..c752e53ccd17c 100644 --- a/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBAsyncSnapshotTest.java +++ b/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBAsyncSnapshotTest.java @@ -168,16 +168,16 @@ public void acknowledgeCheckpoint( throw new RuntimeException(e); } - boolean hasKeyedManagedKeyedState = false; + boolean hasManagedKeyedState = false; for (Map.Entry entry : checkpointStateHandles.getSubtaskStateMappings()) { OperatorSubtaskState state = entry.getValue(); if (state != null) { - hasKeyedManagedKeyedState |= state.getManagedKeyedState() != null; + hasManagedKeyedState |= state.getManagedKeyedState() != null; } } // should be one k/v state - assertTrue(hasKeyedManagedKeyedState); + assertTrue(hasManagedKeyedState); // we now know that the checkpoint went through ensureCheckpointLatch.trigger(); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorSubtaskState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorSubtaskState.java index d4c79ebe4dace..05d34984c709c 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorSubtaskState.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorSubtaskState.java @@ -43,15 +43,18 @@ * This class encapsulates the state for one parallel instance of an operator. The complete state of a (logical) * operator (e.g. a flatmap operator) consists of the union of all {@link OperatorSubtaskState}s from all * parallel tasks that physically execute parallelized, physical instances of the operator. + * *

The full state of the logical operator is represented by {@link OperatorState} which consists of * {@link OperatorSubtaskState}s. + * *

Typically, we expect all collections in this class to be of size 0 or 1, because there up to one state handle * produced per state type (e.g. managed-keyed, raw-operator, ...). In particular, this holds when taking a snapshot. * The purpose of having the state handles in collections is that this class is also reused in restoring state. * Under normal circumstances, the expected size of each collection is still 0 or 1, except for scale-down. In * scale-down, one operator subtask can become responsible for the state of multiple previous subtasks. The collections * can then store all the state handles that are relevant to build up the new subtask state. - *

There is no collection for legacy state because it is nor rescalable. + * + *

There is no collection for legacy state because it is not rescalable. */ public class OperatorSubtaskState implements CompositeStateHandle { @@ -269,7 +272,6 @@ public long getStateSize() { // -------------------------------------------------------------------------------------------- - @Override public boolean equals(Object o) { if (this == o) { From 363f0ee18e06affe95c30095ed229ca8dfd47801 Mon Sep 17 00:00:00 2001 From: Stefan Richter Date: Wed, 26 Jul 2017 13:31:30 +0200 Subject: [PATCH 4/5] review comments zentol part 2 --- .../flink/runtime/state/TaskStateHandles.java | 177 ------------------ 1 file changed, 177 deletions(-) delete mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskStateHandles.java diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskStateHandles.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskStateHandles.java deleted file mode 100644 index 9a00e68854435..0000000000000 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskStateHandles.java +++ /dev/null @@ -1,177 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.runtime.state; - -import org.apache.flink.runtime.checkpoint.SubtaskState; -import org.apache.flink.util.Preconditions; - -import java.io.Serializable; -import java.util.ArrayList; -import java.util.Collection; -import java.util.Collections; -import java.util.List; - -/** - * This class encapsulates all state handles for a task. - */ -public class TaskStateHandles implements Serializable { - - public static final TaskStateHandles EMPTY = new TaskStateHandles(); - - private static final long serialVersionUID = 267686583583579359L; - - /** - * State handle with the (non-partitionable) legacy operator state - * - * @deprecated Non-repartitionable operator state that has been deprecated. - * Can be removed when we remove the APIs for non-repartitionable operator state. - */ - @Deprecated - private final ChainedStateHandle legacyOperatorState; - - /** Collection of handles which represent the managed keyed state of the head operator */ - private final Collection managedKeyedState; - - /** Collection of handles which represent the raw/streamed keyed state of the head operator */ - private final Collection rawKeyedState; - - /** Outer list represents the operator chain, each collection holds handles for managed state of a single operator */ - private final List> managedOperatorState; - - /** Outer list represents the operator chain, each collection holds handles for raw/streamed state of a single operator */ - private final List> rawOperatorState; - - public TaskStateHandles() { - this(null, - Collections.>emptyList(), - Collections.>emptyList(), - Collections.emptyList(), - Collections.emptyList()); - } - - public TaskStateHandles(SubtaskState checkpointStateHandles) { - this(checkpointStateHandles.getLegacyOperatorState(), - transform(checkpointStateHandles.getManagedOperatorState()), - transform(checkpointStateHandles.getRawOperatorState()), - transform(checkpointStateHandles.getManagedKeyedState()), - transform(checkpointStateHandles.getRawKeyedState())); - } - - public TaskStateHandles( - ChainedStateHandle legacyOperatorState, - List> managedOperatorState, - List> rawOperatorState, - Collection managedKeyedState, - Collection rawKeyedState) { - - this.legacyOperatorState = legacyOperatorState; - this.managedKeyedState = Preconditions.checkNotNull(managedKeyedState); - this.rawKeyedState = Preconditions.checkNotNull(rawKeyedState); - this.managedOperatorState = Preconditions.checkNotNull(managedOperatorState); - this.rawOperatorState = Preconditions.checkNotNull(rawOperatorState); - } - - /** - * @deprecated Non-repartitionable operator state that has been deprecated. - * Can be removed when we remove the APIs for non-repartitionable operator state. - */ - @Deprecated - public ChainedStateHandle getLegacyOperatorState() { - return legacyOperatorState; - } - - public Collection getManagedKeyedState() { - return managedKeyedState; - } - - public Collection getRawKeyedState() { - return rawKeyedState; - } - - public List> getRawOperatorState() { - return rawOperatorState; - } - - public List> getManagedOperatorState() { - return managedOperatorState; - } - - private static List> transform(ChainedStateHandle in) { - if (null == in) { - return Collections.emptyList(); - } - List> out = new ArrayList<>(in.getLength()); - for (int i = 0; i < in.getLength(); ++i) { - OperatorStateHandle osh = in.get(i); - out.add(osh != null ? Collections.singletonList(osh) : null); - } - return out; - } - - private static List transform(T in) { - return in == null ? Collections.emptyList() : Collections.singletonList(in); - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - - TaskStateHandles that = (TaskStateHandles) o; - - if (legacyOperatorState != null ? - !legacyOperatorState.equals(that.legacyOperatorState) - : that.legacyOperatorState != null) { - return false; - } - if (managedKeyedState != null ? - !managedKeyedState.equals(that.managedKeyedState) - : that.managedKeyedState != null) { - return false; - } - if (rawKeyedState != null ? - !rawKeyedState.equals(that.rawKeyedState) - : that.rawKeyedState != null) { - return false; - } - - if (rawOperatorState != null ? - !rawOperatorState.equals(that.rawOperatorState) - : that.rawOperatorState != null) { - return false; - } - return managedOperatorState != null ? - managedOperatorState.equals(that.managedOperatorState) - : that.managedOperatorState == null; - } - - @Override - public int hashCode() { - int result = legacyOperatorState != null ? legacyOperatorState.hashCode() : 0; - result = 31 * result + (managedKeyedState != null ? managedKeyedState.hashCode() : 0); - result = 31 * result + (rawKeyedState != null ? rawKeyedState.hashCode() : 0); - result = 31 * result + (managedOperatorState != null ? managedOperatorState.hashCode() : 0); - result = 31 * result + (rawOperatorState != null ? rawOperatorState.hashCode() : 0); - return result; - } -} From 98e657ea02c17d972391dd2360c287a00f27231e Mon Sep 17 00:00:00 2001 From: Stefan Richter Date: Wed, 26 Jul 2017 13:31:47 +0200 Subject: [PATCH 5/5] review comments zentol part 2 --- .../checkpoint/CheckpointCoordinator.java | 5 ++--- .../checkpoint/OperatorSubtaskState.java | 17 ++++++----------- .../RoundRobinOperatorStateRepartitioner.java | 2 +- .../runtime/checkpoint/TaskStateSnapshot.java | 6 ++++-- .../checkpoint/CheckpointCoordinatorTest.java | 13 ------------- .../CheckpointStateRestoreTest.java | 2 -- .../runtime/tasks/OperatorStateHandles.java | 19 ------------------- .../streaming/runtime/tasks/StreamTask.java | 2 +- 8 files changed, 14 insertions(+), 52 deletions(-) diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java index 2c9c902d7cdcd..667525ca6838f 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java @@ -42,7 +42,6 @@ import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint; import org.apache.flink.runtime.messages.checkpoint.DeclineCheckpoint; import org.apache.flink.runtime.state.SharedStateRegistry; -import org.apache.flink.runtime.state.TaskStateHandles; import org.apache.flink.runtime.taskmanager.DispatcherThreadFactory; import org.apache.flink.util.Preconditions; import org.apache.flink.util.StringUtils; @@ -1018,7 +1017,7 @@ int getNumScheduledTasks() { * Restores the latest checkpointed state. * * @param tasks Map of job vertices to restore. State for these vertices is - * restored via {@link Execution#setInitialState(TaskStateHandles)}. + * restored via {@link Execution#setInitialState(TaskStateSnapshot)}. * @param errorIfNoCheckpoint Fail if no completed checkpoint is available to * restore from. * @param allowNonRestoredState Allow checkpoint state that cannot be mapped @@ -1104,7 +1103,7 @@ public boolean restoreLatestCheckpointedState( * mapped to any job vertex in tasks. * @param tasks Map of job vertices to restore. State for these * vertices is restored via - * {@link Execution#setInitialState(TaskStateHandles)}. + * {@link Execution#setInitialState(TaskStateSnapshot)}. * @param userClassLoader The class loader to resolve serialized classes in * legacy savepoint versions. */ diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorSubtaskState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorSubtaskState.java index 05d34984c709c..296b5ab29dd32 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorSubtaskState.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorSubtaskState.java @@ -47,7 +47,7 @@ *

The full state of the logical operator is represented by {@link OperatorState} which consists of * {@link OperatorSubtaskState}s. * - *

Typically, we expect all collections in this class to be of size 0 or 1, because there up to one state handle + *

Typically, we expect all collections in this class to be of size 0 or 1, because there is up to one state handle * produced per state type (e.g. managed-keyed, raw-operator, ...). In particular, this holds when taking a snapshot. * The purpose of having the state handles in collections is that this class is also reused in restoring state. * Under normal circumstances, the expected size of each collection is still 0 or 1, except for scale-down. In @@ -106,16 +106,11 @@ public class OperatorSubtaskState implements CompositeStateHandle { @VisibleForTesting public OperatorSubtaskState(StreamStateHandle legacyOperatorState) { - this.legacyOperatorState = legacyOperatorState; - this.managedOperatorState = Collections.emptyList(); - this.rawOperatorState = Collections.emptyList(); - this.managedKeyedState = Collections.emptyList(); - this.rawKeyedState = Collections.emptyList(); - try { - this.stateSize = getSizeNullSafe(legacyOperatorState); - } catch (Exception e) { - throw new RuntimeException("Failed to get state size.", e); - } + this(legacyOperatorState, + Collections.emptyList(), + Collections.emptyList(), + Collections.emptyList(), + Collections.emptyList()); } /** diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/RoundRobinOperatorStateRepartitioner.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/RoundRobinOperatorStateRepartitioner.java index 5bf9115756f91..4513ef80b32b1 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/RoundRobinOperatorStateRepartitioner.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/RoundRobinOperatorStateRepartitioner.java @@ -89,7 +89,7 @@ private GroupByStateNameResults groupByStateName( for (OperatorStateHandle psh : previousParallelSubtaskStates) { - if(psh == null) { + if (psh == null) { continue; } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskStateSnapshot.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskStateSnapshot.java index d464423134459..c416f3f641c10 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskStateSnapshot.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskStateSnapshot.java @@ -34,10 +34,12 @@ * register their state under their operator id. Each operator instance is a physical execution responsible for * processing a partition of the data that goes through a logical operator. This partitioning happens to parallelize * execution of logical operators, e.g. distributing a map function. - *

One instance of this class contains the information that one task will send to acknowledge a checkpoint request by t - * he checkpoint coordinator. Tasks run operator instances in parallel, so the union of all + * + *

One instance of this class contains the information that one task will send to acknowledge a checkpoint request by + * the checkpoint coordinator. Tasks run operator instances in parallel, so the union of all * {@link TaskStateSnapshot} that are collected by the checkpoint coordinator from all tasks represent the whole * state of a job at the time of the checkpoint. + * *

This class should be called TaskState once the old class with this name that we keep for backwards * compatibility goes away. */ diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java index 6be660781bafd..7b7fcc2338aa6 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java @@ -3672,17 +3672,4 @@ public void testSavepointsAreNotAddedToCompletedCheckpointStore() throws Excepti "The latest completed (proper) checkpoint should have been added to the completed checkpoint store.", completedCheckpointStore.getLatestCheckpoint().getCheckpointID() == checkpointIDCounter.getLast()); } - -// private static final class SpyInjectingOperatorState extends OperatorState { -// -// private static final long serialVersionUID = -4004437428483663815L; -// -// public SpyInjectingOperatorState(OperatorID taskID, int parallelism, int maxParallelism) { -// super(taskID, parallelism, maxParallelism); -// } -// -// public void putState(int subtaskIndex, OperatorSubtaskState subtaskState) { -// super.putState(subtaskIndex, (subtaskState != null) ? spy(subtaskState) : null); -// } -// } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java index 53150350bb6b1..6ce071b2269bd 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java @@ -129,8 +129,6 @@ public void testSetState() { Collections.singletonList(serializedKeyGroupStates), Collections.emptyList())); - //SubtaskState checkpointStateHandles = new SubtaskState(serializedState, null, null, serializedKeyGroupStates, null); - coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec1.getAttemptId(), checkpointId, new CheckpointMetrics(), subtaskStates)); coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec2.getAttemptId(), checkpointId, new CheckpointMetrics(), subtaskStates)); coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec3.getAttemptId(), checkpointId, new CheckpointMetrics(), subtaskStates)); diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/OperatorStateHandles.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/OperatorStateHandles.java index 1a79f5429c206..4914075d0136d 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/OperatorStateHandles.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/OperatorStateHandles.java @@ -20,13 +20,10 @@ import org.apache.flink.annotation.Internal; import org.apache.flink.annotation.VisibleForTesting; -import org.apache.flink.runtime.state.ChainedStateHandle; import org.apache.flink.runtime.state.KeyedStateHandle; import org.apache.flink.runtime.state.OperatorStateHandle; import org.apache.flink.runtime.state.StreamStateHandle; -import org.apache.flink.runtime.state.TaskStateHandles; import org.apache.flink.util.CollectionUtil; -import org.apache.flink.util.Preconditions; import java.util.Collection; import java.util.List; @@ -63,22 +60,6 @@ public OperatorStateHandles( this.rawOperatorState = rawOperatorState; } - public OperatorStateHandles(TaskStateHandles taskStateHandles, int operatorChainIndex) { - Preconditions.checkNotNull(taskStateHandles); - - this.operatorChainIndex = operatorChainIndex; - - ChainedStateHandle legacyState = taskStateHandles.getLegacyOperatorState(); - this.legacyOperatorState = ChainedStateHandle.isNullOrEmpty(legacyState) ? - null : legacyState.get(operatorChainIndex); - - this.rawKeyedState = taskStateHandles.getRawKeyedState(); - this.managedKeyedState = taskStateHandles.getManagedKeyedState(); - - this.managedOperatorState = getSafeItemAtIndexOrNull(taskStateHandles.getManagedOperatorState(), operatorChainIndex); - this.rawOperatorState = getSafeItemAtIndexOrNull(taskStateHandles.getRawOperatorState(), operatorChainIndex); - } - public StreamStateHandle getLegacyOperatorState() { return legacyOperatorState; } diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java index 70b1d78ad85a7..cb8639b0dbe60 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java @@ -892,7 +892,7 @@ public void run() { if (asyncCheckpointState.compareAndSet(CheckpointingOperation.AsynCheckpointState.RUNNING, CheckpointingOperation.AsynCheckpointState.COMPLETED)) { - // we signal a stateless task by reporting null, so that there are no attempts to assign empty state + // we signal stateless tasks by reporting null, so that there are no attempts to assign empty state // to stateless tasks on restore. This enables simple job modifications that only concern // stateless without the need to assign them uids to match their (always empty) states. owner.getEnvironment().acknowledgeCheckpoint(