diff --git a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java index f67daab8f5ad2..e722b902584b0 100644 --- a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java +++ b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java @@ -768,7 +768,10 @@ private static final class RocksDBIncrementalSnapshotOperation { /** The state meta data. */ private final List> stateMetaInfoSnapshots = new ArrayList<>(); + /** Local filesystem for the RocksDB backup. */ private FileSystem backupFileSystem; + + /** Local path for the RocksDB backup. */ private Path backupPath; // Registry for all opened i/o streams @@ -831,11 +834,12 @@ private StreamStateHandle materializeStateData(Path filePath) throws Exception { return result; } finally { - if (inputStream != null && closeableRegistry.unregisterCloseable(inputStream)) { + + if (closeableRegistry.unregisterCloseable(inputStream)) { inputStream.close(); } - if (outputStream != null && closeableRegistry.unregisterCloseable(outputStream)) { + if (closeableRegistry.unregisterCloseable(outputStream)) { outputStream.close(); } } @@ -1041,7 +1045,13 @@ public void restore(Collection restoreState) throws Exception @Override public void notifyCheckpointComplete(long completedCheckpointId) { + + if (!enableIncrementalCheckpointing) { + return; + } + synchronized (materializedSstFiles) { + if (completedCheckpointId < lastCompletedCheckpointId) { return; } @@ -1153,8 +1163,7 @@ private void restoreKeyGroupsInStateHandle() restoreKVStateMetaData(); restoreKVStateData(); } finally { - if (currentStateHandleInStream != null - && rocksDBKeyedStateBackend.cancelStreamRegistry.unregisterCloseable(currentStateHandleInStream)) { + if (rocksDBKeyedStateBackend.cancelStreamRegistry.unregisterCloseable(currentStateHandleInStream)) { IOUtils.closeQuietly(currentStateHandleInStream); } } @@ -1318,7 +1327,7 @@ private RocksDBIncrementalRestoreOperation(RocksDBKeyedStateBackend stateBack return serializationProxy.getStateMetaInfoSnapshots(); } finally { - if (inputStream != null && stateBackend.cancelStreamRegistry.unregisterCloseable(inputStream)) { + if (stateBackend.cancelStreamRegistry.unregisterCloseable(inputStream)) { inputStream.close(); } } @@ -1350,11 +1359,11 @@ private void readStateData( outputStream.write(buffer, 0, numBytes); } } finally { - if (inputStream != null && stateBackend.cancelStreamRegistry.unregisterCloseable(inputStream)) { + if (stateBackend.cancelStreamRegistry.unregisterCloseable(inputStream)) { inputStream.close(); } - if (outputStream != null && stateBackend.cancelStreamRegistry.unregisterCloseable(outputStream)) { + if (stateBackend.cancelStreamRegistry.unregisterCloseable(outputStream)) { outputStream.close(); } } diff --git a/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBAsyncSnapshotTest.java b/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBAsyncSnapshotTest.java index 2ba0494f4b6ee..4aa46b3f5140f 100644 --- a/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBAsyncSnapshotTest.java +++ b/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBAsyncSnapshotTest.java @@ -34,6 +34,7 @@ import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.execution.CancelTaskException; import org.apache.flink.runtime.execution.Environment; +import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.operators.testutils.DummyEnvironment; import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider; @@ -42,10 +43,12 @@ import org.apache.flink.runtime.state.CheckpointStreamFactory; import org.apache.flink.runtime.state.KeyGroupRange; import org.apache.flink.runtime.state.KeyedStateHandle; +import org.apache.flink.runtime.state.TestTaskStateManager; import org.apache.flink.runtime.state.VoidNamespace; import org.apache.flink.runtime.state.VoidNamespaceSerializer; import org.apache.flink.runtime.state.memory.MemCheckpointStreamFactory; import org.apache.flink.runtime.state.memory.MemoryStateBackend; +import org.apache.flink.runtime.taskmanager.CheckpointResponder; import org.apache.flink.runtime.util.BlockerCheckpointStreamFactory; import org.apache.flink.streaming.api.graph.StreamConfig; import org.apache.flink.streaming.api.operators.AbstractStreamOperator; @@ -132,21 +135,15 @@ public String getKey(String value) throws Exception { final OneShotLatch delayCheckpointLatch = new OneShotLatch(); final OneShotLatch ensureCheckpointLatch = new OneShotLatch(); - StreamMockEnvironment mockEnv = new StreamMockEnvironment( - testHarness.jobConfig, - testHarness.taskConfig, - testHarness.memorySize, - new MockInputSplitProvider(), - testHarness.bufferSize) { + CheckpointResponder checkpointResponderMock = new CheckpointResponder() { @Override public void acknowledgeCheckpoint( - long checkpointId, - CheckpointMetrics checkpointMetrics, - TaskStateSnapshot checkpointStateHandles) { - - super.acknowledgeCheckpoint(checkpointId, checkpointMetrics); - + JobID jobID, + ExecutionAttemptID executionAttemptID, + long checkpointId, + CheckpointMetrics checkpointMetrics, + TaskStateSnapshot subtaskState) { // block on the latch, to verify that triggerCheckpoint returns below, // even though the async checkpoint would not finish try { @@ -156,7 +153,7 @@ public void acknowledgeCheckpoint( } boolean hasManagedKeyedState = false; - for (Map.Entry entry : checkpointStateHandles.getSubtaskStateMappings()) { + for (Map.Entry entry : subtaskState.getSubtaskStateMappings()) { OperatorSubtaskState state = entry.getValue(); if (state != null) { hasManagedKeyedState |= state.getManagedKeyedState() != null; @@ -169,8 +166,30 @@ public void acknowledgeCheckpoint( // we now know that the checkpoint went through ensureCheckpointLatch.trigger(); } + + @Override + public void declineCheckpoint( + JobID jobID, ExecutionAttemptID executionAttemptID, + long checkpointId, Throwable cause) { + + } }; + JobID jobID = new JobID(); + ExecutionAttemptID executionAttemptID = new ExecutionAttemptID(0L, 0L); + TestTaskStateManager taskStateManagerTestMock = new TestTaskStateManager( + jobID, + executionAttemptID, + checkpointResponderMock); + + StreamMockEnvironment mockEnv = new StreamMockEnvironment( + testHarness.jobConfig, + testHarness.taskConfig, + testHarness.memorySize, + new MockInputSplitProvider(), + testHarness.bufferSize, + taskStateManagerTestMock); + testHarness.invoke(mockEnv); // wait for the task to be running @@ -260,12 +279,15 @@ public MemCheckpointStreamFactory.MemoryCheckpointOutputStream createCheckpointS streamConfig.setStreamOperator(new AsyncCheckpointOperator()); streamConfig.setOperatorID(new OperatorID()); + TestTaskStateManager taskStateManagerTestMock = new TestTaskStateManager(); + StreamMockEnvironment mockEnv = new StreamMockEnvironment( testHarness.jobConfig, testHarness.taskConfig, testHarness.memorySize, new MockInputSplitProvider(), - testHarness.bufferSize); + testHarness.bufferSize, + taskStateManagerTestMock); blockerCheckpointStreamFactory.setBlockerLatch(new OneShotLatch()); blockerCheckpointStreamFactory.setWaiterLatch(new OneShotLatch()); diff --git a/flink-contrib/flink-storm/src/test/java/org/apache/flink/storm/wrappers/BoltWrapperTest.java b/flink-contrib/flink-storm/src/test/java/org/apache/flink/storm/wrappers/BoltWrapperTest.java index f518d178b5dfe..5e200bc309a02 100644 --- a/flink-contrib/flink-storm/src/test/java/org/apache/flink/storm/wrappers/BoltWrapperTest.java +++ b/flink-contrib/flink-storm/src/test/java/org/apache/flink/storm/wrappers/BoltWrapperTest.java @@ -23,6 +23,7 @@ import org.apache.flink.api.java.tuple.Tuple1; import org.apache.flink.configuration.Configuration; import org.apache.flink.configuration.UnmodifiableConfiguration; +import org.apache.flink.core.fs.CloseableRegistry; import org.apache.flink.metrics.groups.UnregisteredMetricsGroup; import org.apache.flink.runtime.execution.Environment; import org.apache.flink.runtime.operators.testutils.UnregisteredTaskMetricsGroup; @@ -375,11 +376,13 @@ public Map getComponentConfiguration() { when(env.getMetricGroup()).thenReturn(new UnregisteredTaskMetricsGroup()); when(env.getTaskManagerInfo()).thenReturn(new TestingTaskManagerRuntimeInfo()); + final CloseableRegistry closeableRegistry = new CloseableRegistry(); StreamTask mockTask = mock(StreamTask.class); when(mockTask.getCheckpointLock()).thenReturn(new Object()); when(mockTask.getConfiguration()).thenReturn(new StreamConfig(new Configuration())); when(mockTask.getEnvironment()).thenReturn(env); when(mockTask.getExecutionConfig()).thenReturn(execConfig); + when(mockTask.getCancelables()).thenReturn(closeableRegistry); return mockTask; } diff --git a/flink-core/src/main/java/org/apache/flink/util/CloseableIterable.java b/flink-core/src/main/java/org/apache/flink/util/CloseableIterable.java new file mode 100644 index 0000000000000..a30cf252db3a6 --- /dev/null +++ b/flink-core/src/main/java/org/apache/flink/util/CloseableIterable.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.util; + +import java.io.Closeable; +import java.io.IOException; +import java.util.Collections; +import java.util.Iterator; + +/** + * This interface represents an iterable that is also closeable. + * + * @param type of the iterated objects. + */ +public interface CloseableIterable extends Iterable, Closeable { + + class Empty implements CloseableIterable { + + private Empty() { + } + + @Override + public void close() throws IOException { + + } + + @Override + public Iterator iterator() { + return Collections.emptyIterator(); + } + } + + static CloseableIterable empty() { + return new CloseableIterable.Empty<>(); + } +} diff --git a/flink-core/src/main/java/org/apache/flink/util/Migration.java b/flink-core/src/main/java/org/apache/flink/util/CloseableIterator.java similarity index 58% rename from flink-core/src/main/java/org/apache/flink/util/Migration.java rename to flink-core/src/main/java/org/apache/flink/util/CloseableIterator.java index a82488d5b3c21..042b48aad8847 100644 --- a/flink-core/src/main/java/org/apache/flink/util/Migration.java +++ b/flink-core/src/main/java/org/apache/flink/util/CloseableIterator.java @@ -18,11 +18,38 @@ package org.apache.flink.util; -import org.apache.flink.annotation.Internal; +import java.io.Closeable; +import java.io.IOException; +import java.util.Iterator; /** - * Tagging interface for migration related classes. + * This interface represents an iterator that is also closeable. + * + * @param type of the iterated objects. */ -@Internal -public interface Migration { +public interface CloseableIterator extends Iterator, Closeable { + + class Empty implements CloseableIterator { + + private Empty() { + } + + @Override + public boolean hasNext() { + return false; + } + + @Override + public T next() { + return null; + } + + @Override + public void close() throws IOException { + } + } + + static CloseableIterator empty() { + return new Empty<>(); + } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java index 9a4456ef7d7d0..82b839b4b37c2 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java @@ -1084,7 +1084,7 @@ public boolean restoreLatestCheckpointedState( final Map operatorStates = latest.getOperatorStates(); StateAssignmentOperation stateAssignmentOperation = - new StateAssignmentOperation(tasks, operatorStates, allowNonRestoredState); + new StateAssignmentOperation(latest.getCheckpointID(), tasks, operatorStates, allowNonRestoredState); stateAssignmentOperation.assignStates(); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/JobManagerTaskRestore.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/JobManagerTaskRestore.java new file mode 100644 index 0000000000000..d5ac3e061d128 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/JobManagerTaskRestore.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint; + +import java.io.Serializable; + +/** + * This class encapsulates the data from the job manager to restore a task. + */ +public class JobManagerTaskRestore implements Serializable { + + private static final long serialVersionUID = 1L; + + private final long restoreCheckpointId; + + private final TaskStateSnapshot taskStateSnapshot; + + public JobManagerTaskRestore(long restoreCheckpointId, TaskStateSnapshot taskStateSnapshot) { + this.restoreCheckpointId = restoreCheckpointId; + this.taskStateSnapshot = taskStateSnapshot; + } + + public long getRestoreCheckpointId() { + return restoreCheckpointId; + } + + public TaskStateSnapshot getTaskStateSnapshot() { + return taskStateSnapshot; + } + + @Override + public String toString() { + return "TaskRestore{" + + "restoreCheckpointId=" + restoreCheckpointId + + ", taskStateSnapshot=" + taskStateSnapshot + + '}'; + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java index d80311ca74736..e108bad9c230e 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java @@ -55,13 +55,17 @@ public class StateAssignmentOperation { private final Map tasks; private final Map operatorStates; + + private final long restoreCheckpointId; private final boolean allowNonRestoredState; public StateAssignmentOperation( - Map tasks, - Map operatorStates, - boolean allowNonRestoredState) { + long restoreCheckpointId, + Map tasks, + Map operatorStates, + boolean allowNonRestoredState) { + this.restoreCheckpointId = restoreCheckpointId; this.tasks = Preconditions.checkNotNull(tasks); this.operatorStates = Preconditions.checkNotNull(operatorStates); this.allowNonRestoredState = allowNonRestoredState; @@ -214,7 +218,8 @@ private void assignTaskStateToExecutionJobVertices( } if (!statelessTask) { - currentExecutionAttempt.setInitialState(taskState); + JobManagerTaskRestore taskRestore = new JobManagerTaskRestore(restoreCheckpointId, taskState); + currentExecutionAttempt.setInitialState(taskRestore); } } } @@ -230,7 +235,7 @@ public static OperatorSubtaskState operatorSubtaskStateFrom( !subRawOperatorState.containsKey(instanceID) && !subManagedKeyedState.containsKey(instanceID) && !subRawKeyedState.containsKey(instanceID)) { - + return new OperatorSubtaskState(); } if (!subManagedKeyedState.containsKey(instanceID)) { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptor.java b/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptor.java index 0c7e308bfd7bd..78cb451bcf526 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptor.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptor.java @@ -21,7 +21,7 @@ import org.apache.flink.api.common.JobID; import org.apache.flink.runtime.blob.PermanentBlobKey; import org.apache.flink.runtime.blob.PermanentBlobService; -import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; +import org.apache.flink.runtime.checkpoint.JobManagerTaskRestore; import org.apache.flink.runtime.clusterframework.types.AllocationID; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.executiongraph.JobInformation; @@ -142,8 +142,8 @@ public Offloaded(PermanentBlobKey serializedValueKey) { /** Slot number to run the sub task in on the target machine. */ private final int targetSlotNumber; - /** State handles for the sub task. */ - private final TaskStateSnapshot taskStateHandles; + /** Information to restore the task. */ + private final JobManagerTaskRestore taskRestore; public TaskDeploymentDescriptor( JobID jobId, @@ -154,7 +154,7 @@ public TaskDeploymentDescriptor( int subtaskIndex, int attemptNumber, int targetSlotNumber, - TaskStateSnapshot taskStateHandles, + JobManagerTaskRestore taskRestore, Collection resultPartitionDeploymentDescriptors, Collection inputGateDeploymentDescriptors) { @@ -175,7 +175,7 @@ public TaskDeploymentDescriptor( Preconditions.checkArgument(0 <= targetSlotNumber, "The target slot number must be positive."); this.targetSlotNumber = targetSlotNumber; - this.taskStateHandles = taskStateHandles; + this.taskRestore = taskRestore; this.producedPartitions = Preconditions.checkNotNull(resultPartitionDeploymentDescriptors); this.inputGates = Preconditions.checkNotNull(inputGateDeploymentDescriptors); @@ -263,8 +263,8 @@ public Collection getInputGates() { return inputGates; } - public TaskStateSnapshot getTaskStateHandles() { - return taskStateHandles; + public JobManagerTaskRestore getTaskRestore() { + return taskRestore; } public AllocationID getAllocationId() { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/execution/Environment.java b/flink-runtime/src/main/java/org/apache/flink/runtime/execution/Environment.java index 203ee8547cf42..b4487943db570 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/execution/Environment.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/execution/Environment.java @@ -36,6 +36,7 @@ import org.apache.flink.runtime.memory.MemoryManager; import org.apache.flink.runtime.metrics.groups.TaskMetricGroup; import org.apache.flink.runtime.query.TaskKvStateRegistry; +import org.apache.flink.runtime.state.TaskStateManager; import org.apache.flink.runtime.state.internal.InternalKvState; import org.apache.flink.runtime.taskmanager.TaskManagerRuntimeInfo; @@ -144,6 +145,8 @@ public interface Environment { BroadcastVariableManager getBroadcastVariableManager(); + TaskStateManager getTaskStateManager(); + /** * Return the registry for accumulators which are periodically sent to the job manager. * @return the registry diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java index 38c382108befe..381ea6f3d5a82 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java @@ -25,7 +25,7 @@ import org.apache.flink.runtime.JobException; import org.apache.flink.runtime.accumulators.StringifiedAccumulatorResult; import org.apache.flink.runtime.checkpoint.CheckpointOptions; -import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; +import org.apache.flink.runtime.checkpoint.JobManagerTaskRestore; import org.apache.flink.runtime.clusterframework.types.ResourceID; import org.apache.flink.runtime.concurrent.FutureUtils; import org.apache.flink.runtime.deployment.InputChannelDeploymentDescriptor; @@ -145,8 +145,8 @@ public class Execution implements AccessExecution, Archiveable> getPreferredLocations( */ public Collection> getPreferredLocationsBasedOnState() { TaskManagerLocation priorLocation; - if (currentExecution.getTaskStateSnapshot() != null && (priorLocation = getLatestPriorLocation()) != null) { + if (currentExecution.getTaskRestore() != null && (priorLocation = getLatestPriorLocation()) != null) { return Collections.singleton(CompletableFuture.completedFuture(priorLocation)); } else { @@ -745,7 +745,7 @@ void notifyStateTransition(Execution execution, ExecutionState newState, Throwab TaskDeploymentDescriptor createDeploymentDescriptor( ExecutionAttemptID executionId, SimpleSlot targetSlot, - TaskStateSnapshot taskStateHandles, + JobManagerTaskRestore taskRestore, int attemptNumber) throws ExecutionGraphException { // Produced intermediate results @@ -833,7 +833,7 @@ TaskDeploymentDescriptor createDeploymentDescriptor( subTaskIndex, attemptNumber, targetSlot.getRoot().getSlotNumber(), - taskStateHandles, + taskRestore, producedPartitions, consumedPartitions); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/tasks/StatefulTask.java b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/tasks/StatefulTask.java index 00db01ffd2e04..7ecb461eb2916 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/tasks/StatefulTask.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/tasks/StatefulTask.java @@ -21,7 +21,6 @@ import org.apache.flink.runtime.checkpoint.CheckpointMetaData; import org.apache.flink.runtime.checkpoint.CheckpointMetrics; import org.apache.flink.runtime.checkpoint.CheckpointOptions; -import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; /** * This interface must be implemented by any invokable that has recoverable state and participates @@ -29,14 +28,6 @@ */ public interface StatefulTask { - /** - * Sets the initial state of the operator, upon recovery. The initial state is typically - * a snapshot of the state from a previous execution. - * - * @param taskStateHandles All state handle for the task. - */ - void setInitialState(TaskStateSnapshot taskStateHandles) throws Exception; - /** * This method is called to trigger a checkpoint, asynchronously by the checkpoint * coordinator. diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java index fea537b6ebd7e..90ec8b98e9d4b 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java @@ -138,7 +138,7 @@ private StreamCompressionDecorator determineStreamCompression(ExecutionConfig ex @Override public void dispose() { - IOUtils.closeQuietly(this); + IOUtils.closeQuietly(cancelStreamRegistry); if (kvStateRegistry != null) { kvStateRegistry.unregisterAll(); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/CheckpointListener.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/CheckpointListener.java index f0e2b79d718fe..0c99316d932c9 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/CheckpointListener.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/CheckpointListener.java @@ -18,11 +18,14 @@ package org.apache.flink.runtime.state; +import org.apache.flink.annotation.PublicEvolving; + /** * This interface must be implemented by functions/operations that want to receive * a commit notification once a checkpoint has been completely acknowledged by all * participants. */ +@PublicEvolving public interface CheckpointListener { /** diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/CheckpointStreamFactory.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/CheckpointStreamFactory.java index 199a856c8dc9b..53334319d5ea0 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/CheckpointStreamFactory.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/CheckpointStreamFactory.java @@ -19,10 +19,11 @@ import org.apache.flink.core.fs.FSDataOutputStream; +import java.io.Closeable; import java.io.IOException; import java.io.OutputStream; -public interface CheckpointStreamFactory { +public interface CheckpointStreamFactory extends Closeable { /** * Creates an new {@link CheckpointStateOutputStream}. When the stream @@ -43,9 +44,9 @@ CheckpointStateOutputStream createCheckpointStateOutputStream( * Closes the stream factory, releasing all internal resources, but does not delete any * persistent checkpoint data. * - * @throws Exception Exceptions can be forwarded and will be logged by the system + * @throws IOException Exceptions can be forwarded and will be logged by the system */ - void close() throws Exception; + void close() throws IOException; /** * A dedicated output stream that produces a {@link StreamStateHandle} when closed. diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java index 9edf8fcf83d40..2a7df6c7ce937 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java @@ -147,7 +147,7 @@ public void close() throws IOException { @Override public void dispose() { - IOUtils.closeQuietly(this); + IOUtils.closeQuietly(closeStreamOnCancelRegistry); registeredStates.clear(); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/MultiStreamStateHandle.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/MultiStreamStateHandle.java deleted file mode 100644 index 1960c1c95f431..0000000000000 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/MultiStreamStateHandle.java +++ /dev/null @@ -1,104 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.runtime.state; - -import org.apache.flink.core.fs.AbstractMultiFSDataInputStream; -import org.apache.flink.core.fs.FSDataInputStream; -import org.apache.flink.util.Preconditions; - -import java.io.IOException; -import java.util.List; -import java.util.Map; -import java.util.TreeMap; - -/** - * Wrapper class that takes multiple {@link StreamStateHandle} and makes them look like a single one. This is done by - * providing a contiguous view on all the streams of the inner handles through a wrapper stream and by summing up all - * all the meta data. - */ -public class MultiStreamStateHandle implements StreamStateHandle { - - private static final long serialVersionUID = -4588701089489569707L; - private final List stateHandles; - private final long stateSize; - - public MultiStreamStateHandle(List stateHandles) { - this.stateHandles = Preconditions.checkNotNull(stateHandles); - long calculateSize = 0L; - for(StreamStateHandle stateHandle : stateHandles) { - calculateSize += stateHandle.getStateSize(); - } - this.stateSize = calculateSize; - } - - @Override - public FSDataInputStream openInputStream() throws IOException { - return new MultiFSDataInputStream(stateHandles); - } - - @Override - public void discardState() throws Exception { - StateUtil.bestEffortDiscardAllStateObjects(stateHandles); - } - - @Override - public long getStateSize() { - return stateSize; - } - - @Override - public String toString() { - return "MultiStreamStateHandle{" + - "stateHandles=" + stateHandles + - ", stateSize=" + stateSize + - '}'; - } - - static final class MultiFSDataInputStream extends AbstractMultiFSDataInputStream { - - private final TreeMap stateHandleMap; - - public MultiFSDataInputStream(List stateHandles) throws IOException { - this.stateHandleMap = new TreeMap<>(); - this.totalPos = 0L; - long calculateSize = 0L; - for (StreamStateHandle stateHandle : stateHandles) { - stateHandleMap.put(calculateSize, stateHandle); - calculateSize += stateHandle.getStateSize(); - } - this.totalAvailable = calculateSize; - - if (totalAvailable > 0L) { - StreamStateHandle first = stateHandleMap.firstEntry().getValue(); - delegate = first.openInputStream(); - } - } - - @Override - protected FSDataInputStream getSeekedStreamForOffset(long globalStreamOffset) throws IOException { - Map.Entry handleEntry = stateHandleMap.floorEntry(globalStreamOffset); - if (handleEntry != null) { - FSDataInputStream stream = handleEntry.getValue().openInputStream(); - stream.seek(globalStreamOffset - handleEntry.getKey()); - return stream; - } - return null; - } - } -} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateInitializationContextImpl.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateInitializationContextImpl.java index 750d2066e19bc..e1e0b0d2dd48a 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateInitializationContextImpl.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateInitializationContextImpl.java @@ -20,63 +20,34 @@ import org.apache.flink.api.common.state.KeyedStateStore; import org.apache.flink.api.common.state.OperatorStateStore; -import org.apache.flink.api.java.tuple.Tuple2; -import org.apache.flink.core.fs.CloseableRegistry; -import org.apache.flink.core.fs.FSDataInputStream; -import org.apache.flink.util.Preconditions; - -import org.apache.commons.io.IOUtils; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.Collection; -import java.util.Collections; -import java.util.Iterator; -import java.util.List; -import java.util.NoSuchElementException; /** * Default implementation of {@link StateInitializationContext}. */ public class StateInitializationContextImpl implements StateInitializationContext { - /** Closable registry to participate in the operator's cancel/close methods */ - private final CloseableRegistry closableRegistry; - /** Signal whether any state to restore was found */ private final boolean restored; private final OperatorStateStore operatorStateStore; - private final Collection operatorStateHandles; private final KeyedStateStore keyedStateStore; - private final Collection keyGroupsStateHandles; - private final Iterable keyedStateIterable; + private final Iterable rawKeyedStateInputs; + private final Iterable rawOperatorStateInputs; public StateInitializationContextImpl( boolean restored, OperatorStateStore operatorStateStore, KeyedStateStore keyedStateStore, - Collection keyedStateHandles, - Collection operatorStateHandles, - CloseableRegistry closableRegistry) { + Iterable rawKeyedStateInputs, + Iterable rawOperatorStateInputs) { this.restored = restored; - this.closableRegistry = Preconditions.checkNotNull(closableRegistry); this.operatorStateStore = operatorStateStore; this.keyedStateStore = keyedStateStore; - this.operatorStateHandles = operatorStateHandles; - this.keyGroupsStateHandles = transform(keyedStateHandles); - - this.keyedStateIterable = keyGroupsStateHandles == null ? - null - : new Iterable() { - @Override - public Iterator iterator() { - return new KeyGroupStreamIterator(getKeyGroupsStateHandles().iterator(), getClosableRegistry()); - } - }; + this.rawOperatorStateInputs = rawOperatorStateInputs; + this.rawKeyedStateInputs = rawKeyedStateInputs; } @Override @@ -84,32 +55,9 @@ public boolean isRestored() { return restored; } - public Collection getOperatorStateHandles() { - return operatorStateHandles; - } - - public Collection getKeyGroupsStateHandles() { - return keyGroupsStateHandles; - } - - public CloseableRegistry getClosableRegistry() { - return closableRegistry; - } - @Override public Iterable getRawOperatorStateInputs() { - if (null != operatorStateHandles) { - return new Iterable() { - @Override - public Iterator iterator() { - return new OperatorStateStreamIterator( - DefaultOperatorStateBackend.DEFAULT_OPERATOR_STATE_NAME, - getOperatorStateHandles().iterator(), getClosableRegistry()); - } - }; - } else { - return Collections.emptyList(); - } + return rawOperatorStateInputs; } @Override @@ -118,11 +66,7 @@ public Iterable getRawKeyedStateInputs() { throw new IllegalStateException("Attempt to access keyed state from non-keyed operator."); } - if (null != keyGroupsStateHandles) { - return keyedStateIterable; - } else { - return Collections.emptyList(); - } + return rawKeyedStateInputs; } @Override @@ -134,199 +78,4 @@ public OperatorStateStore getOperatorStateStore() { public KeyedStateStore getKeyedStateStore() { return keyedStateStore; } - - public void close() { - IOUtils.closeQuietly(closableRegistry); - } - - private static Collection transform(Collection keyedStateHandles) { - - if (keyedStateHandles == null) { - return null; - } - - List keyGroupsStateHandles = new ArrayList<>(); - - for (KeyedStateHandle keyedStateHandle : keyedStateHandles) { - - if (keyedStateHandle instanceof KeyGroupsStateHandle) { - keyGroupsStateHandles.add((KeyGroupsStateHandle) keyedStateHandle); - } else if (keyedStateHandle != null) { - throw new IllegalStateException("Unexpected state handle type, " + - "expected: " + KeyGroupsStateHandle.class + - ", but found: " + keyedStateHandle.getClass() + "."); - } - } - - return keyGroupsStateHandles; - } - - private static class KeyGroupStreamIterator - extends AbstractStateStreamIterator { - - private Iterator> currentOffsetsIterator; - - public KeyGroupStreamIterator( - Iterator stateHandleIterator, CloseableRegistry closableRegistry) { - - super(stateHandleIterator, closableRegistry); - } - - @Override - public boolean hasNext() { - - if (null != currentStateHandle && currentOffsetsIterator.hasNext()) { - - return true; - } - - closeCurrentStream(); - - while (stateHandleIterator.hasNext()) { - currentStateHandle = stateHandleIterator.next(); - if (currentStateHandle.getKeyGroupRange().getNumberOfKeyGroups() > 0) { - currentOffsetsIterator = currentStateHandle.getGroupRangeOffsets().iterator(); - - return true; - } - } - - return false; - } - - @Override - public KeyGroupStatePartitionStreamProvider next() { - - if (!hasNext()) { - - throw new NoSuchElementException("Iterator exhausted"); - } - - Tuple2 keyGroupOffset = currentOffsetsIterator.next(); - try { - if (null == currentStream) { - openCurrentStream(); - } - currentStream.seek(keyGroupOffset.f1); - - return new KeyGroupStatePartitionStreamProvider(currentStream, keyGroupOffset.f0); - } catch (IOException ioex) { - - return new KeyGroupStatePartitionStreamProvider(ioex, keyGroupOffset.f0); - } - } - } - - private static class OperatorStateStreamIterator - extends AbstractStateStreamIterator { - - private final String stateName; //TODO since we only support a single named state in raw, this could be dropped - private long[] offsets; - private int offPos; - - public OperatorStateStreamIterator( - String stateName, - Iterator stateHandleIterator, - CloseableRegistry closableRegistry) { - - super(stateHandleIterator, closableRegistry); - this.stateName = Preconditions.checkNotNull(stateName); - } - - @Override - public boolean hasNext() { - - if (null != offsets && offPos < offsets.length) { - - return true; - } - - closeCurrentStream(); - - while (stateHandleIterator.hasNext()) { - currentStateHandle = stateHandleIterator.next(); - OperatorStateHandle.StateMetaInfo metaInfo = - currentStateHandle.getStateNameToPartitionOffsets().get(stateName); - - if (null != metaInfo) { - long[] metaOffsets = metaInfo.getOffsets(); - if (null != metaOffsets && metaOffsets.length > 0) { - this.offsets = metaOffsets; - this.offPos = 0; - - if (closableRegistry.unregisterCloseable(currentStream)) { - IOUtils.closeQuietly(currentStream); - currentStream = null; - } - - return true; - } - } - } - - return false; - } - - @Override - public StatePartitionStreamProvider next() { - - if (!hasNext()) { - - throw new NoSuchElementException("Iterator exhausted"); - } - - long offset = offsets[offPos++]; - - try { - if (null == currentStream) { - openCurrentStream(); - } - currentStream.seek(offset); - - return new StatePartitionStreamProvider(currentStream); - } catch (IOException ioex) { - - return new StatePartitionStreamProvider(ioex); - } - } - } - - abstract static class AbstractStateStreamIterator - implements Iterator { - - protected final Iterator stateHandleIterator; - protected final CloseableRegistry closableRegistry; - - protected H currentStateHandle; - protected FSDataInputStream currentStream; - - public AbstractStateStreamIterator( - Iterator stateHandleIterator, - CloseableRegistry closableRegistry) { - - this.stateHandleIterator = Preconditions.checkNotNull(stateHandleIterator); - this.closableRegistry = Preconditions.checkNotNull(closableRegistry); - } - - protected void openCurrentStream() throws IOException { - - Preconditions.checkState(currentStream == null); - - FSDataInputStream stream = currentStateHandle.openInputStream(); - closableRegistry.registerCloseable(stream); - currentStream = stream; - } - - protected void closeCurrentStream() { - if (closableRegistry.unregisterCloseable(currentStream)) { - IOUtils.closeQuietly(currentStream); - } - currentStream = null; - } - - @Override - public void remove() { - throw new UnsupportedOperationException("Read only Iterator"); - } - } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateSnapshotContextSynchronousImpl.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateSnapshotContextSynchronousImpl.java index 6a8a08f1b93c1..d24fbf3869ef6 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateSnapshotContextSynchronousImpl.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateSnapshotContextSynchronousImpl.java @@ -122,7 +122,7 @@ public RunnableFuture getOperatorStateStreamFuture() throws private T closeAndUnregisterStreamToObtainStateHandle( NonClosingCheckpointOutputStream stream) throws IOException { - if (null != stream && closableRegistry.unregisterCloseable(stream.getDelegate())) { + if (stream != null && closableRegistry.unregisterCloseable(stream.getDelegate())) { return stream.closeAndGetHandle(); } else { return null; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateUtil.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateUtil.java index 09d195add1fde..a7673c081394c 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateUtil.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateUtil.java @@ -21,6 +21,7 @@ import org.apache.flink.util.FutureUtil; import org.apache.flink.util.LambdaUtil; +import java.util.Collection; import java.util.concurrent.RunnableFuture; /** @@ -71,4 +72,22 @@ public static void discardStateFuture(RunnableFuture stat } } } + + /** + * Discards the given state collection future by first trying to cancel it. If this is not possible, then + * the state object contained in the future is calculated and afterwards discarded. + * + * @param stateCollectionFuture to be discarded + * @throws Exception if the discard operation failed + */ + public static void discardStateCollectionFuture( + RunnableFuture> stateCollectionFuture) throws Exception { + + if (null != stateCollectionFuture) { + if (!stateCollectionFuture.cancel(true)) { + Collection stateObject = FutureUtil.runIfNotDoneAndGet(stateCollectionFuture); + bestEffortDiscardAllStateObjects(stateObject); + } + } + } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskExecutorLocalStateStoresManager.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskExecutorLocalStateStoresManager.java new file mode 100644 index 0000000000000..326b95ce62a89 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskExecutorLocalStateStoresManager.java @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state; + +import org.apache.flink.api.common.JobID; +import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.util.Preconditions; + +import javax.annotation.Nonnull; + +import java.util.HashMap; +import java.util.Map; + +/** + * This class holds the all {@link TaskLocalStateStore} objects for a task executor (manager). + * + * TODO: this still still work in progress and partially still acts as a placeholder. + */ +public class TaskExecutorLocalStateStoresManager { + + /** + * This map holds all local state stores for tasks running on the task manager / executor that own the instance of + * this. + */ + private final Map> taskStateManagers; + + public TaskExecutorLocalStateStoresManager() { + this.taskStateManagers = new HashMap<>(); + } + + public TaskLocalStateStore localStateStoreForTask( + JobID jobId, + JobVertexID jobVertexID, + int subtaskIndex) { + + Preconditions.checkNotNull(jobId); + final JobVertexSubtaskKey taskKey = new JobVertexSubtaskKey(jobVertexID, subtaskIndex); + + final Map taskStateManagers = + this.taskStateManagers.computeIfAbsent(jobId, k -> new HashMap<>()); + + return taskStateManagers.computeIfAbsent( + taskKey, k -> new TaskLocalStateStore(jobId, jobVertexID, subtaskIndex)); + } + + public void releaseJob(JobID jobID) { + + Map cleanupLocalStores = taskStateManagers.remove(jobID); + + if (cleanupLocalStores != null) { + doRelease(cleanupLocalStores.values()); + } + } + + public void releaseAll() { + + for (Map stateStoreMap : taskStateManagers.values()) { + doRelease(stateStoreMap.values()); + } + + taskStateManagers.clear(); + } + + private void doRelease(Iterable toRelease) { + if (toRelease != null) { + for (TaskLocalStateStore stateStore : toRelease) { + stateStore.dispose(); + } + } + } + + private static final class JobVertexSubtaskKey { + + @Nonnull + final JobVertexID jobVertexID; + final int subtaskIndex; + + public JobVertexSubtaskKey(@Nonnull JobVertexID jobVertexID, int subtaskIndex) { + this.jobVertexID = Preconditions.checkNotNull(jobVertexID); + this.subtaskIndex = subtaskIndex; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + JobVertexSubtaskKey that = (JobVertexSubtaskKey) o; + + if (subtaskIndex != that.subtaskIndex) { + return false; + } + return jobVertexID.equals(that.jobVertexID); + } + + @Override + public int hashCode() { + int result = jobVertexID.hashCode(); + result = 31 * result + subtaskIndex; + return result; + } + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskLocalStateStore.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskLocalStateStore.java new file mode 100644 index 0000000000000..f7436308abe52 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskLocalStateStore.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state; + +import org.apache.flink.api.common.JobID; +import org.apache.flink.runtime.jobgraph.JobVertexID; + +/** + * This class will service as a task-manager-level local storage for local checkpointed state. The purpose is to provide + * access to a state that is stored locally for a faster recovery compared to the state that is stored remotely in a + * stable store DFS. For now, this storage is only complementary to the stable storage and local state is typically + * lost in case of machine failures. In such cases (and others), client code of this class must fall back to using the + * slower but highly available store. + * + * TODO this is currently a placeholder / mock that still must be implemented! + */ +public class TaskLocalStateStore { + + /** */ + private final JobID jobID; + + /** */ + private final JobVertexID jobVertexID; + + /** */ + private final int subtaskIndex; + + public TaskLocalStateStore( + JobID jobID, + JobVertexID jobVertexID, + int subtaskIndex) { + + this.jobID = jobID; + this.jobVertexID = jobVertexID; + this.subtaskIndex = subtaskIndex; + } + + public void storeSnapshot(/* TODO */) { + throw new UnsupportedOperationException("TODO!"); + } + + public void dispose() { + throw new UnsupportedOperationException("TODO!"); + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskStateManager.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskStateManager.java new file mode 100644 index 0000000000000..8b41e9ec1f2e0 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskStateManager.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state; + +import org.apache.flink.runtime.checkpoint.CheckpointMetaData; +import org.apache.flink.runtime.checkpoint.CheckpointMetrics; +import org.apache.flink.runtime.checkpoint.OperatorSubtaskState; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; +import org.apache.flink.runtime.jobgraph.OperatorID; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; + +/** + * This interface provides methods to report and retrieve state for a task. + * + *

When a checkpoint or savepoint is triggered on a task, it will create snapshots for all stream operator instances + * it owns. All operator snapshots from the task are then reported via this interface. A typical implementation will + * dispatch and forward the reported state information to interested parties such as the checkpoint coordinator or a + * local state store. + * + *

This interface also offers the complementary method that provides access to previously saved state of operator + * instances in the task for restore purposes. + */ +public interface TaskStateManager extends CheckpointListener { + + /** + * Report the state snapshots for the operator instances running in the owning task. + * + * @param checkpointMetaData meta data from the checkpoint request. + * @param checkpointMetrics task level metrics for the checkpoint. + * @param acknowledgedState the reported states from the owning task. + */ + void reportStateHandles( + @Nonnull CheckpointMetaData checkpointMetaData, + @Nonnull CheckpointMetrics checkpointMetrics, + @Nullable TaskStateSnapshot acknowledgedState); + + /** + * Returns means to restore previously reported state of an operator running in the owning task. + * + * @param operatorID the id of the operator for which we request state. + * @return previous state for the operator. Null if no previous state exists. + */ + OperatorSubtaskState operatorStates(OperatorID operatorID); +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskStateManagerImpl.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskStateManagerImpl.java new file mode 100644 index 0000000000000..3aae6e27b91ea --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskStateManagerImpl.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state; + +import org.apache.flink.api.common.JobID; +import org.apache.flink.runtime.checkpoint.CheckpointMetaData; +import org.apache.flink.runtime.checkpoint.CheckpointMetrics; +import org.apache.flink.runtime.checkpoint.JobManagerTaskRestore; +import org.apache.flink.runtime.checkpoint.OperatorSubtaskState; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; +import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; +import org.apache.flink.runtime.jobgraph.OperatorID; +import org.apache.flink.runtime.taskmanager.CheckpointResponder; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; + +/** + * This class is the default implementation of {@link TaskStateManager} and collaborates with the job manager + * through {@link CheckpointResponder}) as well as a task-manager-local state store. Like this, client code does + * not have to deal with the differences between remote or local state on recovery because this class handles both + * cases transparently. + * + * Reported state is tagged by clients so that this class can properly forward to the right receiver for the + * checkpointed state. + * + * TODO: all interaction with local state store must still be implemented! It is currently just a placeholder. + */ +public class TaskStateManagerImpl implements TaskStateManager { + + /** The id of the job for which this manager was created, can report, and recover. */ + private final JobID jobId; + + /** The execution attempt id that this manager reports for. */ + private final ExecutionAttemptID executionAttemptID; + + /** The data given by the job manager to restore the job. This is not set for a new job without previous state. */ + private final JobManagerTaskRestore jobManagerTaskRestore; + + /** The local state store to which this manager reports local state snapshots. */ + private final TaskLocalStateStore localStateStore; + + /** The checkpoint responder through which this manager can report to the job manager. */ + private final CheckpointResponder checkpointResponder; + + public TaskStateManagerImpl( + JobID jobId, + ExecutionAttemptID executionAttemptID, + TaskLocalStateStore localStateStore, + JobManagerTaskRestore jobManagerTaskRestore, + CheckpointResponder checkpointResponder) { + + this.jobId = jobId; + this.localStateStore = localStateStore; + this.jobManagerTaskRestore = jobManagerTaskRestore; + this.executionAttemptID = executionAttemptID; + this.checkpointResponder = checkpointResponder; + } + + @Override + public void reportStateHandles( + @Nonnull CheckpointMetaData checkpointMetaData, + @Nonnull CheckpointMetrics checkpointMetrics, + @Nullable TaskStateSnapshot acknowledgedState) { + + checkpointResponder.acknowledgeCheckpoint( + jobId, + executionAttemptID, + checkpointMetaData.getCheckpointId(), + checkpointMetrics, + acknowledgedState); + } + + @Override + public OperatorSubtaskState operatorStates(OperatorID operatorID) { + + if (jobManagerTaskRestore == null) { + return null; + } + + TaskStateSnapshot taskStateSnapshot = jobManagerTaskRestore.getTaskStateSnapshot(); + return taskStateSnapshot.getSubtaskStateByOperatorID(operatorID); + + /* + TODO!!!!!!! + 1) lookup local states for a matching operatorID / checkpointID. + 2) if nothing available: look into job manager provided state. + 3) massage it into a snapshots and return stuff. + */ + } + + /** + * Tracking when local state can be disposed. + */ + @Override + public void notifyCheckpointComplete(long checkpointId) throws Exception { + //TODO activate and prune local state later + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsCheckpointStreamFactory.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsCheckpointStreamFactory.java index a8246518a655c..9be194f901b70 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsCheckpointStreamFactory.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsCheckpointStreamFactory.java @@ -106,7 +106,7 @@ public FsCheckpointStreamFactory( } @Override - public void close() throws Exception {} + public void close() throws IOException {} @Override public FsCheckpointStateOutputStream createCheckpointStateOutputStream(long checkpointID, long timestamp) throws Exception { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemCheckpointStreamFactory.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemCheckpointStreamFactory.java index 3920ce88778cb..31c23bf40e3c4 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemCheckpointStreamFactory.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemCheckpointStreamFactory.java @@ -45,7 +45,7 @@ public MemCheckpointStreamFactory(int maxStateSize) { } @Override - public void close() throws Exception {} + public void close() throws IOException {} @Override public CheckpointStateOutputStream createCheckpointStateOutputStream( diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/TaskExecutor.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/TaskExecutor.java index a348948c7077e..882444b9bf0ec 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/TaskExecutor.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/TaskExecutor.java @@ -25,6 +25,7 @@ import org.apache.flink.runtime.blob.BlobCacheService; import org.apache.flink.runtime.broadcast.BroadcastVariableManager; import org.apache.flink.runtime.checkpoint.CheckpointOptions; +import org.apache.flink.runtime.checkpoint.JobManagerTaskRestore; import org.apache.flink.runtime.clusterframework.types.AllocationID; import org.apache.flink.runtime.clusterframework.types.ResourceID; import org.apache.flink.runtime.clusterframework.types.SlotID; @@ -65,6 +66,10 @@ import org.apache.flink.runtime.rpc.RpcEndpoint; import org.apache.flink.runtime.rpc.RpcService; import org.apache.flink.runtime.rpc.akka.AkkaRpcServiceUtils; +import org.apache.flink.runtime.state.TaskLocalStateStore; +import org.apache.flink.runtime.state.TaskStateManager; +import org.apache.flink.runtime.state.TaskStateManagerImpl; +import org.apache.flink.runtime.state.TaskExecutorLocalStateStoresManager; import org.apache.flink.runtime.taskexecutor.exceptions.CheckpointException; import org.apache.flink.runtime.taskexecutor.exceptions.PartitionException; import org.apache.flink.runtime.taskexecutor.exceptions.SlotAllocationException; @@ -131,6 +136,9 @@ public class TaskExecutor extends RpcEndpoint implements TaskExecutorGateway { /** The memory manager component in the task manager */ private final MemoryManager memoryManager; + /** The state manager for this task, providing state managers per slot. */ + private final TaskExecutorLocalStateStoresManager localStateStoresManager; + /** The network component in the task manager */ private final NetworkEnvironment networkEnvironment; @@ -175,6 +183,7 @@ public TaskExecutor( TaskManagerLocation taskManagerLocation, MemoryManager memoryManager, IOManager ioManager, + TaskExecutorLocalStateStoresManager localStateStoresManager, NetworkEnvironment networkEnvironment, HighAvailabilityServices haServices, HeartbeatServices heartbeatServices, @@ -193,6 +202,7 @@ public TaskExecutor( this.taskManagerConfiguration = checkNotNull(taskManagerConfiguration); this.taskManagerLocation = checkNotNull(taskManagerLocation); this.memoryManager = checkNotNull(memoryManager); + this.localStateStoresManager = checkNotNull(localStateStoresManager); this.ioManager = checkNotNull(ioManager); this.networkEnvironment = checkNotNull(networkEnvironment); this.haServices = checkNotNull(haServices); @@ -380,6 +390,20 @@ public CompletableFuture submitTask( ResultPartitionConsumableNotifier resultPartitionConsumableNotifier = jobManagerConnection.getResultPartitionConsumableNotifier(); PartitionProducerStateChecker partitionStateChecker = jobManagerConnection.getPartitionStateChecker(); + final TaskLocalStateStore localStateStore = localStateStoresManager.localStateStoreForTask( + jobId, + taskInformation.getJobVertexId(), + tdd.getSubtaskIndex()); + + final JobManagerTaskRestore taskRestore = tdd.getTaskRestore(); + + final TaskStateManager taskStateManager = new TaskStateManagerImpl( + jobId, + tdd.getExecutionAttemptId(), + localStateStore, + taskRestore, + checkpointResponder); + Task task = new Task( jobInformation, taskInformation, @@ -390,11 +414,12 @@ public CompletableFuture submitTask( tdd.getProducedPartitions(), tdd.getInputGates(), tdd.getTargetSlotNumber(), - tdd.getTaskStateHandles(), + taskRestore, memoryManager, ioManager, networkEnvironment, broadcastVariableManager, + taskStateManager, taskManagerActions, inputSplitProvider, checkpointResponder, diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/TaskManagerRunner.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/TaskManagerRunner.java index a24daf0f608fa..00f2b931631f6 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/TaskManagerRunner.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/TaskManagerRunner.java @@ -285,6 +285,7 @@ public static TaskExecutor startTaskManager( taskManagerServices.getTaskManagerLocation(), taskManagerServices.getMemoryManager(), taskManagerServices.getIOManager(), + taskManagerServices.getTaskStateManager(), taskManagerServices.getNetworkEnvironment(), highAvailabilityServices, heartbeatServices, diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/TaskManagerServices.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/TaskManagerServices.java index 4daff05d05373..a5e5734484450 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/TaskManagerServices.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/TaskManagerServices.java @@ -41,6 +41,7 @@ import org.apache.flink.runtime.query.KvStateRegistry; import org.apache.flink.runtime.query.KvStateServer; import org.apache.flink.runtime.query.QueryableStateUtils; +import org.apache.flink.runtime.state.TaskExecutorLocalStateStoresManager; import org.apache.flink.runtime.taskexecutor.slot.TaskSlotTable; import org.apache.flink.runtime.taskexecutor.slot.TimerService; import org.apache.flink.runtime.taskmanager.NetworkEnvironmentConfiguration; @@ -74,6 +75,7 @@ public class TaskManagerServices { private final TaskSlotTable taskSlotTable; private final JobManagerTable jobManagerTable; private final JobLeaderService jobLeaderService; + private final TaskExecutorLocalStateStoresManager taskStateManager; private TaskManagerServices( TaskManagerLocation taskManagerLocation, @@ -84,7 +86,8 @@ private TaskManagerServices( FileCache fileCache, TaskSlotTable taskSlotTable, JobManagerTable jobManagerTable, - JobLeaderService jobLeaderService) { + JobLeaderService jobLeaderService, + TaskExecutorLocalStateStoresManager taskStateManager) { this.taskManagerLocation = Preconditions.checkNotNull(taskManagerLocation); this.memoryManager = Preconditions.checkNotNull(memoryManager); @@ -95,6 +98,7 @@ private TaskManagerServices( this.taskSlotTable = Preconditions.checkNotNull(taskSlotTable); this.jobManagerTable = Preconditions.checkNotNull(jobManagerTable); this.jobLeaderService = Preconditions.checkNotNull(jobLeaderService); + this.taskStateManager = Preconditions.checkNotNull(taskStateManager); } // -------------------------------------------------------------------------------------------- @@ -137,6 +141,10 @@ public JobLeaderService getJobLeaderService() { return jobLeaderService; } + public TaskExecutorLocalStateStoresManager getTaskStateManager() { + return taskStateManager; + } + // -------------------------------------------------------------------------------------------- // Static factory methods for task manager services // -------------------------------------------------------------------------------------------- @@ -189,7 +197,7 @@ public static TaskManagerServices fromConfiguration( final JobManagerTable jobManagerTable = new JobManagerTable(); final JobLeaderService jobLeaderService = new JobLeaderService(taskManagerLocation); - + final TaskExecutorLocalStateStoresManager taskStateManager = new TaskExecutorLocalStateStoresManager(); return new TaskManagerServices( taskManagerLocation, memoryManager, @@ -199,7 +207,8 @@ public static TaskManagerServices fromConfiguration( fileCache, taskSlotTable, jobManagerTable, - jobLeaderService); + jobLeaderService, + taskStateManager); } /** diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/RuntimeEnvironment.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/RuntimeEnvironment.java index 92b58868d666f..9220c04ce2b80 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/RuntimeEnvironment.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/RuntimeEnvironment.java @@ -37,6 +37,7 @@ import org.apache.flink.runtime.memory.MemoryManager; import org.apache.flink.runtime.metrics.groups.TaskMetricGroup; import org.apache.flink.runtime.query.TaskKvStateRegistry; +import org.apache.flink.runtime.state.TaskStateManager; import java.util.Map; import java.util.concurrent.Future; @@ -63,6 +64,7 @@ public class RuntimeEnvironment implements Environment { private final MemoryManager memManager; private final IOManager ioManager; private final BroadcastVariableManager bcVarManager; + private final TaskStateManager taskStateManager; private final InputSplitProvider splitProvider; private final Map> distCacheEntries; @@ -95,6 +97,7 @@ public RuntimeEnvironment( MemoryManager memManager, IOManager ioManager, BroadcastVariableManager bcVarManager, + TaskStateManager taskStateManager, AccumulatorRegistry accumulatorRegistry, TaskKvStateRegistry kvStateRegistry, InputSplitProvider splitProvider, @@ -117,6 +120,7 @@ public RuntimeEnvironment( this.memManager = checkNotNull(memManager); this.ioManager = checkNotNull(ioManager); this.bcVarManager = checkNotNull(bcVarManager); + this.taskStateManager = checkNotNull(taskStateManager); this.accumulatorRegistry = checkNotNull(accumulatorRegistry); this.kvStateRegistry = checkNotNull(kvStateRegistry); this.splitProvider = checkNotNull(splitProvider); @@ -196,6 +200,11 @@ public BroadcastVariableManager getBroadcastVariableManager() { return bcVarManager; } + @Override + public TaskStateManager getTaskStateManager() { + return taskStateManager; + } + @Override public AccumulatorRegistry getAccumulatorRegistry() { return accumulatorRegistry; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java index 2cb356c81f116..bd979d0e2addc 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java @@ -35,7 +35,7 @@ import org.apache.flink.runtime.broadcast.BroadcastVariableManager; import org.apache.flink.runtime.checkpoint.CheckpointMetaData; import org.apache.flink.runtime.checkpoint.CheckpointOptions; -import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; +import org.apache.flink.runtime.checkpoint.JobManagerTaskRestore; import org.apache.flink.runtime.checkpoint.decline.CheckpointDeclineTaskNotCheckpointingException; import org.apache.flink.runtime.checkpoint.decline.CheckpointDeclineTaskNotReadyException; import org.apache.flink.runtime.clusterframework.types.AllocationID; @@ -69,6 +69,8 @@ import org.apache.flink.runtime.memory.MemoryManager; import org.apache.flink.runtime.metrics.groups.TaskMetricGroup; import org.apache.flink.runtime.query.TaskKvStateRegistry; +import org.apache.flink.runtime.state.CheckpointListener; +import org.apache.flink.runtime.state.TaskStateManager; import org.apache.flink.util.ExceptionUtils; import org.apache.flink.util.Preconditions; import org.apache.flink.util.SerializedValue; @@ -121,7 +123,7 @@ * *

Each Task is run by one dedicated thread. */ -public class Task implements Runnable, TaskActions { +public class Task implements Runnable, TaskActions, CheckpointListener { /** The class logger. */ private static final Logger LOG = LoggerFactory.getLogger(Task.class); @@ -182,6 +184,9 @@ public class Task implements Runnable, TaskActions { /** The BroadcastVariableManager to be used by this task */ private final BroadcastVariableManager broadcastVariableManager; + /** The manager for state of operators running in this task/slot */ + private final TaskStateManager taskStateManager; + /** Serialized version of the job specific execution configuration (see {@link ExecutionConfig}). */ private final SerializedValue serializedExecutionConfig; @@ -254,10 +259,10 @@ public class Task implements Runnable, TaskActions { private volatile ExecutorService asyncCallDispatcher; /** - * The handles to the states that the task was initialized with. Will be set + * Provides previous state that the task can use for restore. Will be set * to null after the initialization, to be memory friendly. */ - private volatile TaskStateSnapshot taskStateHandles; + private volatile JobManagerTaskRestore taskRestore; /** Initialized from the Flink configuration. May also be set at the ExecutionConfig */ private long taskCancellationInterval; @@ -285,11 +290,12 @@ public Task( Collection resultPartitionDeploymentDescriptors, Collection inputGateDeploymentDescriptors, int targetSlotNumber, - TaskStateSnapshot taskStateHandles, + JobManagerTaskRestore taskRestore, MemoryManager memManager, IOManager ioManager, NetworkEnvironment networkEnvironment, BroadcastVariableManager bcVarManager, + TaskStateManager taskStateManager, TaskManagerActions taskManagerActions, InputSplitProvider inputSplitProvider, CheckpointResponder checkpointResponder, @@ -327,7 +333,7 @@ public Task( this.requiredClasspaths = jobInformation.getRequiredClasspathURLs(); this.nameOfInvokableClass = taskInformation.getInvokableClassName(); this.serializedExecutionConfig = jobInformation.getSerializedExecutionConfig(); - this.taskStateHandles = taskStateHandles; + this.taskRestore = taskRestore; Configuration tmConfig = taskManagerConfig.getConfiguration(); this.taskCancellationInterval = tmConfig.getLong(TaskManagerOptions.TASK_CANCELLATION_INTERVAL); @@ -336,6 +342,7 @@ public Task( this.memoryManager = Preconditions.checkNotNull(memManager); this.ioManager = Preconditions.checkNotNull(ioManager); this.broadcastVariableManager = Preconditions.checkNotNull(bcVarManager); + this.taskStateManager = Preconditions.checkNotNull(taskStateManager); this.accumulatorRegistry = new AccumulatorRegistry(jobId, executionId); this.inputSplitProvider = Preconditions.checkNotNull(inputSplitProvider); @@ -669,7 +676,7 @@ else if (current == ExecutionState.CANCELING) { Environment env = new RuntimeEnvironment( jobId, vertexId, executionId, executionConfig, taskInfo, jobConfiguration, taskConfiguration, userCodeClassLoader, - memoryManager, ioManager, broadcastVariableManager, + memoryManager, ioManager, broadcastVariableManager, taskStateManager, accumulatorRegistry, kvStateRegistry, inputSplitProvider, distributedCacheEntries, writers, inputGates, checkpointResponder, taskManagerConfig, metrics, this); @@ -677,23 +684,6 @@ else if (current == ExecutionState.CANCELING) { // let the task code create its readers and writers invokable.setEnvironment(env); - // the very last thing before the actual execution starts running is to inject - // the state into the task. the state is non-empty if this is an execution - // of a task that failed but had backuped state from a checkpoint - - if (null != taskStateHandles) { - if (invokable instanceof StatefulTask) { - StatefulTask op = (StatefulTask) invokable; - op.setInitialState(taskStateHandles); - } else { - throw new IllegalStateException("Found operator state for a non-stateful task invokable"); - } - // be memory and GC friendly - since the code stays in invoke() for a potentially long time, - // we clear the reference to the state handle - //noinspection UnusedAssignment - taskStateHandles = null; - } - // ---------------------------------------------------------------- // actual task core work // ---------------------------------------------------------------- @@ -1238,6 +1228,7 @@ public void run() { } } + @Override public void notifyCheckpointComplete(final long checkpointID) { AbstractInvokable invokable = this.invokable; @@ -1253,6 +1244,7 @@ public void notifyCheckpointComplete(final long checkpointID) { public void run() { try { statefulTask.notifyCheckpointComplete(checkpointID); + taskStateManager.notifyCheckpointComplete(checkpointID); } catch (Throwable t) { if (getExecutionState() == ExecutionState.RUNNING) { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/util/OperatorSubtaskDescriptionText.java b/flink-runtime/src/main/java/org/apache/flink/runtime/util/OperatorSubtaskDescriptionText.java new file mode 100644 index 0000000000000..510caa1681aae --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/util/OperatorSubtaskDescriptionText.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.util; + +import org.apache.flink.runtime.jobgraph.OperatorID; +import org.apache.flink.util.Preconditions; + +/** + * This class generates a string that can be used to identify an operator subtask. + */ +public class OperatorSubtaskDescriptionText { + + /** Cached description result. */ + private final String description; + + public OperatorSubtaskDescriptionText(OperatorID operatorId, Class operatorClass, int subtaskIndex, int numberOfTasks) { + + Preconditions.checkArgument(numberOfTasks > 0); + Preconditions.checkArgument(subtaskIndex >= 0); + Preconditions.checkArgument(subtaskIndex < numberOfTasks); + + this.description = operatorClass.getSimpleName() + + "_" + operatorId + + "_(" + subtaskIndex + "/" + numberOfTasks + ")"; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + OperatorSubtaskDescriptionText that = (OperatorSubtaskDescriptionText) o; + + return description.equals(that.description); + } + + @Override + public int hashCode() { + return description.hashCode(); + } + + @Override + public String toString() { + return description; + } +} diff --git a/flink-runtime/src/main/scala/org/apache/flink/runtime/taskmanager/TaskManager.scala b/flink-runtime/src/main/scala/org/apache/flink/runtime/taskmanager/TaskManager.scala index f948df4c08ea6..be1b0c6810fff 100644 --- a/flink-runtime/src/main/scala/org/apache/flink/runtime/taskmanager/TaskManager.scala +++ b/flink-runtime/src/main/scala/org/apache/flink/runtime/taskmanager/TaskManager.scala @@ -67,6 +67,7 @@ import org.apache.flink.runtime.metrics.util.MetricUtils import org.apache.flink.runtime.metrics.{MetricRegistryConfiguration, MetricRegistryImpl, MetricRegistry => FlinkMetricRegistry} import org.apache.flink.runtime.process.ProcessReaper import org.apache.flink.runtime.security.{SecurityConfiguration, SecurityUtils} +import org.apache.flink.runtime.state.{TaskExecutorLocalStateStoresManager, TaskStateManagerImpl} import org.apache.flink.runtime.taskexecutor.{TaskExecutor, TaskManagerConfiguration, TaskManagerServices, TaskManagerServicesConfiguration} import org.apache.flink.runtime.util._ import org.apache.flink.runtime.{FlinkActor, LeaderSessionMessageFilter, LogMessages} @@ -1176,6 +1177,21 @@ class TaskManager( config.getTimeout().getSize(), config.getTimeout().getUnit())) + // TODO: wire this so that the manager survives the end of the task + val taskExecutorLocalStateStoresManager = new TaskExecutorLocalStateStoresManager + + val localStateStore = taskExecutorLocalStateStoresManager.localStateStoreForTask( + jobInformation.getJobId, + taskInformation.getJobVertexId, + tdd.getSubtaskIndex) + + val slotStateManager = new TaskStateManagerImpl( + jobInformation.getJobId, + tdd.getExecutionAttemptId, + localStateStore, + tdd.getTaskRestore, + checkpointResponder) + val task = new Task( jobInformation, taskInformation, @@ -1186,11 +1202,12 @@ class TaskManager( tdd.getProducedPartitions, tdd.getInputGates, tdd.getTargetSlotNumber, - tdd.getTaskStateHandles, + tdd.getTaskRestore, memoryManager, ioManager, network, bcVarManager, + slotStateManager, taskManagerConnection, inputSplitProvider, checkpointResponder, diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java index 51d21420c9963..68182321aa55c 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java @@ -2362,7 +2362,9 @@ private void testRestoreLatestCheckpointedStateWithChangingParallelism(boolean s KeyGroupsStateHandle originalKeyedStateBackend = generateKeyGroupState(jobVertexID2, newKeyGroupPartitions2.get(i), false); KeyGroupsStateHandle originalKeyedStateRaw = generateKeyGroupState(jobVertexID2, newKeyGroupPartitions2.get(i), true); - TaskStateSnapshot taskStateHandles = newJobVertex2.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskStateSnapshot(); + JobManagerTaskRestore taskRestore = newJobVertex2.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskRestore(); + Assert.assertEquals(1L, taskRestore.getRestoreCheckpointId()); + TaskStateSnapshot taskStateHandles = taskRestore.getTaskStateSnapshot(); final int headOpIndex = operatorIDs.size() - 1; List> allParallelManagedOpStates = new ArrayList<>(operatorIDs.size()); @@ -2562,7 +2564,9 @@ public void testStateRecoveryWithTopologyChange(int scaleType) throws Exception final List operatorIds = newJobVertex1.getOperatorIDs(); - TaskStateSnapshot stateSnapshot = newJobVertex1.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskStateSnapshot(); + JobManagerTaskRestore taskRestore = newJobVertex1.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskRestore(); + Assert.assertEquals(2L, taskRestore.getRestoreCheckpointId()); + TaskStateSnapshot stateSnapshot = taskRestore.getTaskStateSnapshot(); OperatorSubtaskState headOpState = stateSnapshot.getSubtaskStateByOperatorID(operatorIds.get(operatorIds.size() - 1)); assertTrue(headOpState.getManagedKeyedState().isEmpty()); @@ -2628,7 +2632,9 @@ public void testStateRecoveryWithTopologyChange(int scaleType) throws Exception final List operatorIds = newJobVertex2.getOperatorIDs(); - TaskStateSnapshot stateSnapshot = newJobVertex2.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskStateSnapshot(); + JobManagerTaskRestore taskRestore = newJobVertex2.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskRestore(); + Assert.assertEquals(2L, taskRestore.getRestoreCheckpointId()); + TaskStateSnapshot stateSnapshot = taskRestore.getTaskStateSnapshot(); // operator 3 { @@ -3039,7 +3045,9 @@ public static void verifyStateRestore( for (int i = 0; i < executionJobVertex.getParallelism(); i++) { - TaskStateSnapshot stateSnapshot = executionJobVertex.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskStateSnapshot(); + JobManagerTaskRestore taskRestore = executionJobVertex.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskRestore(); + Assert.assertEquals(1L, taskRestore.getRestoreCheckpointId()); + TaskStateSnapshot stateSnapshot = taskRestore.getTaskStateSnapshot(); OperatorSubtaskState operatorState = stateSnapshot.getSubtaskStateByOperatorID(OperatorID.fromJobVertexID(jobVertexID)); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java index 47daa01bd872f..357735225d669 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java @@ -141,11 +141,12 @@ public void testSetState() { // verify that each stateful vertex got the state - BaseMatcher matcher = new BaseMatcher() { + BaseMatcher matcher = new BaseMatcher() { @Override public boolean matches(Object o) { - if (o instanceof TaskStateSnapshot) { - return Objects.equals(o, subtaskStates); + if (o instanceof JobManagerTaskRestore) { + JobManagerTaskRestore taskRestore = (JobManagerTaskRestore) o; + return Objects.equals(taskRestore.getTaskStateSnapshot(), subtaskStates); } return false; } @@ -159,8 +160,8 @@ public void describeTo(Description description) { verify(statefulExec1, times(1)).setInitialState(Mockito.argThat(matcher)); verify(statefulExec2, times(1)).setInitialState(Mockito.argThat(matcher)); verify(statefulExec3, times(1)).setInitialState(Mockito.argThat(matcher)); - verify(statelessExec1, times(0)).setInitialState(Mockito.any()); - verify(statelessExec2, times(0)).setInitialState(Mockito.any()); + verify(statelessExec1, times(0)).setInitialState(Mockito.any()); + verify(statelessExec2, times(0)).setInitialState(Mockito.any()); } catch (Exception e) { e.printStackTrace(); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorTest.java index 324702485d8f9..e20d34b1d3e1f 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorTest.java @@ -23,6 +23,7 @@ import org.apache.flink.configuration.Configuration; import org.apache.flink.core.testutils.CommonTestUtils; import org.apache.flink.runtime.blob.PermanentBlobKey; +import org.apache.flink.runtime.checkpoint.JobManagerTaskRestore; import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.clusterframework.types.AllocationID; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; @@ -74,6 +75,7 @@ public void testSerialization() { vertexID, taskName, currentNumberOfSubtasks, numberOfKeyGroups, invokableClass.getName(), taskConfiguration)); final int targetSlotNumber = 47; final TaskStateSnapshot taskStateHandles = new TaskStateSnapshot(); + final JobManagerTaskRestore taskRestore = new JobManagerTaskRestore(1L, taskStateHandles); final TaskDeploymentDescriptor orig = new TaskDeploymentDescriptor( jobID, @@ -84,7 +86,7 @@ public void testSerialization() { indexInSubtaskGroup, attemptNumber, targetSlotNumber, - taskStateHandles, + taskRestore, producedResults, inputGates); @@ -93,7 +95,7 @@ public void testSerialization() { assertFalse(orig.getSerializedJobInformation() == copy.getSerializedJobInformation()); assertFalse(orig.getSerializedTaskInformation() == copy.getSerializedTaskInformation()); assertFalse(orig.getExecutionAttemptId() == copy.getExecutionAttemptId()); - assertFalse(orig.getTaskStateHandles() == copy.getTaskStateHandles()); + assertFalse(orig.getTaskRestore() == copy.getTaskRestore()); assertFalse(orig.getProducedPartitions() == copy.getProducedPartitions()); assertFalse(orig.getInputGates() == copy.getInputGates()); @@ -104,7 +106,8 @@ public void testSerialization() { assertEquals(orig.getSubtaskIndex(), copy.getSubtaskIndex()); assertEquals(orig.getAttemptNumber(), copy.getAttemptNumber()); assertEquals(orig.getTargetSlotNumber(), copy.getTargetSlotNumber()); - assertEquals(orig.getTaskStateHandles(), copy.getTaskStateHandles()); + assertEquals(orig.getTaskRestore().getRestoreCheckpointId(), copy.getTaskRestore().getRestoreCheckpointId()); + assertEquals(orig.getTaskRestore().getTaskStateSnapshot(), copy.getTaskRestore().getTaskStateSnapshot()); assertEquals(orig.getProducedPartitions(), copy.getProducedPartitions()); assertEquals(orig.getInputGates(), copy.getInputGates()); } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionVertexLocalityTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionVertexLocalityTest.java index 15d021aad0901..5b0320fae1cdb 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionVertexLocalityTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionVertexLocalityTest.java @@ -24,7 +24,7 @@ import org.apache.flink.metrics.groups.UnregisteredMetricsGroup; import org.apache.flink.runtime.blob.VoidBlobWriter; import org.apache.flink.runtime.checkpoint.StandaloneCheckpointRecoveryFactory; -import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; +import org.apache.flink.runtime.checkpoint.JobManagerTaskRestore; import org.apache.flink.runtime.clusterframework.types.AllocationID; import org.apache.flink.runtime.clusterframework.types.ResourceID; import org.apache.flink.runtime.clusterframework.types.ResourceProfile; @@ -174,7 +174,7 @@ public void testLocalityBasedOnState() throws Exception { // target state ExecutionVertex target = graph.getAllVertices().get(targetVertexId).getTaskVertices()[i]; - target.getCurrentExecutionAttempt().setInitialState(mock(TaskStateSnapshot.class)); + target.getCurrentExecutionAttempt().setInitialState(mock(JobManagerTaskRestore.class)); } // validate that the target vertices have the state's location as the location preference diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java index 88141d6ca710e..ff8d23baa4594 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java @@ -65,6 +65,7 @@ import org.apache.flink.runtime.metrics.NoOpMetricRegistry; import org.apache.flink.runtime.metrics.groups.JobManagerMetricGroup; import org.apache.flink.runtime.state.OperatorStateHandle; +import org.apache.flink.runtime.state.TaskStateManager; import org.apache.flink.runtime.state.memory.ByteStreamStateHandle; import org.apache.flink.runtime.taskmanager.TaskManager; import org.apache.flink.runtime.testingUtils.TestingJobManager; @@ -531,6 +532,22 @@ public static class BlockingInvokable extends AbstractInvokable { @Override public void invoke() throws Exception { + + OperatorID operatorID = OperatorID.fromJobVertexID(getEnvironment().getJobVertexId()); + TaskStateManager taskStateManager = getEnvironment().getTaskStateManager(); + OperatorSubtaskState subtaskState = taskStateManager.operatorStates(operatorID); + + if(subtaskState != null) { + int subtaskIndex = getIndexInSubtaskGroup(); + if (subtaskIndex < BlockingStatefulInvokable.recoveredStates.length) { + OperatorStateHandle operatorStateHandle = subtaskState.getManagedOperatorState().iterator().next(); + try (FSDataInputStream in = operatorStateHandle.openInputStream()) { + BlockingStatefulInvokable.recoveredStates[subtaskIndex] = + InstantiationUtil.deserializeObject(in, getUserCodeClassLoader()); + } + } + } + while (blocking) { synchronized (lock) { lock.wait(); @@ -553,22 +570,10 @@ public static class BlockingStatefulInvokable extends BlockingInvokable implemen private static volatile CountDownLatch completedCheckpointsLatch = new CountDownLatch(1); - private static volatile long[] recoveredStates = new long[0]; + static volatile long[] recoveredStates = new long[0]; private int completedCheckpoints = 0; - @Override - public void setInitialState( - TaskStateSnapshot taskStateHandles) throws Exception { - int subtaskIndex = getIndexInSubtaskGroup(); - if (subtaskIndex < recoveredStates.length) { - OperatorStateHandle operatorStateHandle = extractSingletonOperatorState(taskStateHandles); - try (FSDataInputStream in = operatorStateHandle.openInputStream()) { - recoveredStates[subtaskIndex] = InstantiationUtil.deserializeObject(in, getUserCodeClassLoader()); - } - } - } - @Override public boolean triggerCheckpoint(CheckpointMetaData checkpointMetaData, CheckpointOptions checkpointOptions) throws Exception { ByteStreamStateHandle byteStreamStateHandle = new TestByteStreamStateHandleDeepCompare( diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerTest.java index 6a02d1f024482..7cb93c29dc3cc 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerTest.java @@ -33,7 +33,6 @@ import org.apache.flink.runtime.checkpoint.CheckpointMetrics; import org.apache.flink.runtime.checkpoint.CheckpointOptions; import org.apache.flink.runtime.checkpoint.CheckpointOptions.CheckpointType; -import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.clusterframework.messages.NotifyResourceStarted; import org.apache.flink.runtime.clusterframework.messages.RegisterResourceManager; import org.apache.flink.runtime.clusterframework.messages.RegisterResourceManagerSuccessful; @@ -1536,10 +1535,6 @@ public void invoke() throws Exception { new CountDownLatch(1).await(); } - @Override - public void setInitialState(TaskStateSnapshot taskStateHandles) throws Exception { - } - @Override public boolean triggerCheckpoint( CheckpointMetaData checkpointMetaData, diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java index 0125a5e9fc3b7..d6695c312afcc 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java @@ -38,6 +38,7 @@ import org.apache.flink.runtime.metrics.groups.TaskMetricGroup; import org.apache.flink.runtime.query.KvStateRegistry; import org.apache.flink.runtime.query.TaskKvStateRegistry; +import org.apache.flink.runtime.state.TaskStateManager; import org.apache.flink.runtime.taskmanager.TaskManagerRuntimeInfo; import org.apache.flink.runtime.util.TestingTaskManagerRuntimeInfo; @@ -53,10 +54,15 @@ public class DummyEnvironment implements Environment { private final ExecutionConfig executionConfig = new ExecutionConfig(); private final TaskInfo taskInfo; private KvStateRegistry kvStateRegistry = new KvStateRegistry(); + private TaskStateManager taskStateManager; private final AccumulatorRegistry accumulatorRegistry = new AccumulatorRegistry(jobId, executionId); public DummyEnvironment(String taskName, int numSubTasks, int subTaskIndex) { - this.taskInfo = new TaskInfo(taskName, numSubTasks, subTaskIndex, numSubTasks, 0); + this(taskName, numSubTasks, subTaskIndex, numSubTasks); + } + + public DummyEnvironment(String taskName, int numSubTasks, int subTaskIndex, int maxParallelism) { + this.taskInfo = new TaskInfo(taskName, maxParallelism, subTaskIndex, numSubTasks, 0); } public void setKvStateRegistry(KvStateRegistry kvStateRegistry) { @@ -142,6 +148,11 @@ public BroadcastVariableManager getBroadcastVariableManager() { return null; } + @Override + public TaskStateManager getTaskStateManager() { + return taskStateManager; + } + @Override public AccumulatorRegistry getAccumulatorRegistry() { return accumulatorRegistry; @@ -190,4 +201,7 @@ public InputGate[] getAllInputGates() { return null; } + public void setTaskStateManager(TaskStateManager taskStateManager) { + this.taskStateManager = taskStateManager; + } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java index 7514cc4200d74..7d3f914572512 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java @@ -46,6 +46,7 @@ import org.apache.flink.runtime.metrics.groups.TaskMetricGroup; import org.apache.flink.runtime.query.KvStateRegistry; import org.apache.flink.runtime.query.TaskKvStateRegistry; +import org.apache.flink.runtime.state.TaskStateManager; import org.apache.flink.runtime.taskmanager.TaskManagerRuntimeInfo; import org.apache.flink.runtime.util.TestingTaskManagerRuntimeInfo; import org.apache.flink.types.Record; @@ -78,6 +79,8 @@ public class MockEnvironment implements Environment { private final IOManager ioManager; + private final TaskStateManager taskStateManager; + private final InputSplitProvider inputSplitProvider; private final Configuration jobConfiguration; @@ -100,11 +103,29 @@ public class MockEnvironment implements Environment { private final ClassLoader userCodeClassLoader; - public MockEnvironment(String taskName, long memorySize, MockInputSplitProvider inputSplitProvider, int bufferSize) { - this(taskName, memorySize, inputSplitProvider, bufferSize, new Configuration(), new ExecutionConfig()); + public MockEnvironment( + String taskName, + long memorySize, + MockInputSplitProvider inputSplitProvider, + int bufferSize, + TaskStateManager taskStateManager) { + this( + taskName, + memorySize, + inputSplitProvider, + bufferSize, + new Configuration(), + new ExecutionConfig(), + taskStateManager); } - public MockEnvironment(String taskName, long memorySize, MockInputSplitProvider inputSplitProvider, int bufferSize, Configuration taskConfiguration, ExecutionConfig executionConfig) { + public MockEnvironment( + String taskName, + long memorySize, + MockInputSplitProvider inputSplitProvider, + int bufferSize, Configuration taskConfiguration, + ExecutionConfig executionConfig, + TaskStateManager taskStateManager) { this( taskName, memorySize, @@ -112,6 +133,7 @@ public MockEnvironment(String taskName, long memorySize, MockInputSplitProvider bufferSize, taskConfiguration, executionConfig, + taskStateManager, 1, 1, 0); @@ -124,6 +146,7 @@ public MockEnvironment( int bufferSize, Configuration taskConfiguration, ExecutionConfig executionConfig, + TaskStateManager taskStateManager, int maxParallelism, int parallelism, int subtaskIndex) { @@ -137,7 +160,8 @@ public MockEnvironment( maxParallelism, parallelism, subtaskIndex, - Thread.currentThread().getContextClassLoader()); + Thread.currentThread().getContextClassLoader(), + taskStateManager); } @@ -151,7 +175,8 @@ public MockEnvironment( int maxParallelism, int parallelism, int subtaskIndex, - ClassLoader userCodeClassLoader) { + ClassLoader userCodeClassLoader, + TaskStateManager taskStateManager) { this.taskInfo = new TaskInfo(taskName, maxParallelism, subtaskIndex, parallelism, 0); this.jobConfiguration = new Configuration(); this.taskConfiguration = taskConfiguration; @@ -160,6 +185,7 @@ public MockEnvironment( this.memManager = new MemoryManager(memorySize, 1); this.ioManager = new IOManagerAsync(); + this.executionConfig = executionConfig; this.inputSplitProvider = inputSplitProvider; this.bufferSize = bufferSize; @@ -170,6 +196,7 @@ public MockEnvironment( this.kvStateRegistry = registry.createTaskRegistry(jobID, getJobVertexId()); this.userCodeClassLoader = Preconditions.checkNotNull(userCodeClassLoader); + this.taskStateManager = Preconditions.checkNotNull(taskStateManager); } @@ -338,6 +365,11 @@ public BroadcastVariableManager getBroadcastVariableManager() { return this.bcVarManager; } + @Override + public TaskStateManager getTaskStateManager() { + return taskStateManager; + } + @Override public AccumulatorRegistry getAccumulatorRegistry() { return this.accumulatorRegistry; diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/TaskTestBase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/TaskTestBase.java index 53d75b35c2e34..ee5b2da9fd111 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/TaskTestBase.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/TaskTestBase.java @@ -23,7 +23,6 @@ import org.apache.flink.api.common.io.FileOutputFormat; import org.apache.flink.api.common.operators.util.UserCodeClassWrapper; import org.apache.flink.api.common.operators.util.UserCodeObjectWrapper; -import org.apache.flink.runtime.testutils.recordutils.RecordSerializerFactory; import org.apache.flink.configuration.Configuration; import org.apache.flink.core.fs.FileSystem.WriteMode; import org.apache.flink.core.fs.Path; @@ -33,10 +32,14 @@ import org.apache.flink.runtime.operators.Driver; import org.apache.flink.runtime.operators.shipping.ShipStrategyType; import org.apache.flink.runtime.operators.util.TaskConfig; +import org.apache.flink.runtime.state.TaskStateManager; +import org.apache.flink.runtime.state.TestTaskStateManager; +import org.apache.flink.runtime.testutils.recordutils.RecordSerializerFactory; import org.apache.flink.types.Record; import org.apache.flink.util.InstantiationUtil; import org.apache.flink.util.MutableObjectIterator; import org.apache.flink.util.TestLogger; + import org.junit.After; import org.junit.Assert; @@ -53,7 +56,8 @@ public abstract class TaskTestBase extends TestLogger { public void initEnvironment(long memorySize, int bufferSize) { this.memorySize = memorySize; this.inputSplitProvider = new MockInputSplitProvider(); - this.mockEnv = new MockEnvironment("mock task", this.memorySize, this.inputSplitProvider, bufferSize); + TaskStateManager taskStateManager = new TestTaskStateManager(); + this.mockEnv = new MockEnvironment("mock task", this.memorySize, this.inputSplitProvider, bufferSize, taskStateManager); } public IteratorWrappingTestSingleInputGate addInput(MutableObjectIterator input, int groupId) { diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/MultiStreamStateHandleTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/MultiStreamStateHandleTest.java deleted file mode 100644 index dd34f030a019a..0000000000000 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/MultiStreamStateHandleTest.java +++ /dev/null @@ -1,135 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.runtime.state; - -import org.apache.flink.core.fs.FSDataInputStream; -import org.apache.flink.runtime.state.memory.ByteStreamStateHandle; -import org.junit.Before; -import org.junit.Test; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.List; -import java.util.Random; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.fail; - -public class MultiStreamStateHandleTest { - - private static final int TEST_DATA_LENGTH = 123; - private Random random; - private byte[] testData; - private List streamStateHandles; - - @Before - public void setup() { - random = new Random(0x42); - testData = new byte[TEST_DATA_LENGTH]; - for (int i = 0; i < testData.length; ++i) { - testData[i] = (byte) i; - } - - int idx = 0; - streamStateHandles = new ArrayList<>(); - while (idx < testData.length) { - int len = random.nextInt(5); - byte[] sub = Arrays.copyOfRange(testData, idx, idx + len); - streamStateHandles.add(new ByteStreamStateHandle(String.valueOf(idx), sub)); - idx += len; - } - } - - @Test - public void testMetaData() throws IOException { - MultiStreamStateHandle multiStreamStateHandle = new MultiStreamStateHandle(streamStateHandles); - assertEquals(TEST_DATA_LENGTH, multiStreamStateHandle.getStateSize()); - } - - @Test - public void testLinearRead() throws IOException { - MultiStreamStateHandle multiStreamStateHandle = new MultiStreamStateHandle(streamStateHandles); - try (FSDataInputStream in = multiStreamStateHandle.openInputStream()) { - - for (int i = 0; i < TEST_DATA_LENGTH; ++i) { - assertEquals(i, in.getPos()); - assertEquals(testData[i], in.read()); - } - - assertEquals(-1, in.read()); - assertEquals(TEST_DATA_LENGTH, in.getPos()); - assertEquals(-1, in.read()); - assertEquals(TEST_DATA_LENGTH, in.getPos()); - } - } - - @Test - public void testRandomRead() throws IOException { - - MultiStreamStateHandle multiStreamStateHandle = new MultiStreamStateHandle(streamStateHandles); - - try (FSDataInputStream in = multiStreamStateHandle.openInputStream()) { - - for (int i = 0; i < 1000; ++i) { - int pos = random.nextInt(TEST_DATA_LENGTH); - int readLen = random.nextInt(TEST_DATA_LENGTH); - in.seek(pos); - while (--readLen > 0 && pos < TEST_DATA_LENGTH) { - assertEquals(pos, in.getPos()); - assertEquals(testData[pos++], in.read()); - } - } - - in.seek(TEST_DATA_LENGTH); - assertEquals(TEST_DATA_LENGTH, in.getPos()); - assertEquals(-1, in.read()); - - try { - in.seek(TEST_DATA_LENGTH + 1); - fail(); - } catch (Exception ignored) { - - } - } - } - - @Test - public void testEmptyList() throws IOException { - - MultiStreamStateHandle multiStreamStateHandle = - new MultiStreamStateHandle(Collections.emptyList()); - - try (FSDataInputStream in = multiStreamStateHandle.openInputStream()) { - - assertEquals(0, in.getPos()); - in.seek(0); - assertEquals(0, in.getPos()); - assertEquals(-1, in.read()); - - try { - in.seek(1); - fail(); - } catch (Exception ignored) { - - } - } - } -} \ No newline at end of file diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/TaskStateManagerImplTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/TaskStateManagerImplTest.java new file mode 100644 index 0000000000000..47bbebbcc88e3 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/TaskStateManagerImplTest.java @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state; + +import org.apache.flink.api.common.JobID; +import org.apache.flink.runtime.checkpoint.CheckpointMetaData; +import org.apache.flink.runtime.checkpoint.CheckpointMetrics; +import org.apache.flink.runtime.checkpoint.JobManagerTaskRestore; +import org.apache.flink.runtime.checkpoint.OperatorSubtaskState; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; +import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; +import org.apache.flink.runtime.jobgraph.OperatorID; +import org.apache.flink.runtime.taskmanager.CheckpointResponder; +import org.apache.flink.runtime.taskmanager.TestCheckpointResponder; + +import org.junit.Assert; +import org.junit.Test; + +import static org.mockito.Mockito.mock; + +public class TaskStateManagerImplTest { + + @Test + public void testStateReportingAndRetrieving() { + + JobID jobID = new JobID(42L, 43L); + ExecutionAttemptID executionAttemptID = new ExecutionAttemptID(23L, 24L); + TestCheckpointResponder checkpointResponderMock = new TestCheckpointResponder(); + + TaskStateManager taskStateManager = taskStateManager( + jobID, + executionAttemptID, + checkpointResponderMock, + null); + + //---------------------------------------- test reporting ----------------------------------------- + + CheckpointMetaData checkpointMetaData = new CheckpointMetaData(74L, 11L); + CheckpointMetrics checkpointMetrics = new CheckpointMetrics(); + TaskStateSnapshot taskStateSnapshot = new TaskStateSnapshot(); + + OperatorID operatorID_1 = new OperatorID(1L, 1L); + OperatorID operatorID_2 = new OperatorID(2L, 2L); + OperatorID operatorID_3 = new OperatorID(3L, 3L); + + Assert.assertNull(taskStateManager.operatorStates(operatorID_1)); + Assert.assertNull(taskStateManager.operatorStates(operatorID_2)); + Assert.assertNull(taskStateManager.operatorStates(operatorID_3)); + + OperatorSubtaskState operatorSubtaskState_1 = new OperatorSubtaskState(); + OperatorSubtaskState operatorSubtaskState_2 = new OperatorSubtaskState(); + + taskStateSnapshot.putSubtaskStateByOperatorID(operatorID_1, operatorSubtaskState_1); + taskStateSnapshot.putSubtaskStateByOperatorID(operatorID_2, operatorSubtaskState_2); + + taskStateManager.reportStateHandles(checkpointMetaData, checkpointMetrics, taskStateSnapshot); + + TestCheckpointResponder.AcknowledgeReport acknowledgeReport = + checkpointResponderMock.getAcknowledgeReports().get(0); + + Assert.assertEquals(checkpointMetaData.getCheckpointId(), acknowledgeReport.getCheckpointId()); + Assert.assertEquals(checkpointMetrics, acknowledgeReport.getCheckpointMetrics()); + Assert.assertEquals(executionAttemptID, acknowledgeReport.getExecutionAttemptID()); + Assert.assertEquals(jobID, acknowledgeReport.getJobID()); + Assert.assertEquals(taskStateSnapshot, acknowledgeReport.getSubtaskState()); + + //---------------------------------------- test retrieving ----------------------------------------- + + JobManagerTaskRestore taskRestore = new JobManagerTaskRestore( + 0L, + acknowledgeReport.getSubtaskState()); + + taskStateManager = taskStateManager( + jobID, + executionAttemptID, + checkpointResponderMock, + taskRestore); + + Assert.assertEquals(operatorSubtaskState_1, taskStateManager.operatorStates(operatorID_1)); + Assert.assertEquals(operatorSubtaskState_2, taskStateManager.operatorStates(operatorID_2)); + Assert.assertNull(taskStateManager.operatorStates(operatorID_3)); + } + + public static TaskStateManager taskStateManager( + JobID jobID, + ExecutionAttemptID executionAttemptID, + CheckpointResponder checkpointResponderMock, + JobManagerTaskRestore jobManagerTaskRestore) { + + // for now just a mock because this is not yet implemented + TaskLocalStateStore taskLocalStateStore = mock(TaskLocalStateStore.class); + + return new TaskStateManagerImpl( + jobID, + executionAttemptID, + taskLocalStateStore, + jobManagerTaskRestore, + checkpointResponderMock); + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/TestTaskStateManager.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/TestTaskStateManager.java new file mode 100644 index 0000000000000..e973a02429220 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/TestTaskStateManager.java @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state; + +import org.apache.flink.api.common.JobID; +import org.apache.flink.core.testutils.OneShotLatch; +import org.apache.flink.runtime.checkpoint.CheckpointMetaData; +import org.apache.flink.runtime.checkpoint.CheckpointMetrics; +import org.apache.flink.runtime.checkpoint.OperatorSubtaskState; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; +import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; +import org.apache.flink.runtime.jobgraph.OperatorID; +import org.apache.flink.runtime.taskmanager.CheckpointResponder; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; + +import java.util.HashMap; +import java.util.Map; + +/** + * Implementation of {@link TaskStateManager} for tests. + */ +public class TestTaskStateManager implements TaskStateManager { + + private long reportedCheckpointId; + + private JobID jobId; + private ExecutionAttemptID executionAttemptID; + + private final Map taskStateSnapshotsByCheckpointId; + private CheckpointResponder checkpointResponder; + private OneShotLatch waitForReportLatch; + + public TestTaskStateManager() { + this(null, null, null); + } + + public TestTaskStateManager( + JobID jobId, + ExecutionAttemptID executionAttemptID) { + + this(jobId, executionAttemptID, null); + } + + public TestTaskStateManager( + JobID jobId, + ExecutionAttemptID executionAttemptID, + CheckpointResponder checkpointResponder) { + this.jobId = jobId; + this.executionAttemptID = executionAttemptID; + this.checkpointResponder = checkpointResponder; + this.taskStateSnapshotsByCheckpointId = new HashMap<>(); + this.reportedCheckpointId = -1L; + } + + @Override + public void reportStateHandles( + @Nonnull CheckpointMetaData checkpointMetaData, + @Nonnull CheckpointMetrics checkpointMetrics, + @Nullable TaskStateSnapshot acknowledgedState) { + + if (taskStateSnapshotsByCheckpointId != null) { + taskStateSnapshotsByCheckpointId.put( + checkpointMetaData.getCheckpointId(), + acknowledgedState); + } + + if (checkpointResponder != null) { + checkpointResponder.acknowledgeCheckpoint( + jobId, + executionAttemptID, + checkpointMetaData.getCheckpointId(), + checkpointMetrics, + acknowledgedState); + } + + this.reportedCheckpointId = checkpointMetaData.getCheckpointId(); + + if (waitForReportLatch != null) { + waitForReportLatch.trigger(); + } + } + + @Override + public OperatorSubtaskState operatorStates(OperatorID operatorID) { + TaskStateSnapshot taskStateSnapshot = getLastTaskStateSnapshot(); + return taskStateSnapshot != null ? taskStateSnapshot.getSubtaskStateByOperatorID(operatorID) : null; + } + + @Override + public void notifyCheckpointComplete(long checkpointId) throws Exception { + + } + + public JobID getJobId() { + return jobId; + } + + public void setJobId(JobID jobId) { + this.jobId = jobId; + } + + public ExecutionAttemptID getExecutionAttemptID() { + return executionAttemptID; + } + + public void setExecutionAttemptID(ExecutionAttemptID executionAttemptID) { + this.executionAttemptID = executionAttemptID; + } + + public CheckpointResponder getCheckpointResponder() { + return checkpointResponder; + } + + public void setCheckpointResponder(CheckpointResponder checkpointResponder) { + this.checkpointResponder = checkpointResponder; + } + + public Map getTaskStateSnapshotsByCheckpointId() { + return taskStateSnapshotsByCheckpointId; + } + + public void setTaskStateSnapshotsByCheckpointId(Map taskStateSnapshotsByCheckpointId) { + this.taskStateSnapshotsByCheckpointId.clear(); + this.taskStateSnapshotsByCheckpointId.putAll(taskStateSnapshotsByCheckpointId); + } + + public long getReportedCheckpointId() { + return reportedCheckpointId; + } + + public void setReportedCheckpointId(long reportedCheckpointId) { + this.reportedCheckpointId = reportedCheckpointId; + } + + public TaskStateSnapshot getLastTaskStateSnapshot() { + return taskStateSnapshotsByCheckpointId != null ? + taskStateSnapshotsByCheckpointId.get(reportedCheckpointId) + : null; + } + + public OneShotLatch getWaitForReportLatch() { + return waitForReportLatch; + } + + public void setWaitForReportLatch(OneShotLatch waitForReportLatch) { + this.waitForReportLatch = waitForReportLatch; + } + + public void restoreLatestCheckpointState(Map taskStateSnapshotsByCheckpointId) { + + if (taskStateSnapshotsByCheckpointId == null + || taskStateSnapshotsByCheckpointId.isEmpty()) { + return; + } + + long latestId = -1; + + for (long id : taskStateSnapshotsByCheckpointId.keySet()) { + if (id > latestId) { + latestId = id; + } + } + + setReportedCheckpointId(latestId); + setTaskStateSnapshotsByCheckpointId(taskStateSnapshotsByCheckpointId); + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/taskexecutor/TaskExecutorITCase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/taskexecutor/TaskExecutorITCase.java index 1f1d09d1d88a6..482be3d2327eb 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/taskexecutor/TaskExecutorITCase.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/taskexecutor/TaskExecutorITCase.java @@ -50,6 +50,7 @@ import org.apache.flink.runtime.resourcemanager.StandaloneResourceManager; import org.apache.flink.runtime.resourcemanager.slotmanager.SlotManager; import org.apache.flink.runtime.rpc.TestingRpcService; +import org.apache.flink.runtime.state.TaskExecutorLocalStateStoresManager; import org.apache.flink.runtime.taskexecutor.slot.SlotOffer; import org.apache.flink.runtime.taskexecutor.slot.TaskSlotTable; import org.apache.flink.runtime.taskexecutor.slot.TimerService; @@ -63,6 +64,7 @@ import java.net.InetAddress; import java.util.Arrays; +import java.util.List; import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ScheduledExecutorService; @@ -123,7 +125,8 @@ public void testSlotAllocation() throws Exception { final TaskManagerMetricGroup taskManagerMetricGroup = mock(TaskManagerMetricGroup.class); final BroadcastVariableManager broadcastVariableManager = mock(BroadcastVariableManager.class); final FileCache fileCache = mock(FileCache.class); - final TaskSlotTable taskSlotTable = new TaskSlotTable(Arrays.asList(resourceProfile), new TimerService(scheduledExecutorService, 100L)); + final List resourceProfiles = Arrays.asList(resourceProfile); + final TaskSlotTable taskSlotTable = new TaskSlotTable(resourceProfiles, new TimerService(scheduledExecutorService, 100L)); final JobManagerTable jobManagerTable = new JobManagerTable(); final JobLeaderService jobLeaderService = new JobLeaderService(taskManagerLocation); final SlotManager slotManager = new SlotManager( @@ -132,6 +135,8 @@ public void testSlotAllocation() throws Exception { TestingUtils.infiniteTime(), TestingUtils.infiniteTime()); + final TaskExecutorLocalStateStoresManager taskStateManager = new TaskExecutorLocalStateStoresManager(); + ResourceManager resourceManager = new StandaloneResourceManager( rpcService, FlinkResourceManager.RESOURCE_MANAGER_NAME, @@ -150,6 +155,7 @@ public void testSlotAllocation() throws Exception { taskManagerLocation, memoryManager, ioManager, + taskStateManager, networkEnvironment, testingHAServices, heartbeatServices, diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/taskexecutor/TaskExecutorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/taskexecutor/TaskExecutorTest.java index 776bdf9a5a17e..57096f45812d2 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/taskexecutor/TaskExecutorTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/taskexecutor/TaskExecutorTest.java @@ -70,6 +70,7 @@ import org.apache.flink.runtime.resourcemanager.ResourceManagerId; import org.apache.flink.runtime.rpc.RpcService; import org.apache.flink.runtime.rpc.TestingRpcService; +import org.apache.flink.runtime.state.TaskExecutorLocalStateStoresManager; import org.apache.flink.runtime.taskexecutor.exceptions.SlotAllocationException; import org.apache.flink.runtime.taskexecutor.slot.SlotOffer; import org.apache.flink.runtime.taskexecutor.slot.TaskSlotTable; @@ -210,10 +211,11 @@ public HeartbeatManagerImpl answer(InvocationOnMock invocation) thro tmConfig, taskManagerLocation, mock(MemoryManager.class), - mock(IOManager.class), + mock(IOManager.class),mock(TaskExecutorLocalStateStoresManager.class), mock(NetworkEnvironment.class), haServices, heartbeatServices, + mock(TaskManagerMetricGroup.class), mock(BroadcastVariableManager.class), mock(FileCache.class), @@ -315,10 +317,11 @@ public HeartbeatManagerImpl answer(InvocationOnMock invocation taskManagerConfiguration, taskManagerLocation, mock(MemoryManager.class), - mock(IOManager.class), + mock(IOManager.class),mock(TaskExecutorLocalStateStoresManager.class), mock(NetworkEnvironment.class), haServices, heartbeatServices, + mock(TaskManagerMetricGroup.class), mock(BroadcastVariableManager.class), mock(FileCache.class), @@ -432,10 +435,11 @@ public HeartbeatManagerImpl answer(InvocationOnMock invocation taskManagerConfiguration, taskManagerLocation, mock(MemoryManager.class), - mock(IOManager.class), + mock(IOManager.class),mock(TaskExecutorLocalStateStoresManager.class), mock(NetworkEnvironment.class), haServices, heartbeatServices, + mock(TaskManagerMetricGroup.class), mock(BroadcastVariableManager.class), mock(FileCache.class), @@ -525,9 +529,10 @@ public void testImmediatelyRegistersIfLeaderIsKnown() throws Exception { taskManagerLocation, mock(MemoryManager.class), mock(IOManager.class), - mock(NetworkEnvironment.class), + mock(TaskExecutorLocalStateStoresManager.class),mock(NetworkEnvironment.class), haServices, mock(HeartbeatServices.class, RETURNS_MOCKS), + mock(TaskManagerMetricGroup.class), mock(BroadcastVariableManager.class), mock(FileCache.class), @@ -607,9 +612,10 @@ public void testTriggerRegistrationOnLeaderChange() throws Exception { taskManagerLocation, mock(MemoryManager.class), mock(IOManager.class), - mock(NetworkEnvironment.class), + mock(TaskExecutorLocalStateStoresManager.class),mock(NetworkEnvironment.class), haServices, mock(HeartbeatServices.class, RETURNS_MOCKS), + mock(TaskManagerMetricGroup.class), mock(BroadcastVariableManager.class), mock(FileCache.class), @@ -748,9 +754,10 @@ public void testTaskSubmission() throws Exception { mock(TaskManagerLocation.class), mock(MemoryManager.class), mock(IOManager.class), - networkEnvironment, + mock(TaskExecutorLocalStateStoresManager.class),networkEnvironment, haServices, mock(HeartbeatServices.class, RETURNS_MOCKS), + taskManagerMetricGroup, mock(BroadcastVariableManager.class), mock(FileCache.class), @@ -864,10 +871,11 @@ public void testJobLeaderDetection() throws Exception { taskManagerConfiguration, taskManagerLocation, mock(MemoryManager.class), - mock(IOManager.class), + mock(IOManager.class),mock(TaskExecutorLocalStateStoresManager.class), mock(NetworkEnvironment.class), haServices, mock(HeartbeatServices.class, RETURNS_MOCKS), + mock(TaskManagerMetricGroup.class), mock(BroadcastVariableManager.class), mock(FileCache.class), @@ -983,10 +991,11 @@ public void testSlotAcceptance() throws Exception { taskManagerConfiguration, taskManagerLocation, mock(MemoryManager.class), - mock(IOManager.class), + mock(IOManager.class),mock(TaskExecutorLocalStateStoresManager.class), mock(NetworkEnvironment.class), haServices, mock(HeartbeatServices.class, RETURNS_MOCKS), + mock(TaskManagerMetricGroup.class), mock(BroadcastVariableManager.class), mock(FileCache.class), @@ -1078,9 +1087,10 @@ public void testRejectAllocationRequestsForOutOfSyncSlots() throws Exception { taskManagerLocation, mock(MemoryManager.class), mock(IOManager.class), - mock(NetworkEnvironment.class), + mock(TaskExecutorLocalStateStoresManager.class),mock(NetworkEnvironment.class), haServices, mock(HeartbeatServices.class, RETURNS_MOCKS), + mock(TaskManagerMetricGroup.class), mock(BroadcastVariableManager.class), mock(FileCache.class), @@ -1253,9 +1263,10 @@ public void testSubmitTaskBeforeAcceptSlot() throws Exception { taskManagerLocation, mock(MemoryManager.class), mock(IOManager.class), - networkMock, + mock(TaskExecutorLocalStateStoresManager.class),networkMock, haServices, mock(HeartbeatServices.class, RETURNS_MOCKS), + taskManagerMetricGroup, mock(BroadcastVariableManager.class), mock(FileCache.class), @@ -1375,6 +1386,7 @@ public void testFilterOutDuplicateJobMasterRegistrations() throws Exception { taskManagerLocation, mock(MemoryManager.class), mock(IOManager.class), + mock(TaskExecutorLocalStateStoresManager.class), mock(NetworkEnvironment.class), haServicesMock, heartbeatServicesMock, diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java index 50456067cf955..9a30f2911c9ab 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java @@ -29,7 +29,6 @@ import org.apache.flink.runtime.checkpoint.CheckpointMetaData; import org.apache.flink.runtime.checkpoint.CheckpointMetrics; import org.apache.flink.runtime.checkpoint.CheckpointOptions; -import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.clusterframework.types.AllocationID; import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor; import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor; @@ -53,6 +52,7 @@ import org.apache.flink.runtime.metrics.groups.TaskIOMetricGroup; import org.apache.flink.runtime.metrics.groups.TaskMetricGroup; import org.apache.flink.runtime.query.TaskKvStateRegistry; +import org.apache.flink.runtime.state.TestTaskStateManager; import org.apache.flink.runtime.util.TestingTaskManagerRuntimeInfo; import org.apache.flink.util.SerializedValue; @@ -252,11 +252,12 @@ private Task createTask(Class invokableClass) throw Collections.emptyList(), Collections.emptyList(), 0, - new TaskStateSnapshot(), + null, mock(MemoryManager.class), mock(IOManager.class), networkEnvironment, mock(BroadcastVariableManager.class), + new TestTaskStateManager(), mock(TaskManagerActions.class), mock(InputSplitProvider.class), mock(CheckpointResponder.class), @@ -297,9 +298,6 @@ public void invoke() throws Exception { } } - @Override - public void setInitialState(TaskStateSnapshot taskStateHandles) throws Exception {} - @Override public boolean triggerCheckpoint(CheckpointMetaData checkpointMetaData, CheckpointOptions checkpointOptions) { lastCheckpointId++; diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskManagerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskManagerTest.java index 31e9e223b6900..1b66c4662a148 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskManagerTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskManagerTest.java @@ -18,13 +18,6 @@ package org.apache.flink.runtime.taskmanager; -import akka.actor.ActorRef; -import akka.actor.ActorSystem; -import akka.actor.Kill; -import akka.actor.Props; -import akka.actor.Status; -import akka.japi.Creator; -import akka.testkit.JavaTestKit; import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.JobID; import org.apache.flink.api.common.time.Time; @@ -81,6 +74,14 @@ import org.apache.flink.util.NetUtils; import org.apache.flink.util.SerializedValue; import org.apache.flink.util.TestLogger; + +import akka.actor.ActorRef; +import akka.actor.ActorSystem; +import akka.actor.Kill; +import akka.actor.Props; +import akka.actor.Status; +import akka.japi.Creator; +import akka.testkit.JavaTestKit; import org.junit.After; import org.junit.AfterClass; import org.junit.Assert; @@ -89,11 +90,6 @@ import org.junit.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import scala.Option; -import scala.concurrent.Await; -import scala.concurrent.Future; -import scala.concurrent.duration.FiniteDuration; -import scala.util.Failure; import java.io.IOException; import java.net.InetAddress; @@ -111,6 +107,12 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.TimeUnit; +import scala.Option; +import scala.concurrent.Await; +import scala.concurrent.Future; +import scala.concurrent.duration.FiniteDuration; +import scala.util.Failure; + import static org.apache.flink.runtime.messages.JobManagerMessages.RequestPartitionProducerState; import static org.apache.flink.runtime.messages.JobManagerMessages.ScheduleOrUpdateConsumers; import static org.junit.Assert.assertEquals; diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskTest.java index b08999794e5c3..b42fc33a5b064 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskTest.java @@ -57,16 +57,18 @@ import org.apache.flink.runtime.metrics.groups.TaskIOMetricGroup; import org.apache.flink.runtime.metrics.groups.TaskMetricGroup; import org.apache.flink.runtime.query.TaskKvStateRegistry; +import org.apache.flink.runtime.state.TestTaskStateManager; import org.apache.flink.runtime.util.TestingTaskManagerRuntimeInfo; import org.apache.flink.util.SerializedValue; import org.apache.flink.util.TestLogger; import org.apache.flink.util.WrappingRuntimeException; + import org.junit.After; import org.junit.Before; import org.junit.Test; -import scala.concurrent.duration.FiniteDuration; import javax.annotation.Nonnull; + import java.io.IOException; import java.lang.reflect.Field; import java.util.Collections; @@ -79,6 +81,8 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; +import scala.concurrent.duration.FiniteDuration; + import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; @@ -1014,6 +1018,7 @@ private Task createTask( mock(IOManager.class), networkEnvironment, mock(BroadcastVariableManager.class), + new TestTaskStateManager(), taskManagerConnection, inputSplitProvider, checkpointResponder, diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TestCheckpointResponder.java b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TestCheckpointResponder.java new file mode 100644 index 0000000000000..d8baee3b52ddf --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TestCheckpointResponder.java @@ -0,0 +1,186 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.taskmanager; + +import org.apache.flink.api.common.JobID; +import org.apache.flink.core.testutils.OneShotLatch; +import org.apache.flink.runtime.checkpoint.CheckpointMetrics; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; +import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; + +import java.util.ArrayList; +import java.util.List; + +/** + * Mock for interface {@link CheckpointResponder} for unit testing. + */ +public class TestCheckpointResponder implements CheckpointResponder { + + private final List acknowledgeReports; + private final List declineReports; + + private OneShotLatch acknowledgeLatch; + private OneShotLatch declinedLatch; + + public TestCheckpointResponder() { + this.acknowledgeReports = new ArrayList<>(); + this.declineReports = new ArrayList<>(); + } + + @Override + public void acknowledgeCheckpoint( + JobID jobID, + ExecutionAttemptID executionAttemptID, + long checkpointId, + CheckpointMetrics checkpointMetrics, + TaskStateSnapshot subtaskState) { + + AcknowledgeReport acknowledgeReport = new AcknowledgeReport( + jobID, + executionAttemptID, + checkpointId, + checkpointMetrics, + subtaskState); + + acknowledgeReports.add(acknowledgeReport); + + if (acknowledgeLatch != null) { + acknowledgeLatch.trigger(); + } + } + + @Override + public void declineCheckpoint( + JobID jobID, + ExecutionAttemptID executionAttemptID, + long checkpointId, + Throwable cause) { + + DeclineReport declineReport = new DeclineReport( + jobID, + executionAttemptID, + checkpointId, + cause); + + declineReports.add(declineReport); + + if (declinedLatch != null) { + declinedLatch.trigger(); + } + } + + public static abstract class AbstractReport { + + private final JobID jobID; + private final ExecutionAttemptID executionAttemptID; + private final long checkpointId; + + AbstractReport(JobID jobID, ExecutionAttemptID executionAttemptID, long checkpointId) { + this.jobID = jobID; + this.executionAttemptID = executionAttemptID; + this.checkpointId = checkpointId; + } + + public JobID getJobID() { + return jobID; + } + + public ExecutionAttemptID getExecutionAttemptID() { + return executionAttemptID; + } + + public long getCheckpointId() { + return checkpointId; + } + } + + public static class AcknowledgeReport extends AbstractReport { + + private final CheckpointMetrics checkpointMetrics; + private final TaskStateSnapshot subtaskState; + + public AcknowledgeReport( + JobID jobID, + ExecutionAttemptID executionAttemptID, + long checkpointId, + CheckpointMetrics checkpointMetrics, + TaskStateSnapshot subtaskState) { + + super(jobID, executionAttemptID, checkpointId); + this.checkpointMetrics = checkpointMetrics; + this.subtaskState = subtaskState; + } + + public CheckpointMetrics getCheckpointMetrics() { + return checkpointMetrics; + } + + public TaskStateSnapshot getSubtaskState() { + return subtaskState; + } + } + + public static class DeclineReport extends AbstractReport { + + public final Throwable cause; + + public DeclineReport( + JobID jobID, + ExecutionAttemptID executionAttemptID, + long checkpointId, + Throwable cause) { + + super(jobID, executionAttemptID, checkpointId); + this.cause = cause; + } + + public Throwable getCause() { + return cause; + } + } + + public List getAcknowledgeReports() { + return acknowledgeReports; + } + + public List getDeclineReports() { + return declineReports; + } + + public OneShotLatch getAcknowledgeLatch() { + return acknowledgeLatch; + } + + public void setAcknowledgeLatch(OneShotLatch acknowledgeLatch) { + this.acknowledgeLatch = acknowledgeLatch; + } + + public OneShotLatch getDeclinedLatch() { + return declinedLatch; + } + + public void setDeclinedLatch(OneShotLatch declinedLatch) { + this.declinedLatch = declinedLatch; + } + + public void clear() { + acknowledgeReports.clear(); + declineReports.clear(); + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/util/BlockerCheckpointStreamFactory.java b/flink-runtime/src/test/java/org/apache/flink/runtime/util/BlockerCheckpointStreamFactory.java index 2091e005aaa73..0b4b52f43dbeb 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/util/BlockerCheckpointStreamFactory.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/util/BlockerCheckpointStreamFactory.java @@ -148,7 +148,7 @@ private void unblockAll() { } @Override - public void close() throws Exception { + public void close() throws IOException { } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/util/JvmExitOnFatalErrorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/util/JvmExitOnFatalErrorTest.java index 807229580855a..7d650a25b4810 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/util/JvmExitOnFatalErrorTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/util/JvmExitOnFatalErrorTest.java @@ -53,6 +53,9 @@ import org.apache.flink.runtime.memory.MemoryManager; import org.apache.flink.runtime.operators.testutils.UnregisteredTaskMetricsGroup; import org.apache.flink.runtime.query.TaskKvStateRegistry; +import org.apache.flink.runtime.state.TaskLocalStateStore; +import org.apache.flink.runtime.state.TaskStateManager; +import org.apache.flink.runtime.state.TaskStateManagerImpl; import org.apache.flink.runtime.taskexecutor.TaskManagerConfiguration; import org.apache.flink.runtime.taskmanager.CheckpointResponder; import org.apache.flink.runtime.taskmanager.Task; @@ -165,6 +168,15 @@ public static void main(String[] args) throws Exception { BlobCacheService blobService = new BlobCacheService(mock(PermanentBlobCache.class), mock(TransientBlobCache.class)); + final TaskLocalStateStore localStateStore = new TaskLocalStateStore(jid, jobVertexId, 0); + final TaskStateManager slotStateManager = + new TaskStateManagerImpl( + jid, + executionAttemptID, + localStateStore, + null, + mock(CheckpointResponder.class)); + Task task = new Task( jobInformation, taskInformation, @@ -180,6 +192,7 @@ public static void main(String[] args) throws Exception { ioManager, networkEnvironment, new BroadcastVariableManager(), + slotStateManager, new NoOpTaskManagerActions(), new NoOpInputSplitProvider(), new NoOpCheckpointResponder(), diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java index 58022a9c61ae7..62a92a5dbfe4d 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java @@ -28,7 +28,7 @@ import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.configuration.Configuration; import org.apache.flink.configuration.MetricOptions; -import org.apache.flink.core.memory.DataInputViewStreamWrapper; +import org.apache.flink.core.fs.CloseableRegistry; import org.apache.flink.core.memory.DataOutputViewStreamWrapper; import org.apache.flink.metrics.Counter; import org.apache.flink.metrics.Gauge; @@ -36,23 +36,21 @@ import org.apache.flink.metrics.groups.UnregisteredMetricsGroup; import org.apache.flink.runtime.checkpoint.CheckpointOptions; import org.apache.flink.runtime.checkpoint.CheckpointOptions.CheckpointType; -import org.apache.flink.runtime.checkpoint.OperatorSubtaskState; +import org.apache.flink.runtime.execution.Environment; import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.metrics.groups.OperatorMetricGroup; import org.apache.flink.runtime.state.AbstractKeyedStateBackend; import org.apache.flink.runtime.state.CheckpointStreamFactory; import org.apache.flink.runtime.state.DefaultKeyedStateStore; import org.apache.flink.runtime.state.KeyGroupRange; -import org.apache.flink.runtime.state.KeyGroupRangeAssignment; import org.apache.flink.runtime.state.KeyGroupStatePartitionStreamProvider; import org.apache.flink.runtime.state.KeyGroupsList; import org.apache.flink.runtime.state.KeyedStateBackend; import org.apache.flink.runtime.state.KeyedStateCheckpointOutputStream; -import org.apache.flink.runtime.state.KeyedStateHandle; import org.apache.flink.runtime.state.OperatorStateBackend; -import org.apache.flink.runtime.state.OperatorStateHandle; import org.apache.flink.runtime.state.StateInitializationContext; import org.apache.flink.runtime.state.StateInitializationContextImpl; +import org.apache.flink.runtime.state.StatePartitionStreamProvider; import org.apache.flink.runtime.state.StateSnapshotContext; import org.apache.flink.runtime.state.StateSnapshotContextSynchronousImpl; import org.apache.flink.runtime.state.VoidNamespace; @@ -63,21 +61,23 @@ import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService; import org.apache.flink.streaming.runtime.tasks.StreamTask; +import org.apache.flink.util.CloseableIterable; +import org.apache.flink.util.ExceptionUtils; +import org.apache.flink.util.IOUtils; import org.apache.flink.util.OutputTag; +import org.apache.flink.util.Preconditions; import org.apache.commons.math3.stat.descriptive.DescriptiveStatistics; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.io.Closeable; import java.io.IOException; import java.io.Serializable; -import java.util.Collection; import java.util.ConcurrentModificationException; import java.util.HashMap; import java.util.Map; -import static org.apache.flink.util.Preconditions.checkArgument; - /** * Base class for all stream operators. Operators that contain a user function should extend the class * {@link AbstractUdfStreamOperator} instead (which is a specialized subclass of this class). @@ -94,7 +94,7 @@ */ @PublicEvolving public abstract class AbstractStreamOperator - implements StreamOperator, Serializable, KeyContext { + implements StreamOperator, Serializable { private static final long serialVersionUID = 1L; @@ -152,6 +152,8 @@ public abstract class AbstractStreamOperator /** Operator state backend / store. */ private transient OperatorStateBackend operatorStateBackend; + private transient StreamTaskStateManager streamTaskStateManager; + // --------------- Metrics --------------------------- /** Metric group for the operator. */ @@ -177,10 +179,11 @@ public abstract class AbstractStreamOperator @Override public void setup(StreamTask containingTask, StreamConfig config, Output> output) { + final Environment environment = containingTask.getEnvironment(); this.container = containingTask; this.config = config; try { - OperatorMetricGroup operatorMetricGroup = container.getEnvironment().getMetricGroup().addOperator(config.getOperatorID(), config.getOperatorName()); + OperatorMetricGroup operatorMetricGroup = environment.getMetricGroup().addOperator(config.getOperatorID(), config.getOperatorName()); this.output = new CountingOutput(output, operatorMetricGroup.getIOMetricGroup().getNumRecordsOutCounter()); if (config.isChainStart()) { operatorMetricGroup.getIOMetricGroup().reuseInputMetricsForTask(); @@ -194,7 +197,7 @@ public void setup(StreamTask containingTask, StreamConfig config, Output containingTask, StreamConfig config, Output keyedStateHandlesRaw = null; - Collection operatorStateHandlesRaw = null; - Collection operatorStateHandlesBackend = null; + final TypeSerializer keySerializer = config.getStateKeySerializer(getUserCodeClassloader()); - boolean restoring = (null != stateHandles); + final StreamTask containingTask = + Preconditions.checkNotNull(getContainingTask()); + final CloseableRegistry streamTaskCloseableRegistry = + Preconditions.checkNotNull(containingTask.getCancelables()); + final StreamTaskStateManager streamTaskStateManager = + Preconditions.checkNotNull(containingTask.getStreamTaskStateManager()); - initKeyedState(); //TODO we should move the actual initialization of this from StreamTask to this class - - if (getKeyedStateBackend() != null && timeServiceManager == null) { - timeServiceManager = new InternalTimeServiceManager<>( - getKeyedStateBackend().getNumberOfKeyGroups(), - getKeyedStateBackend().getKeyGroupRange(), + final StreamOperatorStateContext context = + streamTaskStateManager.streamOperatorStateContext( this, - getRuntimeContext().getProcessingTimeService()); - } + keySerializer, + streamTaskCloseableRegistry); - if (restoring) { + this.operatorStateBackend = context.operatorStateBackend(); + this.keyedStateBackend = context.keyedStateBackend(); - //pass directly - operatorStateHandlesBackend = stateHandles.getManagedOperatorState(); - operatorStateHandlesRaw = stateHandles.getRawOperatorState(); - - if (null != getKeyedStateBackend()) { - //only use the keyed state if it is meant for us (aka head operator) - keyedStateHandlesRaw = stateHandles.getRawKeyedState(); - } + if (keyedStateBackend != null) { + this.keyedStateStore = new DefaultKeyedStateStore(keyedStateBackend, getExecutionConfig()); } - checkpointStreamFactory = container.createCheckpointStreamFactory(this); + timeServiceManager = context.internalTimerServiceManager(); + checkpointStreamFactory = context.checkpointStreamFactory(); - initOperatorState(operatorStateHandlesBackend); + CloseableIterable keyedStateInputs = context.rawKeyedStateInputs(); + CloseableIterable operatorStateInputs = context.rawOperatorStateInputs(); - StateInitializationContext initializationContext = new StateInitializationContextImpl( - restoring, // information whether we restore or start for the first time + try { + StateInitializationContext initializationContext = new StateInitializationContextImpl( + context.isRestored(), // information whether we restore or start for the first time operatorStateBackend, // access to operator state backend keyedStateStore, // access to keyed state backend - keyedStateHandlesRaw, // access to keyed state stream - operatorStateHandlesRaw, // access to operator state stream - getContainingTask().getCancelables()); // access to register streams for canceling + keyedStateInputs, // access to operator state stream + operatorStateInputs); // access to keyed state stream - initializeState(initializationContext); + initializeState(initializationContext); + } finally { + closeFromRegistry(operatorStateInputs, streamTaskCloseableRegistry); + closeFromRegistry(keyedStateInputs, streamTaskCloseableRegistry); + } + } + + private static void closeFromRegistry(Closeable closeable, CloseableRegistry registry) { + if (registry.unregisterCloseable(closeable)) { + IOUtils.closeQuietly(closeable); + } } /** @@ -270,39 +279,6 @@ public final void initializeState(OperatorSubtaskState stateHandles) throws Exce @Override public void open() throws Exception {} - private void initKeyedState() { - try { - TypeSerializer keySerializer = config.getStateKeySerializer(getUserCodeClassloader()); - // create a keyed state backend if there is keyed state, as indicated by the presence of a key serializer - if (null != keySerializer) { - KeyGroupRange subTaskKeyGroupRange = KeyGroupRangeAssignment.computeKeyGroupRangeForOperatorIndex( - container.getEnvironment().getTaskInfo().getMaxNumberOfParallelSubtasks(), - container.getEnvironment().getTaskInfo().getNumberOfParallelSubtasks(), - container.getEnvironment().getTaskInfo().getIndexOfThisSubtask()); - - this.keyedStateBackend = container.createKeyedStateBackend( - keySerializer, - // The maximum parallelism == number of key group - container.getEnvironment().getTaskInfo().getMaxNumberOfParallelSubtasks(), - subTaskKeyGroupRange); - - this.keyedStateStore = new DefaultKeyedStateStore(keyedStateBackend, getExecutionConfig()); - } - - } catch (Exception e) { - throw new IllegalStateException("Could not initialize keyed state backend.", e); - } - } - - private void initOperatorState(Collection operatorStateHandles) { - try { - // create an operator state backend - this.operatorStateBackend = container.createOperatorStateBackend(this, operatorStateHandles); - } catch (Exception e) { - throw new IllegalStateException("Could not initialize operator state backend.", e); - } - } - /** * This method is called after all records have been added to the operators via the methods * {@link OneInputStreamOperator#processElement(StreamRecord)}, or @@ -316,7 +292,38 @@ private void initOperatorState(Collection operatorStateHand * @throws Exception An exception in this method causes the operator to fail. */ @Override - public void close() throws Exception {} + public void close() throws Exception { + + Exception exception = null; + + StreamTask containingTask = getContainingTask(); + + if (containingTask == null) { + return; // without a containing task, we have nothing to do here. + } + + CloseableRegistry cancelables = containingTask.getCancelables(); + + try { + if (cancelables.unregisterCloseable(operatorStateBackend)) { + operatorStateBackend.close(); + } + } catch (Exception e) { + exception = e; + } + + try { + if (cancelables.unregisterCloseable(keyedStateBackend)) { + keyedStateBackend.close(); + } + } catch (Exception e) { + exception = ExceptionUtils.firstOrSuppressed(e, exception); + } + + if (exception != null) { + throw exception; + } + } /** * This method is called at the very end of the operator's life, both in the case of a successful @@ -328,12 +335,26 @@ public void close() throws Exception {} @Override public void dispose() throws Exception { - if (operatorStateBackend != null) { - operatorStateBackend.dispose(); + Exception exception = null; + + try { + if (operatorStateBackend != null) { + operatorStateBackend.dispose(); + } + } catch (Exception e) { + exception = e; } - if (keyedStateBackend != null) { - keyedStateBackend.dispose(); + try { + if (keyedStateBackend != null) { + keyedStateBackend.dispose(); + } + } catch (Exception e) { + exception = ExceptionUtils.firstOrSuppressed(e, exception); + } + + if (exception != null) { + throw exception; } } @@ -426,25 +447,11 @@ public void snapshotState(StateSnapshotContext context) throws Exception { * @param context context that allows to register different states. */ public void initializeState(StateInitializationContext context) throws Exception { - if (getKeyedStateBackend() != null) { - KeyGroupsList localKeyGroupRange = getKeyedStateBackend().getKeyGroupRange(); - - // and then initialize the timer services - for (KeyGroupStatePartitionStreamProvider streamProvider : context.getRawKeyedStateInputs()) { - int keyGroupIdx = streamProvider.getKeyGroupId(); - checkArgument(localKeyGroupRange.contains(keyGroupIdx), - "Key Group " + keyGroupIdx + " does not belong to the local range."); - - timeServiceManager.restoreStateForKeyGroup( - new DataInputViewStreamWrapper(streamProvider.getStream()), - keyGroupIdx, getUserCodeClassloader()); - } - } } @Override - public void notifyOfCompletedCheckpoint(long checkpointId) throws Exception { + public void notifyCheckpointComplete(long checkpointId) throws Exception { if (keyedStateBackend != null) { keyedStateBackend.notifyCheckpointComplete(checkpointId); } diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java index 329ce183ce5d0..71f3d3eb6d92c 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java @@ -123,8 +123,8 @@ public void dispose() throws Exception { // ------------------------------------------------------------------------ @Override - public void notifyOfCompletedCheckpoint(long checkpointId) throws Exception { - super.notifyOfCompletedCheckpoint(checkpointId); + public void notifyCheckpointComplete(long checkpointId) throws Exception { + super.notifyCheckpointComplete(checkpointId); if (userFunction instanceof CheckpointListener) { ((CheckpointListener) userFunction).notifyCheckpointComplete(checkpointId); diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java index 38b4aeedb1b2c..07adfe409345b 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java @@ -20,8 +20,8 @@ import org.apache.flink.annotation.PublicEvolving; import org.apache.flink.metrics.MetricGroup; import org.apache.flink.runtime.checkpoint.CheckpointOptions; -import org.apache.flink.runtime.checkpoint.OperatorSubtaskState; import org.apache.flink.runtime.jobgraph.OperatorID; +import org.apache.flink.runtime.state.CheckpointListener; import org.apache.flink.streaming.api.graph.StreamConfig; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.streaming.runtime.tasks.StreamTask; @@ -44,7 +44,7 @@ * @param The output type of the operator */ @PublicEvolving -public interface StreamOperator extends Serializable { +public interface StreamOperator extends CheckpointListener, KeyContext, Serializable { // ------------------------------------------------------------------------ // life cycle @@ -104,21 +104,9 @@ OperatorSnapshotResult snapshotState( CheckpointOptions checkpointOptions) throws Exception; /** - * Provides state handles to restore the operator state. - * - * @param stateHandles state handles to the operator state. - */ - void initializeState(OperatorSubtaskState stateHandles) throws Exception; - - /** - * Called when the checkpoint with the given ID is completed and acknowledged on the JobManager. - * - * @param checkpointId The ID of the checkpoint that has been completed. - * - * @throws Exception Exceptions during checkpoint acknowledgement may be forwarded and will cause - * the program to fail and enter recovery. + * Provides a context to initialize all state in the operator. */ - void notifyOfCompletedCheckpoint(long checkpointId) throws Exception; + void initializeState() throws Exception; // ------------------------------------------------------------------------ // miscellaneous diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperatorStateContext.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperatorStateContext.java new file mode 100644 index 0000000000000..f838760b6e520 --- /dev/null +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperatorStateContext.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.api.operators; + +import org.apache.flink.runtime.state.AbstractKeyedStateBackend; +import org.apache.flink.runtime.state.CheckpointStreamFactory; +import org.apache.flink.runtime.state.KeyGroupStatePartitionStreamProvider; +import org.apache.flink.runtime.state.OperatorStateBackend; +import org.apache.flink.runtime.state.StatePartitionStreamProvider; +import org.apache.flink.util.CloseableIterable; + +/** + * This interface represents a context from which a stream operator can initialize everything connected state to + * state such as e.g. backends, raw state, and timer service manager. + */ +public interface StreamOperatorStateContext { + + /** + * Returns true, the states provided by this context are restored from a checkpoint/savepoint. + */ + boolean isRestored(); + + /** + * Returns the operator state backend for the stream operator. + */ + OperatorStateBackend operatorStateBackend(); + + + /** + * Returns the keyed state backend for the stream operator. This method returns null for non-keyed operators. + */ + AbstractKeyedStateBackend keyedStateBackend(); + + /** + * Returns the internal timer service manager for the stream operator. This method returns null for non-keyed + * operators. + */ + InternalTimeServiceManager internalTimerServiceManager(); + + /** + * Returns the checkpoint stream factory for the stream operator. + */ + CheckpointStreamFactory checkpointStreamFactory(); + + /** + * Returns an iterable to obtain input streams for previously stored operator state partitions that are assigned to + * this stream operator. + */ + CloseableIterable rawOperatorStateInputs(); + + /** + * Returns an iterable to obtain input streams for previously stored keyed state partitions that are assigned to + * this operator. This method returns null for non-keyed operators. + */ + CloseableIterable rawKeyedStateInputs(); + +} diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamTaskStateManager.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamTaskStateManager.java new file mode 100644 index 0000000000000..59c87f06c29bb --- /dev/null +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamTaskStateManager.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.api.operators; + +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.core.fs.CloseableRegistry; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; + +/** + * This is the interface through which stream task expose a {@link StreamOperatorStateContext} to their operators. + * Operators, in turn, can use the context to initialize everything connected to their state, such as backends or + * a timer service manager. + */ +public interface StreamTaskStateManager { + + /** + * Returns the {@link StreamOperatorStateContext} for an {@link AbstractStreamOperator} that runs in the stream + * task that owns this manager. + * + * @param operator the operator for which the context is created. Cannot be null. + * @param keySerializer the key-serializer for the operator. Can be null. + * @param streamTaskCloseableRegistry the closeable registry to which created closeable objects will be registered. + * @return a context from which the given operator can initialize everything related to state. + * @throws Exception when something went wrong while creating the context. + */ + StreamOperatorStateContext streamOperatorStateContext( + @Nonnull AbstractStreamOperator operator, + @Nullable TypeSerializer keySerializer, + @Nonnull CloseableRegistry streamTaskCloseableRegistry) throws Exception; +} diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamTaskStateManagerImpl.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamTaskStateManagerImpl.java new file mode 100644 index 0000000000000..ef92b29428647 --- /dev/null +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamTaskStateManagerImpl.java @@ -0,0 +1,640 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.api.operators; + +import org.apache.flink.api.common.TaskInfo; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.core.fs.CloseableRegistry; +import org.apache.flink.core.fs.FSDataInputStream; +import org.apache.flink.core.memory.DataInputViewStreamWrapper; +import org.apache.flink.runtime.checkpoint.OperatorSubtaskState; +import org.apache.flink.runtime.execution.Environment; +import org.apache.flink.runtime.state.AbstractKeyedStateBackend; +import org.apache.flink.runtime.state.CheckpointStreamFactory; +import org.apache.flink.runtime.state.DefaultOperatorStateBackend; +import org.apache.flink.runtime.state.KeyGroupRange; +import org.apache.flink.runtime.state.KeyGroupRangeAssignment; +import org.apache.flink.runtime.state.KeyGroupStatePartitionStreamProvider; +import org.apache.flink.runtime.state.KeyGroupsStateHandle; +import org.apache.flink.runtime.state.KeyedStateHandle; +import org.apache.flink.runtime.state.OperatorStateBackend; +import org.apache.flink.runtime.state.OperatorStateHandle; +import org.apache.flink.runtime.state.StateBackend; +import org.apache.flink.runtime.state.StatePartitionStreamProvider; +import org.apache.flink.runtime.state.StreamStateHandle; +import org.apache.flink.runtime.state.TaskStateManager; +import org.apache.flink.runtime.util.OperatorSubtaskDescriptionText; +import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService; +import org.apache.flink.util.CloseableIterable; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.io.IOUtils; + +import javax.annotation.Nonnull; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Iterator; +import java.util.List; +import java.util.NoSuchElementException; + +/** + * This class is the main implementation of a {@link StreamTaskStateManager}. This class obtains the state to create + * {@link StreamOperatorStateContext} objects for stream operators from the {@link TaskStateManager} of the task that + * runs the stream task and hence the operator. + * + *

This implementation operates on top a {@link TaskStateManager}, from which it receives everything required to + * restore state in the backends from checkpoints or savepoints. + */ +public class StreamTaskStateManagerImpl implements StreamTaskStateManager { + + /** + * The environment of the task. This is required as parameter to construct state backends via their factory. + */ + private final Environment environment; + + /** This processing time service is required to construct an internal timer service manager. */ + private final ProcessingTimeService processingTimeService; + + /** The state manager of the tasks provides the information used to restore potential previous state. */ + private final TaskStateManager taskStateManager; + + /** This object is the factory for everything related to state backends and checkpointing. */ + private final StateBackend stateBackend; + + public StreamTaskStateManagerImpl( + Environment environment, + StateBackend stateBackend, + ProcessingTimeService processingTimeService) { + + this.environment = environment; + this.taskStateManager = Preconditions.checkNotNull(environment.getTaskStateManager()); + this.stateBackend = Preconditions.checkNotNull(stateBackend); + this.processingTimeService = processingTimeService; + } + + // ----------------------------------------------------------------------------------------------------------------- + + @Override + public StreamOperatorStateContext streamOperatorStateContext( + AbstractStreamOperator operator, + TypeSerializer keySerializer, + CloseableRegistry streamTaskCloseableRegistry) throws Exception { + + TaskInfo taskInfo = environment.getTaskInfo(); + OperatorSubtaskDescriptionText operatorSubtaskDescription = + new OperatorSubtaskDescriptionText( + operator.getOperatorID(), + operator.getClass(), + taskInfo.getIndexOfThisSubtask(), + taskInfo.getNumberOfParallelSubtasks()); + + final String operatorIdentifierText = operatorSubtaskDescription.toString(); + + final OperatorSubtaskState operatorSubtaskStateFromJobManager = + taskStateManager.operatorStates(operator.getOperatorID()); + + final boolean restoring = (operatorSubtaskStateFromJobManager != null); + + AbstractKeyedStateBackend keyedStatedBackend = null; + OperatorStateBackend operatorStateBackend = null; + CloseableIterable rawKeyedStateInputs = null; + CloseableIterable rawOperatorStateInputs = null; + CheckpointStreamFactory checkpointStreamFactory = null; + InternalTimeServiceManager timeServiceManager = null; + + try { + + // -------------- Keyed State Backend -------------- + keyedStatedBackend = keyedStatedBackend( + keySerializer, + operatorIdentifierText, + operatorSubtaskStateFromJobManager, + streamTaskCloseableRegistry); + + // -------------- Operator State Backend -------------- + operatorStateBackend = operatorStateBackend( + operatorIdentifierText, + operatorSubtaskStateFromJobManager, + streamTaskCloseableRegistry); + + // -------------- Raw State Streams -------------- + rawKeyedStateInputs = rawKeyedStateInputs(operatorSubtaskStateFromJobManager); + streamTaskCloseableRegistry.registerCloseable(rawKeyedStateInputs); + + rawOperatorStateInputs = rawOperatorStateInputs(operatorSubtaskStateFromJobManager); + streamTaskCloseableRegistry.registerCloseable(rawOperatorStateInputs); + + // -------------- Checkpoint Stream Factory -------------- + checkpointStreamFactory = streamFactory(operatorIdentifierText); + streamTaskCloseableRegistry.registerCloseable(checkpointStreamFactory); + + // -------------- Internal Timer Service Manager -------------- + timeServiceManager = internalTimeServiceManager(keyedStatedBackend, operator, rawKeyedStateInputs); + + // -------------- Preparing return value -------------- + + return new StreamOperatorStateContextImpl( + restoring, + operatorStateBackend, + keyedStatedBackend, + timeServiceManager, + rawOperatorStateInputs, + rawKeyedStateInputs, + checkpointStreamFactory); + } catch (Exception ex) { + + // cleanup if something went wrong before results got published. + if (streamTaskCloseableRegistry.unregisterCloseable(keyedStatedBackend)) { + IOUtils.closeQuietly(keyedStatedBackend); + } + + if (streamTaskCloseableRegistry.unregisterCloseable(operatorStateBackend)) { + IOUtils.closeQuietly(keyedStatedBackend); + } + + if (streamTaskCloseableRegistry.unregisterCloseable(rawKeyedStateInputs)) { + IOUtils.closeQuietly(rawKeyedStateInputs); + } + + if (streamTaskCloseableRegistry.unregisterCloseable(rawOperatorStateInputs)) { + IOUtils.closeQuietly(rawOperatorStateInputs); + } + + if (streamTaskCloseableRegistry.unregisterCloseable(rawOperatorStateInputs)) { + IOUtils.closeQuietly(rawOperatorStateInputs); + } + + throw new Exception("Exception while creating StreamOperatorStateContext.", ex); + } + } + + protected InternalTimeServiceManager internalTimeServiceManager( + AbstractKeyedStateBackend keyedStatedBackend, + KeyContext keyContext, //the operator + Iterable rawKeyedStates) throws Exception { + + if (keyedStatedBackend == null) { + return null; + } + + final KeyGroupRange keyGroupRange = keyedStatedBackend.getKeyGroupRange(); + + final InternalTimeServiceManager timeServiceManager = new InternalTimeServiceManager<>( + keyedStatedBackend.getNumberOfKeyGroups(), + keyGroupRange, + keyContext, + processingTimeService); + + // and then initialize the timer services + for (KeyGroupStatePartitionStreamProvider streamProvider : rawKeyedStates) { + int keyGroupIdx = streamProvider.getKeyGroupId(); + + Preconditions.checkArgument(keyGroupRange.contains(keyGroupIdx), + "Key Group " + keyGroupIdx + " does not belong to the local range."); + + timeServiceManager.restoreStateForKeyGroup( + new DataInputViewStreamWrapper(streamProvider.getStream()), + keyGroupIdx, environment.getUserClassLoader()); + } + + return timeServiceManager; + } + + protected OperatorStateBackend operatorStateBackend( + String operatorIdentifierText, + OperatorSubtaskState operatorSubtaskStateFromJobManager, + CloseableRegistry backendCloseableRegistry) throws Exception { + + //TODO search in local state for a local recovery opportunity. + + return createOperatorStateBackendFromJobManagerState( + operatorIdentifierText, + operatorSubtaskStateFromJobManager, + backendCloseableRegistry); + } + + protected AbstractKeyedStateBackend keyedStatedBackend( + TypeSerializer keySerializer, + String operatorIdentifierText, + OperatorSubtaskState operatorSubtaskStateFromJobManager, + CloseableRegistry backendCloseableRegistry) throws Exception { + + if (keySerializer == null) { + return null; + } + + //TODO search in local state for a local recovery opportunity. + + return createKeyedStatedBackendFromJobManagerState( + keySerializer, + operatorIdentifierText, + operatorSubtaskStateFromJobManager, + backendCloseableRegistry); + } + + protected CheckpointStreamFactory streamFactory(String operatorIdentifierText) throws IOException { + return stateBackend.createStreamFactory(environment.getJobID(), operatorIdentifierText); + } + + protected CloseableIterable rawOperatorStateInputs( + OperatorSubtaskState operatorSubtaskStateFromJobManager) { + + if (operatorSubtaskStateFromJobManager != null) { + + final CloseableRegistry closeableRegistry = new CloseableRegistry(); + + Collection rawOperatorState = + operatorSubtaskStateFromJobManager.getRawOperatorState(); + + return new CloseableIterable() { + @Override + public void close() throws IOException { + closeableRegistry.close(); + } + + @Nonnull + @Override + public Iterator iterator() { + return new OperatorStateStreamIterator( + DefaultOperatorStateBackend.DEFAULT_OPERATOR_STATE_NAME, + rawOperatorState.iterator(), closeableRegistry); + } + }; + } + + return CloseableIterable.empty(); + } + + protected CloseableIterable rawKeyedStateInputs( + OperatorSubtaskState operatorSubtaskStateFromJobManager) { + + if (operatorSubtaskStateFromJobManager != null) { + + Collection rawKeyedState = operatorSubtaskStateFromJobManager.getRawKeyedState(); + Collection keyGroupsStateHandles = transform(rawKeyedState); + final CloseableRegistry closeableRegistry = new CloseableRegistry(); + + return new CloseableIterable() { + @Override + public void close() throws IOException { + closeableRegistry.close(); + } + + @Override + public Iterator iterator() { + return new KeyGroupStreamIterator(keyGroupsStateHandles.iterator(), closeableRegistry); + } + }; + } + + return CloseableIterable.empty(); + } + + // ================================================================================================================= + + private OperatorStateBackend createOperatorStateBackendFromJobManagerState( + String operatorIdentifierText, + OperatorSubtaskState operatorSubtaskStateFromJobManager, + CloseableRegistry backendCloseableRegistry) throws Exception { + + final OperatorStateBackend operatorStateBackend = + stateBackend.createOperatorStateBackend(environment, operatorIdentifierText); + + backendCloseableRegistry.registerCloseable(operatorStateBackend); + + Collection managedOperatorState = null; + + if (operatorSubtaskStateFromJobManager != null) { + managedOperatorState = operatorSubtaskStateFromJobManager.getManagedOperatorState(); + } + + operatorStateBackend.restore(managedOperatorState); + + return operatorStateBackend; + } + + private AbstractKeyedStateBackend createKeyedStatedBackendFromJobManagerState( + TypeSerializer keySerializer, + String operatorIdentifierText, + OperatorSubtaskState operatorSubtaskStateFromJobManager, + CloseableRegistry backendCloseableRegistry) throws Exception { + + final AbstractKeyedStateBackend keyedStateBackend = createKeyedStateBackend( + operatorIdentifierText, + keySerializer); + + backendCloseableRegistry.registerCloseable(keyedStateBackend); + + Collection managedKeyedState = null; + + if (operatorSubtaskStateFromJobManager != null) { + managedKeyedState = operatorSubtaskStateFromJobManager.getManagedKeyedState(); + } + + keyedStateBackend.restore(managedKeyedState); + + return keyedStateBackend; + } + + private AbstractKeyedStateBackend createKeyedStateBackend( + String operatorIdentifier, + TypeSerializer keySerializer) throws Exception { + + TaskInfo taskInfo = environment.getTaskInfo(); + + final KeyGroupRange keyGroupRange = KeyGroupRangeAssignment.computeKeyGroupRangeForOperatorIndex( + taskInfo.getMaxNumberOfParallelSubtasks(), + taskInfo.getNumberOfParallelSubtasks(), + taskInfo.getIndexOfThisSubtask()); + + return stateBackend.createKeyedStateBackend( + environment, + environment.getJobID(), + operatorIdentifier, + keySerializer, + taskInfo.getMaxNumberOfParallelSubtasks(), //TODO check: this is numberOfKeyGroups !!!! + keyGroupRange, + environment.getTaskKvStateRegistry()); + } + + // ================================================================================================================= + + private static class KeyGroupStreamIterator + extends AbstractStateStreamIterator { + + private Iterator> currentOffsetsIterator; + + KeyGroupStreamIterator( + Iterator stateHandleIterator, CloseableRegistry closableRegistry) { + + super(stateHandleIterator, closableRegistry); + } + + @Override + public boolean hasNext() { + + if (null != currentStateHandle && currentOffsetsIterator.hasNext()) { + + return true; + } + + closeCurrentStream(); + + while (stateHandleIterator.hasNext()) { + currentStateHandle = stateHandleIterator.next(); + if (currentStateHandle.getKeyGroupRange().getNumberOfKeyGroups() > 0) { + currentOffsetsIterator = currentStateHandle.getGroupRangeOffsets().iterator(); + + return true; + } + } + + return false; + } + + @Override + public KeyGroupStatePartitionStreamProvider next() { + + if (!hasNext()) { + + throw new NoSuchElementException("Iterator exhausted"); + } + + Tuple2 keyGroupOffset = currentOffsetsIterator.next(); + try { + if (null == currentStream) { + openCurrentStream(); + } + + currentStream.seek(keyGroupOffset.f1); + return new KeyGroupStatePartitionStreamProvider(currentStream, keyGroupOffset.f0); + + } catch (IOException ioex) { + return new KeyGroupStatePartitionStreamProvider(ioex, keyGroupOffset.f0); + } + } + } + + private static class OperatorStateStreamIterator + extends AbstractStateStreamIterator { + + private final String stateName; //TODO since we only support a single named state in raw, this could be dropped + private long[] offsets; + private int offPos; + + OperatorStateStreamIterator( + String stateName, + Iterator stateHandleIterator, + CloseableRegistry closableRegistry) { + + super(stateHandleIterator, closableRegistry); + this.stateName = Preconditions.checkNotNull(stateName); + } + + @Override + public boolean hasNext() { + + if (null != offsets && offPos < offsets.length) { + + return true; + } + + closeCurrentStream(); + + while (stateHandleIterator.hasNext()) { + currentStateHandle = stateHandleIterator.next(); + OperatorStateHandle.StateMetaInfo metaInfo = + currentStateHandle.getStateNameToPartitionOffsets().get(stateName); + + if (null != metaInfo) { + long[] metaOffsets = metaInfo.getOffsets(); + if (null != metaOffsets && metaOffsets.length > 0) { + this.offsets = metaOffsets; + this.offPos = 0; + + if (closableRegistry.unregisterCloseable(currentStream)) { + IOUtils.closeQuietly(currentStream); + currentStream = null; + } + + return true; + } + } + } + + return false; + } + + @Override + public StatePartitionStreamProvider next() { + + if (!hasNext()) { + + throw new NoSuchElementException("Iterator exhausted"); + } + + long offset = offsets[offPos++]; + + try { + if (null == currentStream) { + openCurrentStream(); + } + + currentStream.seek(offset); + return new StatePartitionStreamProvider(currentStream); + + } catch (IOException ioex) { + return new StatePartitionStreamProvider(ioex); + } + } + } + + private abstract static class AbstractStateStreamIterator< + T extends StatePartitionStreamProvider, H extends StreamStateHandle> + implements Iterator { + + protected final Iterator stateHandleIterator; + protected final CloseableRegistry closableRegistry; + + protected H currentStateHandle; + protected FSDataInputStream currentStream; + + AbstractStateStreamIterator( + Iterator stateHandleIterator, + CloseableRegistry closableRegistry) { + + this.stateHandleIterator = Preconditions.checkNotNull(stateHandleIterator); + this.closableRegistry = Preconditions.checkNotNull(closableRegistry); + } + + protected void openCurrentStream() throws IOException { + + Preconditions.checkState(currentStream == null); + + FSDataInputStream stream = currentStateHandle.openInputStream(); + closableRegistry.registerCloseable(stream); + currentStream = stream; + } + + protected void closeCurrentStream() { + if (closableRegistry.unregisterCloseable(currentStream)) { + IOUtils.closeQuietly(currentStream); + } + currentStream = null; + } + + @Override + public void remove() { + throw new UnsupportedOperationException("Read only Iterator"); + } + } + + private static Collection transform(Collection keyedStateHandles) { + + if (keyedStateHandles == null) { + return null; + } + + List keyGroupsStateHandles = new ArrayList<>(keyedStateHandles.size()); + + for (KeyedStateHandle keyedStateHandle : keyedStateHandles) { + + if (keyedStateHandle instanceof KeyGroupsStateHandle) { + keyGroupsStateHandles.add((KeyGroupsStateHandle) keyedStateHandle); + } else if (keyedStateHandle != null) { + throw new IllegalStateException("Unexpected state handle type, " + + "expected: " + KeyGroupsStateHandle.class + + ", but found: " + keyedStateHandle.getClass() + "."); + } + } + + return keyGroupsStateHandles; + } + + private static class StreamOperatorStateContextImpl implements StreamOperatorStateContext { + + private final boolean restored; + + private final OperatorStateBackend operatorStateBackend; + private final AbstractKeyedStateBackend keyedStateBackend; + private final InternalTimeServiceManager internalTimeServiceManager; + + private final CloseableIterable rawOperatorStateInputs; + private final CloseableIterable rawKeyedStateInputs; + + private final CheckpointStreamFactory checkpointStreamFactory; + + StreamOperatorStateContextImpl( + boolean restored, + OperatorStateBackend operatorStateBackend, + AbstractKeyedStateBackend keyedStateBackend, + InternalTimeServiceManager internalTimeServiceManager, + CloseableIterable rawOperatorStateInputs, + CloseableIterable rawKeyedStateInputs, + CheckpointStreamFactory checkpointStreamFactory) { + + this.restored = restored; + this.operatorStateBackend = operatorStateBackend; + this.keyedStateBackend = keyedStateBackend; + this.internalTimeServiceManager = internalTimeServiceManager; + this.rawOperatorStateInputs = rawOperatorStateInputs; + this.rawKeyedStateInputs = rawKeyedStateInputs; + this.checkpointStreamFactory = checkpointStreamFactory; + } + + @Override + public boolean isRestored() { + return restored; + } + + @Override + public AbstractKeyedStateBackend keyedStateBackend() { + return keyedStateBackend; + } + + @Override + public OperatorStateBackend operatorStateBackend() { + return operatorStateBackend; + } + + @Override + public InternalTimeServiceManager internalTimerServiceManager() { + return internalTimeServiceManager; + } + + @Override + public CheckpointStreamFactory checkpointStreamFactory() { + return checkpointStreamFactory; + } + + @Override + public CloseableIterable rawOperatorStateInputs() { + return rawOperatorStateInputs; + } + + @Override + public CloseableIterable rawKeyedStateInputs() { + return rawKeyedStateInputs; + } + } + +} diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/GenericWriteAheadSink.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/GenericWriteAheadSink.java index 370d021dc32fd..fbf9bef97f908 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/GenericWriteAheadSink.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/GenericWriteAheadSink.java @@ -47,8 +47,8 @@ /** * Generic Sink that emits its input elements into an arbitrary backend. This sink is integrated with Flink's checkpointing * mechanism and can provide exactly-once guarantees; depending on the storage backend and sink/committer implementation. - *

- * Incoming records are stored within a {@link org.apache.flink.runtime.state.AbstractStateBackend}, and only committed if a + * + *

Incoming records are stored within a {@link org.apache.flink.runtime.state.AbstractStateBackend}, and only committed if a * checkpoint is completed. * * @param Type of the elements emitted by this sink @@ -204,8 +204,8 @@ private void cleanRestoredHandles() throws Exception { } @Override - public void notifyOfCompletedCheckpoint(long checkpointId) throws Exception { - super.notifyOfCompletedCheckpoint(checkpointId); + public void notifyCheckpointComplete(long checkpointId) throws Exception { + super.notifyCheckpointComplete(checkpointId); synchronized (pendingCheckpoints) { diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java index 36e67485650da..f9c11dacc4568 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java @@ -21,7 +21,6 @@ import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.api.common.TaskInfo; import org.apache.flink.api.common.accumulators.Accumulator; -import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.core.fs.CloseableRegistry; import org.apache.flink.core.fs.FileSystemSafetyNet; import org.apache.flink.runtime.checkpoint.CheckpointMetaData; @@ -30,27 +29,25 @@ import org.apache.flink.runtime.checkpoint.OperatorSubtaskState; import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.execution.CancelTaskException; -import org.apache.flink.runtime.execution.Environment; import org.apache.flink.runtime.io.network.api.CancelCheckpointMarker; import org.apache.flink.runtime.io.network.api.serialization.EventSerializer; import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter; import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable; import org.apache.flink.runtime.jobgraph.tasks.StatefulTask; -import org.apache.flink.runtime.state.AbstractKeyedStateBackend; import org.apache.flink.runtime.state.AbstractStateBackend; import org.apache.flink.runtime.state.CheckpointStreamFactory; -import org.apache.flink.runtime.state.KeyGroupRange; -import org.apache.flink.runtime.state.KeyedStateHandle; -import org.apache.flink.runtime.state.OperatorStateBackend; -import org.apache.flink.runtime.state.OperatorStateHandle; import org.apache.flink.runtime.state.StateBackend; +import org.apache.flink.runtime.state.TaskStateManager; import org.apache.flink.runtime.taskmanager.DispatcherThreadFactory; +import org.apache.flink.runtime.util.OperatorSubtaskDescriptionText; import org.apache.flink.streaming.api.TimeCharacteristic; import org.apache.flink.streaming.api.graph.StreamConfig; import org.apache.flink.streaming.api.operators.OperatorSnapshotResult; import org.apache.flink.streaming.api.operators.Output; import org.apache.flink.streaming.api.operators.StreamOperator; +import org.apache.flink.streaming.api.operators.StreamTaskStateManager; +import org.apache.flink.streaming.api.operators.StreamTaskStateManagerImpl; import org.apache.flink.streaming.runtime.io.RecordWriterOutput; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.streaming.runtime.streamstatus.StreamStatusMaintainer; @@ -63,7 +60,6 @@ import java.io.Closeable; import java.io.IOException; -import java.util.Collection; import java.util.HashMap; import java.util.Map; import java.util.concurrent.ExecutorService; @@ -137,26 +133,24 @@ public abstract class StreamTask> protected OperatorChain operatorChain; /** The configuration of this streaming task. */ - private StreamConfig configuration; + protected StreamConfig configuration; /** Our state backend. We use this to create checkpoint streams and a keyed state backend. */ - private StateBackend stateBackend; + protected StateBackend stateBackend; - /** Keyed state backend for the head operator, if it is keyed. There can only ever be one. */ - private AbstractKeyedStateBackend keyedStateBackend; + /** State manager for this stream task. */ + protected StreamTaskStateManager streamTaskStateManager; /** * The internal {@link ProcessingTimeService} used to define the current * processing time (default = {@code System.currentTimeMillis()}) and * register timers for tasks to be executed in the future. */ - private ProcessingTimeService timerService; + protected ProcessingTimeService timerService; /** The map of user-defined accumulators of this task. */ private Map> accumulatorMap; - private TaskStateSnapshot taskStateSnapshot; - /** The currently active background materialization threads. */ private final CloseableRegistry cancelables = new CloseableRegistry(); @@ -209,6 +203,13 @@ public void setProcessingTimeService(ProcessingTimeService timeProvider) { timerService = timeProvider; } + protected StreamTaskStateManager createStreamTaskStateManager() { + return new StreamTaskStateManagerImpl( + getEnvironment(), + stateBackend, + timerService); + } + @Override public final void invoke() throws Exception { @@ -241,6 +242,8 @@ public final void invoke() throws Exception { timerService = new SystemProcessingTimeService(this, getCheckpointLock(), timerThreadFactory); } + streamTaskStateManager = createStreamTaskStateManager(); + operatorChain = new OperatorChain<>(this); headOperator = operatorChain.getHeadOperator(); @@ -512,6 +515,10 @@ public StreamStatusMaintainer getStreamStatusMaintainer() { return operatorChain; } + public StreamTaskStateManager getStreamTaskStateManager() { + return streamTaskStateManager; + } + Output> getHeadOutput() { return operatorChain.getChainEntryPoint(); } @@ -524,11 +531,6 @@ RecordWriterOutput[] getStreamOutputs() { // Checkpoint and Restore // ------------------------------------------------------------------------ - @Override - public void setInitialState(TaskStateSnapshot taskStateHandles) { - this.taskStateSnapshot = taskStateHandles; - } - @Override public boolean triggerCheckpoint(CheckpointMetaData checkpointMetaData, CheckpointOptions checkpointOptions) throws Exception { try { @@ -649,7 +651,7 @@ public void notifyCheckpointComplete(long checkpointId) throws Exception { for (StreamOperator operator : operatorChain.getAllOperators()) { if (operator != null) { - operator.notifyOfCompletedCheckpoint(checkpointId); + operator.notifyCheckpointComplete(checkpointId); } } } @@ -675,26 +677,11 @@ private void checkpointState( private void initializeState() throws Exception { - boolean restored = null != taskStateSnapshot; - - if (restored) { - initializeOperators(true); - taskStateSnapshot = null; // free for GC - } else { - initializeOperators(false); - } - } - - private void initializeOperators(boolean restored) throws Exception { StreamOperator[] allOperators = operatorChain.getAllOperators(); - for (int chainIdx = 0; chainIdx < allOperators.length; ++chainIdx) { - StreamOperator operator = allOperators[chainIdx]; + + for (StreamOperator operator : allOperators) { if (null != operator) { - if (restored && taskStateSnapshot != null) { - operator.initializeState(taskStateSnapshot.getSubtaskStateByOperatorID(operator.getOperatorID())); - } else { - operator.initializeState(null); - } + operator.initializeState(); } } } @@ -719,82 +706,24 @@ private StateBackend createStateBackend() throws Exception { } } - public OperatorStateBackend createOperatorStateBackend( - StreamOperator op, Collection restoreStateHandles) throws Exception { - - Environment env = getEnvironment(); - String opId = createOperatorIdentifier(op, getConfiguration().getVertexID()); - - OperatorStateBackend operatorStateBackend = stateBackend.createOperatorStateBackend(env, opId); - - // let operator state backend participate in the operator lifecycle, i.e. make it responsive to cancelation - cancelables.registerCloseable(operatorStateBackend); - - // restore if we have some old state - if (null != restoreStateHandles) { - operatorStateBackend.restore(restoreStateHandles); - } - - return operatorStateBackend; - } - - public AbstractKeyedStateBackend createKeyedStateBackend( - TypeSerializer keySerializer, - int numberOfKeyGroups, - KeyGroupRange keyGroupRange) throws Exception { - - if (keyedStateBackend != null) { - throw new RuntimeException("The keyed state backend can only be created once."); - } - - String operatorIdentifier = createOperatorIdentifier( - headOperator, - configuration.getVertexID()); - - keyedStateBackend = stateBackend.createKeyedStateBackend( - getEnvironment(), - getEnvironment().getJobID(), - operatorIdentifier, - keySerializer, - numberOfKeyGroups, - keyGroupRange, - getEnvironment().getTaskKvStateRegistry()); - - // let keyed state backend participate in the operator lifecycle, i.e. make it responsive to cancelation - cancelables.registerCloseable(keyedStateBackend); - - // restore if we have some old state - Collection restoreKeyedStateHandles = null; - - if (taskStateSnapshot != null) { - OperatorSubtaskState stateByOperatorID = - taskStateSnapshot.getSubtaskStateByOperatorID(headOperator.getOperatorID()); - restoreKeyedStateHandles = stateByOperatorID != null ? stateByOperatorID.getManagedKeyedState() : null; - } - - keyedStateBackend.restore(restoreKeyedStateHandles); - - @SuppressWarnings("unchecked") - AbstractKeyedStateBackend typedBackend = (AbstractKeyedStateBackend) keyedStateBackend; - return typedBackend; - } - /** * This is only visible because * {@link org.apache.flink.streaming.runtime.operators.GenericWriteAheadSink} uses the * checkpoint stream factory to write write-ahead logs. This should not be used for * anything else. */ - public CheckpointStreamFactory createCheckpointStreamFactory(StreamOperator operator) throws IOException { + public CheckpointStreamFactory createCheckpointStreamFactory( + StreamOperator operator) throws IOException { return stateBackend.createStreamFactory( - getEnvironment().getJobID(), - createOperatorIdentifier(operator, configuration.getVertexID())); + getEnvironment().getJobID(), + createOperatorIdentifier(operator)); } - public CheckpointStreamFactory createSavepointStreamFactory(StreamOperator operator, String targetLocation) throws IOException { + public CheckpointStreamFactory createSavepointStreamFactory( + StreamOperator operator, String targetLocation) throws IOException { return stateBackend.createSavepointStreamFactory( getEnvironment().getJobID(), - createOperatorIdentifier(operator, configuration.getVertexID()), + createOperatorIdentifier(operator), targetLocation); } @@ -802,12 +731,13 @@ protected CheckpointExceptionHandlerFactory createCheckpointExceptionHandlerFact return new CheckpointExceptionHandlerFactory(); } - private String createOperatorIdentifier(StreamOperator operator, int vertexId) { - + private String createOperatorIdentifier(StreamOperator operator) { TaskInfo taskInfo = getEnvironment().getTaskInfo(); - return operator.getClass().getSimpleName() + - "_" + operator.getOperatorID() + - "_(" + taskInfo.getIndexOfThisSubtask() + "/" + taskInfo.getNumberOfParallelSubtasks() + ")"; + return new OperatorSubtaskDescriptionText( + operator.getOperatorID(), + operator.getClass(), + taskInfo.getIndexOfThisSubtask(), + taskInfo.getNumberOfParallelSubtasks()).toString(); } /** @@ -884,7 +814,9 @@ private static final class AsyncCheckpointRunnable implements Runnable, Closeabl @Override public void run() { FileSystemSafetyNet.initializeSafetyNetForThread(); + final long checkpointId = checkpointMetaData.getCheckpointId(); try { + boolean hasState = false; final TaskStateSnapshot taskOperatorSubtaskStates = new TaskStateSnapshot(operatorSnapshotsInProgress.size()); @@ -906,20 +838,22 @@ public void run() { } final long asyncEndNanos = System.nanoTime(); - final long asyncDurationMillis = (asyncEndNanos - asyncStartNanos) / 1_000_000; + final long asyncDurationMillis = (asyncEndNanos - asyncStartNanos) / 1_000_000L; checkpointMetrics.setAsyncDurationMillis(asyncDurationMillis); if (asyncCheckpointState.compareAndSet(CheckpointingOperation.AsynCheckpointState.RUNNING, - CheckpointingOperation.AsynCheckpointState.COMPLETED)) { + CheckpointingOperation.AsynCheckpointState.COMPLETED)) { TaskStateSnapshot acknowledgedState = hasState ? taskOperatorSubtaskStates : null; + TaskStateManager taskStateManager = owner.getEnvironment().getTaskStateManager(); + // we signal stateless tasks by reporting null, so that there are no attempts to assign empty state // to stateless tasks on restore. This enables simple job modifications that only concern // stateless without the need to assign them uids to match their (always empty) states. - owner.getEnvironment().acknowledgeCheckpoint( - checkpointMetaData.getCheckpointId(), + taskStateManager.reportStateHandles( + checkpointMetaData, checkpointMetrics, acknowledgedState); @@ -935,6 +869,7 @@ public void run() { checkpointMetaData.getCheckpointId()); } } catch (Exception e) { + e.printStackTrace(); // the state is completed if an exception occurred in the acknowledgeCheckpoint call // in order to clean up, we have to set it to RUNNING again. asyncCheckpointState.compareAndSet( @@ -948,7 +883,7 @@ public void run() { } Exception checkpointException = new Exception( - "Could not materialize checkpoint " + checkpointMetaData.getCheckpointId() + " for operator " + + "Could not materialize checkpoint " + checkpointId + " for operator " + owner.getName() + '.', e); diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/source/InputFormatSourceFunctionTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/source/InputFormatSourceFunctionTest.java index b99119edbeb7e..e65d4c991864c 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/source/InputFormatSourceFunctionTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/source/InputFormatSourceFunctionTest.java @@ -31,6 +31,7 @@ import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider; import org.apache.flink.runtime.memory.MemoryManager; import org.apache.flink.runtime.operators.testutils.MockEnvironment; +import org.apache.flink.runtime.state.TestTaskStateManager; import org.apache.flink.streaming.api.operators.AbstractStreamOperator; import org.apache.flink.streaming.api.operators.StreamingRuntimeContext; import org.apache.flink.streaming.api.watermark.Watermark; @@ -258,7 +259,12 @@ private static class MockRuntimeContext extends StreamingRuntimeContext { private MockRuntimeContext(LifeCycleTestInputFormat format, int noOfSplits) { super(new MockStreamOperator(), - new MockEnvironment("no", 4 * MemoryManager.DEFAULT_PAGE_SIZE, null, 16), + new MockEnvironment( + "no", + 4 * MemoryManager.DEFAULT_PAGE_SIZE, + null, + 16, + new TestTaskStateManager()), Collections.>emptyMap()); this.noOfSplits = noOfSplits; diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/AbstractStreamOperatorTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/AbstractStreamOperatorTest.java index 726256009678f..c4cfcfb83d7c7 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/AbstractStreamOperatorTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/AbstractStreamOperatorTest.java @@ -63,6 +63,7 @@ import static org.mockito.Matchers.anyBoolean; import static org.mockito.Matchers.anyLong; import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.doCallRealMethod; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.verify; import static org.powermock.api.mockito.PowerMockito.doReturn; @@ -578,6 +579,9 @@ public void testFailingBackendSnapshotMethod() throws Exception { AbstractStreamOperator operator = mock(AbstractStreamOperator.class); when(operator.snapshotState(anyLong(), anyLong(), any(CheckpointOptions.class))).thenCallRealMethod(); + doCallRealMethod().when(operator).close(); + doCallRealMethod().when(operator).dispose(); + // The amount of mocking in this test makes it necessary to make the // getCheckpointStreamFactory method visible for the test and to // overwrite its behaviour. @@ -588,10 +592,21 @@ public void testFailingBackendSnapshotMethod() throws Exception { RunnableFuture futureManagedOperatorStateHandle = mock(RunnableFuture.class); OperatorStateBackend operatorStateBackend = mock(OperatorStateBackend.class); - when(operatorStateBackend.snapshot(eq(checkpointId), eq(timestamp), eq(streamFactory), any(CheckpointOptions.class))).thenReturn(futureManagedOperatorStateHandle); + when(operatorStateBackend.snapshot( + eq(checkpointId), + eq(timestamp), + eq(streamFactory), + any(CheckpointOptions.class))).thenReturn(futureManagedOperatorStateHandle); AbstractKeyedStateBackend keyedStateBackend = mock(AbstractKeyedStateBackend.class); - when(keyedStateBackend.snapshot(eq(checkpointId), eq(timestamp), eq(streamFactory), eq(CheckpointOptions.forCheckpoint()))).thenThrow(failingException); + when(keyedStateBackend.snapshot( + eq(checkpointId), + eq(timestamp), + eq(streamFactory), + eq(CheckpointOptions.forCheckpoint()))).thenThrow(failingException); + + closeableRegistry.registerCloseable(operatorStateBackend); + closeableRegistry.registerCloseable(keyedStateBackend); Whitebox.setInternalState(operator, "operatorStateBackend", operatorStateBackend); Whitebox.setInternalState(operator, "keyedStateBackend", keyedStateBackend); @@ -612,6 +627,16 @@ public void testFailingBackendSnapshotMethod() throws Exception { verify(futureKeyedStateHandle).cancel(anyBoolean()); verify(futureOperatorStateHandle).cancel(anyBoolean()); verify(futureKeyedStateHandle).cancel(anyBoolean()); + + operator.close(); + + verify(operatorStateBackend).close(); + verify(keyedStateBackend).close(); + + operator.dispose(); + + verify(operatorStateBackend).dispose(); + verify(keyedStateBackend).dispose(); } /** diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperatorLifecycleTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperatorLifecycleTest.java index f16ea2d83e2d4..a887b05d73767 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperatorLifecycleTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperatorLifecycleTest.java @@ -83,10 +83,10 @@ public class AbstractUdfStreamOperatorLifecycleTest { "UDF::close"); private static final String ALL_METHODS_STREAM_OPERATOR = "[close[], dispose[], getChainingStrategy[], " + - "getMetricGroup[], getOperatorID[], initializeState[class org.apache.flink.runtime.checkpoint.OperatorSubtaskState], " + - "notifyOfCompletedCheckpoint[long], open[], setChainingStrategy[class " + - "org.apache.flink.streaming.api.operators.ChainingStrategy], setKeyContextElement1[class " + - "org.apache.flink.streaming.runtime.streamrecord.StreamRecord], " + + "getCurrentKey[], getMetricGroup[], getOperatorID[], initializeState[], " + + "notifyCheckpointComplete[long], open[], setChainingStrategy[class " + + "org.apache.flink.streaming.api.operators.ChainingStrategy], setCurrentKey[class java.lang.Object], " + + "setKeyContextElement1[class org.apache.flink.streaming.runtime.streamrecord.StreamRecord], " + "setKeyContextElement2[class org.apache.flink.streaming.runtime.streamrecord.StreamRecord], " + "setup[class org.apache.flink.streaming.runtime.tasks.StreamTask, class " + "org.apache.flink.streaming.api.graph.StreamConfig, interface " + diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StateInitializationContextImplTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StateInitializationContextImplTest.java index 1ba2e776cf434..00e580401d06a 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StateInitializationContextImplTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StateInitializationContextImplTest.java @@ -18,8 +18,9 @@ package org.apache.flink.streaming.api.operators; +import org.apache.flink.api.common.JobID; import org.apache.flink.api.common.state.KeyedStateStore; -import org.apache.flink.api.common.state.OperatorStateStore; +import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.core.fs.CloseableRegistry; import org.apache.flink.core.fs.FSDataInputStream; import org.apache.flink.core.memory.ByteArrayOutputStreamWithPos; @@ -27,6 +28,13 @@ import org.apache.flink.core.memory.DataInputViewStreamWrapper; import org.apache.flink.core.memory.DataOutputView; import org.apache.flink.core.memory.DataOutputViewStreamWrapper; +import org.apache.flink.runtime.checkpoint.JobManagerTaskRestore; +import org.apache.flink.runtime.checkpoint.OperatorSubtaskState; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; +import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; +import org.apache.flink.runtime.jobgraph.OperatorID; +import org.apache.flink.runtime.operators.testutils.DummyEnvironment; +import org.apache.flink.runtime.state.AbstractKeyedStateBackend; import org.apache.flink.runtime.state.DefaultOperatorStateBackend; import org.apache.flink.runtime.state.KeyGroupRange; import org.apache.flink.runtime.state.KeyGroupRangeOffsets; @@ -34,10 +42,17 @@ import org.apache.flink.runtime.state.KeyGroupsStateHandle; import org.apache.flink.runtime.state.KeyedStateHandle; import org.apache.flink.runtime.state.OperatorStateHandle; +import org.apache.flink.runtime.state.StateBackend; import org.apache.flink.runtime.state.StateInitializationContextImpl; import org.apache.flink.runtime.state.StatePartitionStreamProvider; +import org.apache.flink.runtime.state.TaskLocalStateStore; +import org.apache.flink.runtime.state.TaskStateManager; +import org.apache.flink.runtime.state.TaskStateManagerImpl; import org.apache.flink.runtime.state.memory.ByteStreamStateHandle; +import org.apache.flink.runtime.state.memory.MemoryStateBackend; +import org.apache.flink.runtime.taskmanager.CheckpointResponder; import org.apache.flink.runtime.util.LongArrayList; +import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService; import org.apache.flink.util.Preconditions; import org.junit.Assert; @@ -47,6 +62,7 @@ import java.io.IOException; import java.io.InputStream; import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -54,6 +70,7 @@ import java.util.Set; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; /** * Tests for {@link StateInitializationContextImpl}. @@ -75,7 +92,6 @@ public void setUp() throws Exception { this.writtenOperatorStates = new HashSet<>(); this.closableRegistry = new CloseableRegistry(); - OperatorStateStore stateStore = mock(OperatorStateStore.class); ByteArrayOutputStreamWithPos out = new ByteArrayOutputStreamWithPos(64); @@ -124,14 +140,69 @@ public void setUp() throws Exception { operatorStateHandles.add(operatorStateHandle); } + OperatorSubtaskState operatorSubtaskState = new OperatorSubtaskState( + Collections.emptyList(), + operatorStateHandles, + Collections.emptyList(), + keyedStateHandles); + + OperatorID operatorID = new OperatorID(); + TaskStateSnapshot taskStateSnapshot = new TaskStateSnapshot(); + taskStateSnapshot.putSubtaskStateByOperatorID(operatorID, operatorSubtaskState); + + JobManagerTaskRestore jobManagerTaskRestore = new JobManagerTaskRestore(0L, taskStateSnapshot); + + TaskStateManager manager = new TaskStateManagerImpl( + new JobID(), + new ExecutionAttemptID(), + mock(TaskLocalStateStore.class), + jobManagerTaskRestore, + mock(CheckpointResponder.class)); + + DummyEnvironment environment = new DummyEnvironment( + "test", + 1, + 0, + prev); + + environment.setTaskStateManager(manager); + + StateBackend stateBackend = new MemoryStateBackend(1024); + StreamTaskStateManager streamTaskStateManager = new StreamTaskStateManagerImpl( + environment, + stateBackend, + mock(ProcessingTimeService.class)) { + + @Override + protected InternalTimeServiceManager internalTimeServiceManager( + AbstractKeyedStateBackend keyedStatedBackend, + KeyContext keyContext, + Iterable rawKeyedStates) throws Exception { + + // We do not initialize a timer service manager here, because it would already consume the raw keyed + // state as part of initialization. For the purpose of this test, we want an unconsumed raw keyed + // stream. + return null; + } + }; + + AbstractStreamOperator mockOperator = mock(AbstractStreamOperator.class); + when(mockOperator.getOperatorID()).thenReturn(operatorID); + + StreamOperatorStateContext stateContext = streamTaskStateManager.streamOperatorStateContext( + mockOperator, + // notice that this essentially disables the previous test of the keyed stream because it was and is always + // consumed by the timer service. + mock(TypeSerializer.class), + closableRegistry); + this.initializationContext = new StateInitializationContextImpl( - true, - stateStore, + stateContext.isRestored(), + stateContext.operatorStateBackend(), mock(KeyedStateStore.class), - keyedStateHandles, - operatorStateHandles, - closableRegistry); + stateContext.rawKeyedStateInputs(), + stateContext.rawOperatorStateInputs()); } @Test @@ -213,7 +284,7 @@ public void close() throws Exception { Assert.assertNotNull(stateStreamProvider); if (count == stopCount) { - initializationContext.close(); + closableRegistry.close(); isClosed = true; } diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamOperatorSnapshotRestoreTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamOperatorSnapshotRestoreTest.java index 50dc4d485a8e5..55361abf76f5e 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamOperatorSnapshotRestoreTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamOperatorSnapshotRestoreTest.java @@ -29,15 +29,19 @@ import org.apache.flink.core.memory.DataInputViewStreamWrapper; import org.apache.flink.core.memory.DataOutputView; import org.apache.flink.core.memory.DataOutputViewStreamWrapper; +import org.apache.flink.runtime.execution.Environment; +import org.apache.flink.runtime.state.AbstractKeyedStateBackend; import org.apache.flink.runtime.state.KeyGroupStatePartitionStreamProvider; import org.apache.flink.runtime.state.KeyedStateCheckpointOutputStream; import org.apache.flink.runtime.state.OperatorStateCheckpointOutputStream; +import org.apache.flink.runtime.state.StateBackend; import org.apache.flink.runtime.state.StateInitializationContext; import org.apache.flink.runtime.state.StatePartitionStreamProvider; import org.apache.flink.runtime.state.StateSnapshotContext; import org.apache.flink.streaming.api.watermark.Watermark; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.streaming.runtime.tasks.OperatorStateHandles; +import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService; import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness; import org.junit.Assert; @@ -73,7 +77,6 @@ public Integer getKey(Integer value) throws Exception { MAX_PARALLELISM, 1 /* num subtasks */, 0 /* subtask index */); - testHarness.open(); for (int i = 0; i < 10; ++i) { @@ -87,7 +90,7 @@ public Integer getKey(Integer value) throws Exception { //-------------------------------------------------------------------------- restore op = new TestOneInputStreamOperator(true); - testHarness = new KeyedOneInputStreamOperatorTestHarness<>( + testHarness = new KeyedOneInputStreamOperatorTestHarness( op, new KeySelector() { @Override @@ -98,7 +101,26 @@ public Integer getKey(Integer value) throws Exception { TypeInformation.of(Integer.class), MAX_PARALLELISM, 1 /* num subtasks */, - 0 /* subtask index */); + 0 /* subtask index */) { + + @Override + protected StreamTaskStateManager createStreamTaskStateManager( + Environment env, + StateBackend stateBackend, + ProcessingTimeService processingTimeService) { + + return new StreamTaskStateManagerImpl(env, stateBackend, processingTimeService) { + @Override + protected InternalTimeServiceManager internalTimeServiceManager( + AbstractKeyedStateBackend keyedStatedBackend, + KeyContext keyContext, + Iterable rawKeyedStates) throws Exception { + + return null; + } + }; + } + }; testHarness.initializeState(handles); diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamTaskStateManagerImplTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamTaskStateManagerImplTest.java new file mode 100644 index 0000000000000..ddc037c650ad8 --- /dev/null +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamTaskStateManagerImplTest.java @@ -0,0 +1,323 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.api.operators; + +import org.apache.flink.api.common.JobID; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.common.typeutils.base.IntSerializer; +import org.apache.flink.core.fs.CloseableRegistry; +import org.apache.flink.runtime.checkpoint.JobManagerTaskRestore; +import org.apache.flink.runtime.checkpoint.OperatorSubtaskState; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; +import org.apache.flink.runtime.checkpoint.savepoint.CheckpointTestUtils; +import org.apache.flink.runtime.execution.Environment; +import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; +import org.apache.flink.runtime.jobgraph.OperatorID; +import org.apache.flink.runtime.operators.testutils.DummyEnvironment; +import org.apache.flink.runtime.query.TaskKvStateRegistry; +import org.apache.flink.runtime.state.AbstractKeyedStateBackend; +import org.apache.flink.runtime.state.CheckpointStreamFactory; +import org.apache.flink.runtime.state.KeyGroupRange; +import org.apache.flink.runtime.state.KeyGroupStatePartitionStreamProvider; +import org.apache.flink.runtime.state.OperatorStateBackend; +import org.apache.flink.runtime.state.OperatorStateHandle; +import org.apache.flink.runtime.state.StateBackend; +import org.apache.flink.runtime.state.StatePartitionStreamProvider; +import org.apache.flink.runtime.state.TaskStateManager; +import org.apache.flink.runtime.state.TaskStateManagerImplTest; +import org.apache.flink.runtime.state.memory.MemoryStateBackend; +import org.apache.flink.runtime.taskmanager.TestCheckpointResponder; +import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService; +import org.apache.flink.streaming.runtime.tasks.TestProcessingTimeService; +import org.apache.flink.util.CloseableIterable; + +import org.junit.Assert; +import org.junit.Test; + +import javax.annotation.Nullable; + +import java.io.Closeable; +import java.io.IOException; +import java.util.Collections; +import java.util.Random; + +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyInt; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** + * Test for {@link StreamTaskStateManagerImpl}. + */ +public class StreamTaskStateManagerImplTest { + + @Test + public void testNoRestore() throws Exception { + + MemoryStateBackend stateBackend = spy(new MemoryStateBackend(1024)); + + // No job manager provided state to restore + StreamTaskStateManager streamTaskStateManager = streamTaskStateManager(stateBackend, null, true); + + OperatorID operatorID = new OperatorID(47L, 11L); + AbstractStreamOperator streamOperator = mock(AbstractStreamOperator.class); + when(streamOperator.getOperatorID()).thenReturn(operatorID); + + TypeSerializer typeSerializer = new IntSerializer(); + CloseableRegistry closeableRegistry = new CloseableRegistry(); + + StreamOperatorStateContext stateContext = streamTaskStateManager.streamOperatorStateContext( + streamOperator, + typeSerializer, + closeableRegistry); + + verify(stateBackend).createKeyedStateBackend( + any(Environment.class), + any(JobID.class), + any(String.class), + eq(typeSerializer), + anyInt(), + any(KeyGroupRange.class), + any(TaskKvStateRegistry.class)); + + verify(stateBackend).createOperatorStateBackend( + any(Environment.class), + any(String.class)); + + verify(stateBackend).createStreamFactory( + any(JobID.class), + any(String.class)); + + OperatorStateBackend operatorStateBackend = stateContext.operatorStateBackend(); + AbstractKeyedStateBackend keyedStateBackend = stateContext.keyedStateBackend(); + InternalTimeServiceManager timeServiceManager = stateContext.internalTimerServiceManager(); + CheckpointStreamFactory streamFactory = stateContext.checkpointStreamFactory(); + CloseableIterable keyedStateInputs = stateContext.rawKeyedStateInputs(); + CloseableIterable operatorStateInputs = stateContext.rawOperatorStateInputs(); + + Assert.assertEquals(false, stateContext.isRestored()); + Assert.assertNotNull(operatorStateBackend); + Assert.assertNotNull(keyedStateBackend); + Assert.assertNotNull(timeServiceManager); + Assert.assertNotNull(streamFactory); + Assert.assertNotNull(keyedStateInputs); + Assert.assertNotNull(operatorStateInputs); + + checkCloseablesRegistered( + closeableRegistry, + operatorStateBackend, + keyedStateBackend, + streamFactory, + keyedStateInputs, + operatorStateInputs); + + for (KeyGroupStatePartitionStreamProvider keyedStateInput : keyedStateInputs) { + Assert.fail(); + } + + for (StatePartitionStreamProvider operatorStateInput : operatorStateInputs) { + Assert.fail(); + } + } + + @SuppressWarnings("unchecked") + @Test + public void testWithRestore() throws Exception { + + StateBackend mockingBackend = spy(new StateBackend() { + @Override + public CheckpointStreamFactory createStreamFactory( + JobID jobId, String operatorIdentifier) throws IOException { + return mock(CheckpointStreamFactory.class); + } + + @Override + public CheckpointStreamFactory createSavepointStreamFactory( + JobID jobId, String operatorIdentifier, @Nullable String targetLocation) throws IOException { + return mock(CheckpointStreamFactory.class); + } + + @Override + public AbstractKeyedStateBackend createKeyedStateBackend( + Environment env, + JobID jobID, + String operatorIdentifier, + TypeSerializer keySerializer, + int numberOfKeyGroups, KeyGroupRange keyGroupRange, + TaskKvStateRegistry kvStateRegistry) throws Exception { + return mock(AbstractKeyedStateBackend.class); + } + + @Override + public OperatorStateBackend createOperatorStateBackend( + Environment env, String operatorIdentifier) throws Exception { + return mock(OperatorStateBackend.class); + } + }); + + OperatorID operatorID = new OperatorID(47L, 11L); + TaskStateSnapshot taskStateSnapshot = new TaskStateSnapshot(); + + Random random = new Random(0x42); + + OperatorSubtaskState operatorSubtaskState = new OperatorSubtaskState( + new OperatorStateHandle( + Collections.singletonMap( + "a", + new OperatorStateHandle.StateMetaInfo( + new long[]{0, 10}, + OperatorStateHandle.Mode.SPLIT_DISTRIBUTE)), + CheckpointTestUtils.createDummyStreamStateHandle(random)), + new OperatorStateHandle( + Collections.singletonMap( + "_default_", + new OperatorStateHandle.StateMetaInfo( + new long[]{0, 20, 30}, + OperatorStateHandle.Mode.SPLIT_DISTRIBUTE)), + CheckpointTestUtils.createDummyStreamStateHandle(random)), + CheckpointTestUtils.createDummyKeyGroupStateHandle(random), + CheckpointTestUtils.createDummyKeyGroupStateHandle(random)); + + taskStateSnapshot.putSubtaskStateByOperatorID(operatorID, operatorSubtaskState); + + JobManagerTaskRestore jobManagerTaskRestore = new JobManagerTaskRestore(0L, taskStateSnapshot); + + StreamTaskStateManager streamTaskStateManager = + streamTaskStateManager(mockingBackend, jobManagerTaskRestore, false); + + AbstractStreamOperator streamOperator = mock(AbstractStreamOperator.class); + when(streamOperator.getOperatorID()).thenReturn(operatorID); + + TypeSerializer typeSerializer = new IntSerializer(); + CloseableRegistry closeableRegistry = new CloseableRegistry(); + + StreamOperatorStateContext stateContext = streamTaskStateManager.streamOperatorStateContext( + streamOperator, + typeSerializer, + closeableRegistry); + + verify(mockingBackend).createKeyedStateBackend( + any(Environment.class), + any(JobID.class), + any(String.class), + eq(typeSerializer), + anyInt(), + any(KeyGroupRange.class), + any(TaskKvStateRegistry.class)); + + verify(mockingBackend).createOperatorStateBackend( + any(Environment.class), + any(String.class)); + + verify(mockingBackend).createStreamFactory( + any(JobID.class), + any(String.class)); + + OperatorStateBackend operatorStateBackend = stateContext.operatorStateBackend(); + AbstractKeyedStateBackend keyedStateBackend = stateContext.keyedStateBackend(); + InternalTimeServiceManager timeServiceManager = stateContext.internalTimerServiceManager(); + CheckpointStreamFactory streamFactory = stateContext.checkpointStreamFactory(); + CloseableIterable keyedStateInputs = stateContext.rawKeyedStateInputs(); + CloseableIterable operatorStateInputs = stateContext.rawOperatorStateInputs(); + + Assert.assertEquals(true, stateContext.isRestored()); + + Assert.assertNotNull(operatorStateBackend); + Assert.assertNotNull(keyedStateBackend); + // this is deactivated on purpose so that it does not attempt to consume the raw keyed state. + Assert.assertNull(timeServiceManager); + Assert.assertNotNull(streamFactory); + Assert.assertNotNull(keyedStateInputs); + Assert.assertNotNull(operatorStateInputs); + + // check that the expected job manager state was restored + verify(operatorStateBackend).restore(eq(operatorSubtaskState.getManagedOperatorState())); + verify(keyedStateBackend).restore(eq(operatorSubtaskState.getManagedKeyedState())); + + int count = 0; + for (KeyGroupStatePartitionStreamProvider keyedStateInput : keyedStateInputs) { + ++count; + } + Assert.assertEquals(1, count); + + count = 0; + for (StatePartitionStreamProvider operatorStateInput : operatorStateInputs) { + ++count; + } + Assert.assertEquals(3, count); + + checkCloseablesRegistered( + closeableRegistry, + operatorStateBackend, + keyedStateBackend, + streamFactory, + keyedStateInputs, + operatorStateInputs); + } + + private static void checkCloseablesRegistered(CloseableRegistry closeableRegistry, Closeable... closeables) { + for (Closeable closeable : closeables) { + Assert.assertTrue(closeableRegistry.unregisterCloseable(closeable)); + } + } + + private StreamTaskStateManager streamTaskStateManager( + StateBackend stateBackend, + JobManagerTaskRestore jobManagerTaskRestore, + boolean createTimerServiceManager) { + + JobID jobID = new JobID(42L, 43L); + ExecutionAttemptID executionAttemptID = new ExecutionAttemptID(23L, 24L); + TestCheckpointResponder checkpointResponderMock = new TestCheckpointResponder(); + + TaskStateManager taskStateManager = TaskStateManagerImplTest.taskStateManager( + jobID, + executionAttemptID, + checkpointResponderMock, + jobManagerTaskRestore); + + DummyEnvironment dummyEnvironment = new DummyEnvironment("test-task", 1, 0); + dummyEnvironment.setTaskStateManager(taskStateManager); + + ProcessingTimeService processingTimeService = new TestProcessingTimeService(); + + if (createTimerServiceManager) { + return new StreamTaskStateManagerImpl( + dummyEnvironment, + stateBackend, + processingTimeService); + } else { + return new StreamTaskStateManagerImpl( + dummyEnvironment, + stateBackend, + processingTimeService) { + @Override + protected InternalTimeServiceManager internalTimeServiceManager( + AbstractKeyedStateBackend keyedStatedBackend, + KeyContext keyContext, + Iterable rawKeyedStates) throws Exception { + return null; + } + }; + } + } +} diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/async/AsyncWaitOperatorTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/async/AsyncWaitOperatorTest.java index 993bffb484c77..8d7f3538f8870 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/async/AsyncWaitOperatorTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/async/AsyncWaitOperatorTest.java @@ -35,8 +35,8 @@ import org.apache.flink.runtime.jobgraph.JobVertex; import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.metrics.groups.TaskMetricGroup; -import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider; import org.apache.flink.runtime.operators.testutils.UnregisteredTaskMetricsGroup; +import org.apache.flink.runtime.state.TestTaskStateManager; import org.apache.flink.runtime.taskmanager.TaskManagerRuntimeInfo; import org.apache.flink.runtime.util.TestingTaskManagerRuntimeInfo; import org.apache.flink.streaming.api.datastream.AsyncDataStream; @@ -53,7 +53,6 @@ import org.apache.flink.streaming.api.operators.async.queue.StreamRecordQueueEntry; import org.apache.flink.streaming.api.watermark.Watermark; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; -import org.apache.flink.streaming.runtime.tasks.AcknowledgeStreamMockEnvironment; import org.apache.flink.streaming.runtime.tasks.OneInputStreamTask; import org.apache.flink.streaming.runtime.tasks.OneInputStreamTaskTestHarness; import org.apache.flink.streaming.runtime.tasks.ProcessingTimeCallback; @@ -503,15 +502,10 @@ public void testStateSnapshotAndRestore() throws Exception { streamConfig.setStreamOperator(operator); streamConfig.setOperatorID(operatorID); - final AcknowledgeStreamMockEnvironment env = new AcknowledgeStreamMockEnvironment( - testHarness.jobConfig, - testHarness.taskConfig, - testHarness.getExecutionConfig(), - testHarness.memorySize, - new MockInputSplitProvider(), - testHarness.bufferSize); + final TestTaskStateManager taskStateManagerMock = testHarness.getTaskStateManager(); + taskStateManagerMock.setWaitForReportLatch(new OneShotLatch()); - testHarness.invoke(env); + testHarness.invoke(); testHarness.waitForTaskRunning(); final long initialTime = 0L; @@ -530,9 +524,9 @@ public void testStateSnapshotAndRestore() throws Exception { task.triggerCheckpoint(checkpointMetaData, CheckpointOptions.forCheckpoint()); - env.getCheckpointLatch().await(); + taskStateManagerMock.getWaitForReportLatch().await(); - assertEquals(checkpointId, env.getCheckpointId()); + assertEquals(checkpointId, taskStateManagerMock.getReportedCheckpointId()); LazyAsyncFunction.countDown(); @@ -541,11 +535,12 @@ public void testStateSnapshotAndRestore() throws Exception { // set the operator state from previous attempt into the restored one final OneInputStreamTask restoredTask = new OneInputStreamTask<>(); - TaskStateSnapshot subtaskStates = env.getCheckpointStateHandles(); - restoredTask.setInitialState(subtaskStates); + TaskStateSnapshot subtaskStates = taskStateManagerMock.getLastTaskStateSnapshot(); final OneInputStreamTaskTestHarness restoredTaskHarness = new OneInputStreamTaskTestHarness<>(restoredTask, BasicTypeInfo.INT_TYPE_INFO, BasicTypeInfo.INT_TYPE_INFO); + + restoredTaskHarness.setTaskStateSnapshot(checkpointId, subtaskStates); restoredTaskHarness.setupOutputForSingletonOperatorChain(); AsyncWaitOperator restoredOperator = new AsyncWaitOperator<>( @@ -978,5 +973,4 @@ public void asyncInvoke(IN input, ResultFuture resultFuture) throws Excepti // no op } } - } diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierBufferTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierBufferTest.java index baac79cc3a762..bc5d4f011c00e 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierBufferTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierBufferTest.java @@ -23,7 +23,6 @@ import org.apache.flink.runtime.checkpoint.CheckpointMetaData; import org.apache.flink.runtime.checkpoint.CheckpointMetrics; import org.apache.flink.runtime.checkpoint.CheckpointOptions; -import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.checkpoint.decline.CheckpointDeclineOnCancellationBarrierException; import org.apache.flink.runtime.checkpoint.decline.CheckpointDeclineSubsumedException; import org.apache.flink.runtime.io.disk.iomanager.IOManager; @@ -1483,11 +1482,6 @@ long getLastReportedBytesBufferedInAlignment() { return lastReportedBytesBufferedInAlignment; } - @Override - public void setInitialState(TaskStateSnapshot taskStateHandles) throws Exception { - throw new UnsupportedOperationException("should never be called"); - } - @Override public boolean triggerCheckpoint(CheckpointMetaData checkpointMetaData, CheckpointOptions checkpointOptions) throws Exception { throw new UnsupportedOperationException("should never be called"); diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierTrackerTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierTrackerTest.java index 9251ab5ab2f30..da2469b15fe9d 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierTrackerTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierTrackerTest.java @@ -22,7 +22,6 @@ import org.apache.flink.runtime.checkpoint.CheckpointMetaData; import org.apache.flink.runtime.checkpoint.CheckpointMetrics; import org.apache.flink.runtime.checkpoint.CheckpointOptions; -import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.io.network.api.CancelCheckpointMarker; import org.apache.flink.runtime.io.network.api.CheckpointBarrier; import org.apache.flink.runtime.io.network.buffer.Buffer; @@ -497,11 +496,6 @@ private CheckpointSequenceValidator(long... checkpointIDs) { this.checkpointIDs = checkpointIDs; } - @Override - public void setInitialState(TaskStateSnapshot taskStateHandles) throws Exception { - throw new UnsupportedOperationException("should never be called"); - } - @Override public boolean triggerCheckpoint(CheckpointMetaData checkpointMetaData, CheckpointOptions checkpointOptions) throws Exception { throw new UnsupportedOperationException("should never be called"); diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/StreamOperatorChainingTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/StreamOperatorChainingTest.java index 8d99acdf64f52..ff9c9ee8ca1ca 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/StreamOperatorChainingTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/StreamOperatorChainingTest.java @@ -26,6 +26,7 @@ import org.apache.flink.runtime.jobgraph.JobVertex; import org.apache.flink.runtime.operators.testutils.MockEnvironment; import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider; +import org.apache.flink.runtime.state.TestTaskStateManager; import org.apache.flink.streaming.api.collector.selector.OutputSelector; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.datastream.SplitStream; @@ -311,7 +312,12 @@ public void invoke(String value) throws Exception { private > StreamTask createMockTask(StreamConfig streamConfig, String taskName) { final Object checkpointLock = new Object(); - final Environment env = new MockEnvironment(taskName, 3 * 1024 * 1024, new MockInputSplitProvider(), 1024); + final Environment env = new MockEnvironment( + taskName, + 3 * 1024 * 1024, + new MockInputSplitProvider(), + 1024, + new TestTaskStateManager()); @SuppressWarnings("unchecked") StreamTask mockTask = mock(StreamTask.class); diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/AcknowledgeStreamMockEnvironment.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/AcknowledgeStreamMockEnvironment.java index c5983cadacb92..8941cc184eb2f 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/AcknowledgeStreamMockEnvironment.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/AcknowledgeStreamMockEnvironment.java @@ -24,6 +24,7 @@ import org.apache.flink.runtime.checkpoint.CheckpointMetrics; import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider; +import org.apache.flink.runtime.state.TaskStateManager; /** * Stream environment that allows to wait for checkpoint acknowledgement. @@ -39,8 +40,16 @@ public AcknowledgeStreamMockEnvironment( ExecutionConfig executionConfig, long memorySize, MockInputSplitProvider inputSplitProvider, - int bufferSize) { - super(jobConfig, taskConfig, executionConfig, memorySize, inputSplitProvider, bufferSize); + int bufferSize, + TaskStateManager taskStateManager) { + super( + jobConfig, + taskConfig, + executionConfig, + memorySize, + inputSplitProvider, + bufferSize, + taskStateManager); } public long getCheckpointId() { diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/CheckpointExceptionHandlerConfigurationTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/CheckpointExceptionHandlerConfigurationTest.java index 69f693505925d..70cbd79bcca9e 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/CheckpointExceptionHandlerConfigurationTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/CheckpointExceptionHandlerConfigurationTest.java @@ -22,6 +22,7 @@ import org.apache.flink.runtime.execution.Environment; import org.apache.flink.runtime.jobgraph.JobGraph; import org.apache.flink.runtime.operators.testutils.DummyEnvironment; +import org.apache.flink.runtime.state.TestTaskStateManager; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.streaming.api.functions.sink.DiscardingSink; import org.apache.flink.streaming.api.functions.source.SourceFunction; @@ -58,6 +59,7 @@ private void testConfigForwarding(boolean failOnException) throws Exception { final boolean expectedHandlerFlag = failOnException; DummyEnvironment environment = new DummyEnvironment("test", 1, 0); + environment.setTaskStateManager(new TestTaskStateManager()); environment.getExecutionConfig().setFailTaskOnCheckpointError(expectedHandlerFlag); final CheckpointExceptionHandlerFactory inspectingFactory = new CheckpointExceptionHandlerFactory() { diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java index c641aa80739a0..b6a6deaf115fe 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java @@ -28,6 +28,7 @@ import org.apache.flink.runtime.blob.PermanentBlobCache; import org.apache.flink.runtime.blob.TransientBlobCache; import org.apache.flink.runtime.broadcast.BroadcastVariableManager; +import org.apache.flink.runtime.checkpoint.JobManagerTaskRestore; import org.apache.flink.runtime.checkpoint.OperatorSubtaskState; import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.clusterframework.types.AllocationID; @@ -60,6 +61,7 @@ import org.apache.flink.runtime.state.OperatorStateHandle; import org.apache.flink.runtime.state.StateInitializationContext; import org.apache.flink.runtime.state.StreamStateHandle; +import org.apache.flink.runtime.state.TestTaskStateManager; import org.apache.flink.runtime.taskmanager.CheckpointResponder; import org.apache.flink.runtime.taskmanager.Task; import org.apache.flink.runtime.taskmanager.TaskManagerActions; @@ -225,6 +227,9 @@ private static Task createTask( streamConfig.setOperatorID(operatorID); TaskStateSnapshot stateSnapshot = new TaskStateSnapshot(); stateSnapshot.putSubtaskStateByOperatorID(operatorID, operatorSubtaskState); + + JobManagerTaskRestore taskRestore = new JobManagerTaskRestore(1L, stateSnapshot); + JobInformation jobInformation = new JobInformation( new JobID(), "test job name", @@ -244,6 +249,13 @@ private static Task createTask( BlobCacheService blobService = new BlobCacheService(mock(PermanentBlobCache.class), mock(TransientBlobCache.class)); + TestTaskStateManager taskStateManager = new TestTaskStateManager(); + taskStateManager.setReportedCheckpointId(taskRestore.getRestoreCheckpointId()); + taskStateManager.setTaskStateSnapshotsByCheckpointId( + Collections.singletonMap( + taskRestore.getRestoreCheckpointId(), + taskRestore.getTaskStateSnapshot())); + return new Task( jobInformation, taskInformation, @@ -254,11 +266,12 @@ private static Task createTask( Collections.emptyList(), Collections.emptyList(), 0, - stateSnapshot, + taskRestore, mock(MemoryManager.class), mock(IOManager.class), networkEnvironment, mock(BroadcastVariableManager.class), + taskStateManager, mock(TaskManagerActions.class), mock(InputSplitProvider.class), mock(CheckpointResponder.class), diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java index 1fdd922b58d69..c3e12f66d25df 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java @@ -27,15 +27,16 @@ import org.apache.flink.api.common.typeutils.base.StringSerializer; import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.configuration.Configuration; +import org.apache.flink.core.testutils.OneShotLatch; import org.apache.flink.runtime.checkpoint.CheckpointMetaData; import org.apache.flink.runtime.checkpoint.CheckpointOptions; import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.io.network.api.CancelCheckpointMarker; import org.apache.flink.runtime.io.network.api.CheckpointBarrier; import org.apache.flink.runtime.jobgraph.OperatorID; -import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider; import org.apache.flink.runtime.state.StateInitializationContext; import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.runtime.state.TestTaskStateManager; import org.apache.flink.streaming.api.graph.StreamConfig; import org.apache.flink.streaming.api.graph.StreamEdge; import org.apache.flink.streaming.api.graph.StreamNode; @@ -485,19 +486,15 @@ public void testSnapshottingAndRestoring() throws Exception { StreamConfig streamConfig = testHarness.getStreamConfig(); configureChainedTestingStreamOperator(streamConfig, numberChainedTasks); + TestTaskStateManager taskStateManager = testHarness.taskStateManager; + OneShotLatch waitForAcknowledgeLatch = new OneShotLatch(); - AcknowledgeStreamMockEnvironment env = new AcknowledgeStreamMockEnvironment( - testHarness.jobConfig, - testHarness.taskConfig, - testHarness.executionConfig, - testHarness.memorySize, - new MockInputSplitProvider(), - testHarness.bufferSize); + taskStateManager.setWaitForReportLatch(waitForAcknowledgeLatch); // reset number of restore calls TestingStreamOperator.numberRestoreCalls = 0; - testHarness.invoke(env); + testHarness.invoke(); testHarness.waitForTaskRunning(deadline.timeLeft().toMillis()); CheckpointMetaData checkpointMetaData = new CheckpointMetaData(checkpointId, checkpointTimestamp); @@ -507,9 +504,9 @@ public void testSnapshottingAndRestoring() throws Exception { // since no state was set, there shouldn't be restore calls assertEquals(0, TestingStreamOperator.numberRestoreCalls); - env.getCheckpointLatch().await(); + waitForAcknowledgeLatch.await(); - assertEquals(checkpointId, env.getCheckpointId()); + assertEquals(checkpointId, taskStateManager.getReportedCheckpointId()); testHarness.endInput(); testHarness.waitForTaskCompletion(deadline.timeLeft().toMillis()); @@ -520,17 +517,20 @@ public void testSnapshottingAndRestoring() throws Exception { new OneInputStreamTaskTestHarness(restoredTask, BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO); restoredTaskHarness.configureForKeyedStream(keySelector, BasicTypeInfo.STRING_TYPE_INFO); + restoredTaskHarness.setTaskStateSnapshot(checkpointId, taskStateManager.getLastTaskStateSnapshot()); + StreamConfig restoredTaskStreamConfig = restoredTaskHarness.getStreamConfig(); configureChainedTestingStreamOperator(restoredTaskStreamConfig, numberChainedTasks); - TaskStateSnapshot stateHandles = env.getCheckpointStateHandles(); + TaskStateSnapshot stateHandles = taskStateManager.getLastTaskStateSnapshot(); Assert.assertEquals(numberChainedTasks, stateHandles.getSubtaskStateMappings().size()); - restoredTask.setInitialState(stateHandles); - TestingStreamOperator.numberRestoreCalls = 0; + // transfer state to new harness + restoredTaskHarness.taskStateManager.restoreLatestCheckpointState( + taskStateManager.getTaskStateSnapshotsByCheckpointId()); restoredTaskHarness.invoke(); restoredTaskHarness.endInput(); restoredTaskHarness.waitForTaskCompletion(deadline.timeLeft().toMillis()); @@ -666,28 +666,6 @@ private static class TestingStreamOperator public static int numberRestoreCalls = 0; public static int numberSnapshotCalls = 0; - @Override - public void open() throws Exception { - super.open(); - - ListState partitionableState = getOperatorStateBackend().getListState(TEST_DESCRIPTOR); - - if (numberSnapshotCalls == 0) { - for (Integer v : partitionableState.get()) { - fail(); - } - } else { - Set result = new HashSet<>(); - for (Integer v : partitionableState.get()) { - result.add(v); - } - - assertEquals(2, result.size()); - assertTrue(result.contains(42)); - assertTrue(result.contains(4711)); - } - } - @Override public void snapshotState(StateSnapshotContext context) throws Exception { ListState partitionableState = @@ -705,6 +683,23 @@ public void initializeState(StateInitializationContext context) throws Exception if (context.isRestored()) { ++numberRestoreCalls; } + + ListState partitionableState = context.getOperatorStateStore().getListState(TEST_DESCRIPTOR); + + if (numberSnapshotCalls == 0) { + for (Integer v : partitionableState.get()) { + fail(); + } + } else { + Set result = new HashSet<>(); + for (Integer v : partitionableState.get()) { + result.add(v); + } + + assertEquals(2, result.size()); + assertTrue(result.contains(42)); + assertTrue(result.contains(4711)); + } } @Override diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/RestoreStreamTaskTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/RestoreStreamTaskTest.java index d54bba83a8154..94732b5cc521e 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/RestoreStreamTaskTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/RestoreStreamTaskTest.java @@ -23,8 +23,10 @@ import org.apache.flink.api.common.typeinfo.BasicTypeInfo; import org.apache.flink.api.common.typeutils.base.LongSerializer; import org.apache.flink.api.common.typeutils.base.StringSerializer; +import org.apache.flink.core.testutils.OneShotLatch; import org.apache.flink.runtime.checkpoint.CheckpointMetaData; import org.apache.flink.runtime.checkpoint.CheckpointOptions; +import org.apache.flink.runtime.checkpoint.JobManagerTaskRestore; import org.apache.flink.runtime.checkpoint.OperatorSubtaskState; import org.apache.flink.runtime.checkpoint.StateAssignmentOperation; import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; @@ -33,6 +35,7 @@ import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider; import org.apache.flink.runtime.state.StateInitializationContext; import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.runtime.state.TestTaskStateManager; import org.apache.flink.streaming.api.operators.AbstractStreamOperator; import org.apache.flink.streaming.api.operators.OneInputStreamOperator; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; @@ -66,25 +69,27 @@ public void setup() { @Test public void testRestore() throws Exception { + OperatorID headOperatorID = new OperatorID(42L, 42L); OperatorID tailOperatorID = new OperatorID(44L, 44L); - AcknowledgeStreamMockEnvironment environment1 = createRunAndCheckpointOperatorChain( + + JobManagerTaskRestore restore = createRunAndCheckpointOperatorChain( headOperatorID, new CounterOperator(), tailOperatorID, new CounterOperator(), Optional.empty()); - assertEquals(2, environment1.getCheckpointStateHandles().getSubtaskStateMappings().size()); + TaskStateSnapshot stateHandles = restore.getTaskStateSnapshot(); - TaskStateSnapshot stateHandles = environment1.getCheckpointStateHandles(); + assertEquals(2, stateHandles.getSubtaskStateMappings().size()); - AcknowledgeStreamMockEnvironment environment2 = createRunAndCheckpointOperatorChain( + createRunAndCheckpointOperatorChain( headOperatorID, new CounterOperator(), tailOperatorID, new CounterOperator(), - Optional.of(stateHandles)); + Optional.of(restore)); assertEquals( new HashSet() {{ @@ -96,24 +101,26 @@ public void testRestore() throws Exception { @Test public void testRestoreHeadWithNewId() throws Exception { + OperatorID tailOperatorID = new OperatorID(44L, 44L); - AcknowledgeStreamMockEnvironment environment1 = createRunAndCheckpointOperatorChain( + + JobManagerTaskRestore restore = createRunAndCheckpointOperatorChain( new OperatorID(42L, 42L), new CounterOperator(), tailOperatorID, new CounterOperator(), Optional.empty()); - assertEquals(2, environment1.getCheckpointStateHandles().getSubtaskStateMappings().size()); + TaskStateSnapshot stateHandles = restore.getTaskStateSnapshot(); - TaskStateSnapshot stateHandles = environment1.getCheckpointStateHandles(); + assertEquals(2, stateHandles.getSubtaskStateMappings().size()); - AcknowledgeStreamMockEnvironment environment2 = createRunAndCheckpointOperatorChain( + createRunAndCheckpointOperatorChain( new OperatorID(4242L, 4242L), new CounterOperator(), tailOperatorID, new CounterOperator(), - Optional.of(stateHandles)); + Optional.of(restore)); assertEquals( new HashSet() {{ @@ -126,23 +133,22 @@ public void testRestoreHeadWithNewId() throws Exception { public void testRestoreTailWithNewId() throws Exception { OperatorID headOperatorID = new OperatorID(42L, 42L); - AcknowledgeStreamMockEnvironment environment1 = createRunAndCheckpointOperatorChain( + JobManagerTaskRestore restore = createRunAndCheckpointOperatorChain( headOperatorID, new CounterOperator(), new OperatorID(44L, 44L), new CounterOperator(), Optional.empty()); - assertEquals(2, environment1.getCheckpointStateHandles().getSubtaskStateMappings().size()); + TaskStateSnapshot stateHandles = restore.getTaskStateSnapshot(); + assertEquals(2, stateHandles.getSubtaskStateMappings().size()); - TaskStateSnapshot stateHandles = environment1.getCheckpointStateHandles(); - - AcknowledgeStreamMockEnvironment environment2 = createRunAndCheckpointOperatorChain( + createRunAndCheckpointOperatorChain( headOperatorID, new CounterOperator(), new OperatorID(4444L, 4444L), new CounterOperator(), - Optional.of(stateHandles)); + Optional.of(restore)); assertEquals( new HashSet() {{ @@ -156,14 +162,16 @@ public void testRestoreAfterScaleUp() throws Exception { OperatorID headOperatorID = new OperatorID(42L, 42L); OperatorID tailOperatorID = new OperatorID(44L, 44L); - AcknowledgeStreamMockEnvironment environment1 = createRunAndCheckpointOperatorChain( + JobManagerTaskRestore restore = createRunAndCheckpointOperatorChain( headOperatorID, new CounterOperator(), tailOperatorID, new CounterOperator(), Optional.empty()); - assertEquals(2, environment1.getCheckpointStateHandles().getSubtaskStateMappings().size()); + TaskStateSnapshot stateHandles = restore.getTaskStateSnapshot(); + + assertEquals(2, stateHandles.getSubtaskStateMappings().size()); // test empty state in case of scale up OperatorSubtaskState emptyHeadOperatorState = StateAssignmentOperation.operatorSubtaskStateFrom( @@ -173,15 +181,14 @@ public void testRestoreAfterScaleUp() throws Exception { Collections.emptyMap(), Collections.emptyMap()); - TaskStateSnapshot stateHandles = environment1.getCheckpointStateHandles(); stateHandles.putSubtaskStateByOperatorID(headOperatorID, emptyHeadOperatorState); - AcknowledgeStreamMockEnvironment environment2 = createRunAndCheckpointOperatorChain( + createRunAndCheckpointOperatorChain( headOperatorID, new CounterOperator(), tailOperatorID, new CounterOperator(), - Optional.of(stateHandles)); + Optional.of(restore)); assertEquals( new HashSet() {{ @@ -196,23 +203,22 @@ public void testRestoreWithoutState() throws Exception { OperatorID headOperatorID = new OperatorID(42L, 42L); OperatorID tailOperatorID = new OperatorID(44L, 44L); - AcknowledgeStreamMockEnvironment environment1 = createRunAndCheckpointOperatorChain( + JobManagerTaskRestore restore = createRunAndCheckpointOperatorChain( headOperatorID, new StatelessOperator(), tailOperatorID, new CounterOperator(), Optional.empty()); - assertEquals(2, environment1.getCheckpointStateHandles().getSubtaskStateMappings().size()); + TaskStateSnapshot stateHandles = restore.getTaskStateSnapshot(); + assertEquals(2, stateHandles.getSubtaskStateMappings().size()); - TaskStateSnapshot stateHandles = environment1.getCheckpointStateHandles(); - - AcknowledgeStreamMockEnvironment environment2 = createRunAndCheckpointOperatorChain( + createRunAndCheckpointOperatorChain( headOperatorID, new StatelessOperator(), tailOperatorID, new CounterOperator(), - Optional.of(stateHandles)); + Optional.of(restore)); assertEquals( new HashSet() {{ @@ -222,12 +228,12 @@ public void testRestoreWithoutState() throws Exception { RESTORED_OPERATORS); } - private AcknowledgeStreamMockEnvironment createRunAndCheckpointOperatorChain( - OperatorID headId, - OneInputStreamOperator headOperator, - OperatorID tailId, - OneInputStreamOperator tailOperator, - Optional stateHandles) throws Exception { + private JobManagerTaskRestore createRunAndCheckpointOperatorChain( + OperatorID headId, + OneInputStreamOperator headOperator, + OperatorID tailId, + OneInputStreamOperator tailOperator, + Optional restore) throws Exception { final OneInputStreamTask streamTask = new OneInputStreamTask<>(); final OneInputStreamTaskTestHarness testHarness = @@ -240,40 +246,54 @@ private AcknowledgeStreamMockEnvironment createRunAndCheckpointOperatorChain( .chain(tailId, tailOperator, StringSerializer.INSTANCE) .finish(); - AcknowledgeStreamMockEnvironment environment = new AcknowledgeStreamMockEnvironment( + if (restore.isPresent()) { + JobManagerTaskRestore taskRestore = restore.get(); + testHarness.setTaskStateSnapshot( + taskRestore.getRestoreCheckpointId(), + taskRestore.getTaskStateSnapshot()); + } + + StreamMockEnvironment environment = new StreamMockEnvironment( testHarness.jobConfig, testHarness.taskConfig, testHarness.executionConfig, testHarness.memorySize, new MockInputSplitProvider(), - testHarness.bufferSize); + testHarness.bufferSize, + testHarness.taskStateManager); - if (stateHandles.isPresent()) { - streamTask.setInitialState(stateHandles.get()); - } testHarness.invoke(environment); testHarness.waitForTaskRunning(); processRecords(testHarness); - triggerCheckpoint(testHarness, environment, streamTask); + triggerCheckpoint(testHarness, streamTask); + + TestTaskStateManager taskStateManager = testHarness.taskStateManager; + + JobManagerTaskRestore jobManagerTaskRestore = new JobManagerTaskRestore( + taskStateManager.getReportedCheckpointId(), + taskStateManager.getLastTaskStateSnapshot()); testHarness.endInput(); testHarness.waitForTaskCompletion(); - - return environment; + return jobManagerTaskRestore; } private void triggerCheckpoint( OneInputStreamTaskTestHarness testHarness, - AcknowledgeStreamMockEnvironment environment, OneInputStreamTask streamTask) throws Exception { + long checkpointId = 1L; CheckpointMetaData checkpointMetaData = new CheckpointMetaData(checkpointId, 1L); + testHarness.taskStateManager.setWaitForReportLatch(new OneShotLatch()); + while (!streamTask.triggerCheckpoint(checkpointMetaData, CheckpointOptions.forCheckpoint())) {} - environment.getCheckpointLatch().await(); - assertEquals(checkpointId, environment.getCheckpointId()); + testHarness.taskStateManager.getWaitForReportLatch().await(); + long reportedCheckpointId = testHarness.taskStateManager.getReportedCheckpointId(); + + assertEquals(checkpointId, reportedCheckpointId); } private void processRecords(OneInputStreamTaskTestHarness testHarness) throws Exception { diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java index 231f59e97fb2a..cff64f022a04c 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java @@ -27,6 +27,7 @@ import org.apache.flink.core.memory.MemorySegmentFactory; import org.apache.flink.runtime.accumulators.AccumulatorRegistry; import org.apache.flink.runtime.broadcast.BroadcastVariableManager; +import org.apache.flink.runtime.checkpoint.CheckpointMetaData; import org.apache.flink.runtime.checkpoint.CheckpointMetrics; import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.event.AbstractEvent; @@ -52,8 +53,11 @@ import org.apache.flink.runtime.plugable.NonReusingDeserializationDelegate; import org.apache.flink.runtime.query.KvStateRegistry; import org.apache.flink.runtime.query.TaskKvStateRegistry; +import org.apache.flink.runtime.state.TaskLocalStateStore; +import org.apache.flink.runtime.state.TaskStateManager; import org.apache.flink.runtime.taskmanager.TaskManagerRuntimeInfo; import org.apache.flink.runtime.util.TestingTaskManagerRuntimeInfo; +import org.apache.flink.util.Preconditions; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; @@ -93,7 +97,9 @@ public class StreamMockEnvironment implements Environment { private final List outputs; - private final JobID jobID = new JobID(); + private final JobID jobID; + + private final ExecutionAttemptID executionAttemptID; private final BroadcastVariableManager bcVarManager = new BroadcastVariableManager(); @@ -105,23 +111,58 @@ public class StreamMockEnvironment implements Environment { private final ExecutionConfig executionConfig; + private final TaskStateManager taskStateManager; + private volatile boolean wasFailedExternally = false; - public StreamMockEnvironment(Configuration jobConfig, Configuration taskConfig, ExecutionConfig executionConfig, - long memorySize, MockInputSplitProvider inputSplitProvider, int bufferSize) { + public StreamMockEnvironment( + Configuration jobConfig, + Configuration taskConfig, + ExecutionConfig executionConfig, + long memorySize, + MockInputSplitProvider inputSplitProvider, + int bufferSize, + TaskStateManager taskStateManager) { + this( + new JobID(), + new ExecutionAttemptID(0L, 0L), + jobConfig, + taskConfig, + executionConfig, + memorySize, + inputSplitProvider, + bufferSize, + taskStateManager); + } + + public StreamMockEnvironment( + JobID jobID, + ExecutionAttemptID executionAttemptID, + Configuration jobConfig, + Configuration taskConfig, + ExecutionConfig executionConfig, + long memorySize, + MockInputSplitProvider inputSplitProvider, + int bufferSize, + TaskStateManager taskStateManager) { + + this.jobID = jobID; + this.executionAttemptID = executionAttemptID; + + int subtaskIndex = 0; this.taskInfo = new TaskInfo( "", /* task name */ 1, /* num key groups / max parallelism */ - 0, /* index of this subtask */ + subtaskIndex, /* index of this subtask */ 1, /* num subtasks */ 0 /* attempt number */); this.jobConfiguration = jobConfig; this.taskConfiguration = taskConfig; this.inputs = new LinkedList(); this.outputs = new LinkedList(); - this.memManager = new MemoryManager(memorySize, 1); this.ioManager = new IOManagerAsync(); + this.taskStateManager = Preconditions.checkNotNull(taskStateManager); this.inputSplitProvider = inputSplitProvider; this.bufferSize = bufferSize; @@ -130,11 +171,19 @@ public StreamMockEnvironment(Configuration jobConfig, Configuration taskConfig, KvStateRegistry registry = new KvStateRegistry(); this.kvStateRegistry = registry.createTaskRegistry(jobID, getJobVertexId()); + + final TaskLocalStateStore localStateStore = new TaskLocalStateStore(jobID, getJobVertexId(), subtaskIndex); } - public StreamMockEnvironment(Configuration jobConfig, Configuration taskConfig, long memorySize, - MockInputSplitProvider inputSplitProvider, int bufferSize) { - this(jobConfig, taskConfig, new ExecutionConfig(), memorySize, inputSplitProvider, bufferSize); + public StreamMockEnvironment( + Configuration jobConfig, + Configuration taskConfig, + long memorySize, + MockInputSplitProvider inputSplitProvider, + int bufferSize, + TaskStateManager taskStateManager) { + + this(jobConfig, taskConfig, new ExecutionConfig(), memorySize, inputSplitProvider, bufferSize, taskStateManager); } public void addInputGate(InputGate gate) { @@ -310,7 +359,7 @@ public JobVertexID getJobVertexId() { @Override public ExecutionAttemptID getExecutionId() { - return new ExecutionAttemptID(0L, 0L); + return executionAttemptID; } @Override @@ -318,6 +367,11 @@ public BroadcastVariableManager getBroadcastVariableManager() { return this.bcVarManager; } + @Override + public TaskStateManager getTaskStateManager() { + return taskStateManager; + } + @Override public AccumulatorRegistry getAccumulatorRegistry() { return accumulatorRegistry; @@ -334,6 +388,10 @@ public void acknowledgeCheckpoint(long checkpointId, CheckpointMetrics checkpoin @Override public void acknowledgeCheckpoint(long checkpointId, CheckpointMetrics checkpointMetrics, TaskStateSnapshot subtaskState) { + taskStateManager.reportStateHandles( + new CheckpointMetaData(checkpointId, 0L), + checkpointMetrics, + subtaskState); } @Override diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTerminationTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTerminationTest.java index 4c73e7254835e..409e7c57e6e8c 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTerminationTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTerminationTest.java @@ -56,6 +56,7 @@ import org.apache.flink.runtime.state.KeyGroupRange; import org.apache.flink.runtime.state.OperatorStateBackend; import org.apache.flink.runtime.state.OperatorStateHandle; +import org.apache.flink.runtime.state.TestTaskStateManager; import org.apache.flink.runtime.taskmanager.CheckpointResponder; import org.apache.flink.runtime.taskmanager.Task; import org.apache.flink.runtime.taskmanager.TaskManagerActions; @@ -155,6 +156,7 @@ public void testConcurrentAsyncCheckpointCannotFailFinishedStreamTask() throws E new IOManagerAsync(), networkEnv, mock(BroadcastVariableManager.class), + new TestTaskStateManager(), mock(TaskManagerActions.class), mock(InputSplitProvider.class), mock(CheckpointResponder.class), diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java index b31fb41993686..516af4e470e39 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java @@ -34,6 +34,7 @@ import org.apache.flink.runtime.checkpoint.CheckpointMetaData; import org.apache.flink.runtime.checkpoint.CheckpointMetrics; import org.apache.flink.runtime.checkpoint.CheckpointOptions; +import org.apache.flink.runtime.checkpoint.JobManagerTaskRestore; import org.apache.flink.runtime.checkpoint.OperatorSubtaskState; import org.apache.flink.runtime.checkpoint.SubtaskState; import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; @@ -66,11 +67,19 @@ import org.apache.flink.runtime.state.CheckpointStreamFactory; import org.apache.flink.runtime.state.DoneFuture; import org.apache.flink.runtime.state.KeyGroupRange; +import org.apache.flink.runtime.state.KeyGroupStatePartitionStreamProvider; import org.apache.flink.runtime.state.KeyedStateHandle; import org.apache.flink.runtime.state.OperatorStateBackend; import org.apache.flink.runtime.state.OperatorStateHandle; import org.apache.flink.runtime.state.StateBackendFactory; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StatePartitionStreamProvider; import org.apache.flink.runtime.state.StreamStateHandle; +import org.apache.flink.runtime.state.TaskLocalStateStore; +import org.apache.flink.runtime.state.TaskStateManager; +import org.apache.flink.runtime.state.TaskStateManagerImpl; +import org.apache.flink.runtime.state.TestTaskStateManager; +import org.apache.flink.runtime.state.memory.MemoryStateBackend; import org.apache.flink.runtime.taskmanager.CheckpointResponder; import org.apache.flink.runtime.taskmanager.Task; import org.apache.flink.runtime.taskmanager.TaskExecutionState; @@ -83,12 +92,17 @@ import org.apache.flink.streaming.api.functions.source.SourceFunction; import org.apache.flink.streaming.api.graph.StreamConfig; import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.InternalTimeServiceManager; import org.apache.flink.streaming.api.operators.OperatorSnapshotResult; import org.apache.flink.streaming.api.operators.Output; import org.apache.flink.streaming.api.operators.StreamOperator; +import org.apache.flink.streaming.api.operators.StreamOperatorStateContext; import org.apache.flink.streaming.api.operators.StreamSource; +import org.apache.flink.streaming.api.operators.StreamTaskStateManager; +import org.apache.flink.streaming.api.operators.StreamTaskStateManagerImpl; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.streaming.runtime.streamstatus.StreamStatusMaintainer; +import org.apache.flink.util.CloseableIterable; import org.apache.flink.util.ExceptionUtils; import org.apache.flink.util.SerializedValue; import org.apache.flink.util.TestLogger; @@ -141,6 +155,7 @@ import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.powermock.api.mockito.PowerMockito.whenNew; @@ -210,8 +225,10 @@ public void testStateBackendLoadingAndClosing() throws Exception { taskManagerConfig.setString(CoreOptions.STATE_BACKEND, MockStateBackend.class.getName()); StreamConfig cfg = new StreamConfig(new Configuration()); + cfg.setStateKeySerializer(mock(TypeSerializer.class)); cfg.setOperatorID(new OperatorID(4711L, 42L)); - cfg.setStreamOperator(new StreamSource<>(new MockSourceFunction())); + TestStreamSource streamSource = new TestStreamSource<>(new MockSourceFunction()); + cfg.setStreamOperator(streamSource); cfg.setTimeCharacteristic(TimeCharacteristic.ProcessingTime); Task task = createTask(StateBackendTestSource.class, cfg, taskManagerConfig); @@ -222,9 +239,14 @@ public void testStateBackendLoadingAndClosing() throws Exception { // wait for clean termination task.getExecutingThread().join(); - // ensure that the state backends are closed - verify(StateBackendTestSource.operatorStateBackend).close(); - verify(StateBackendTestSource.keyedStateBackend).close(); + // ensure that the state backends and stream iterables are closed ... + verify(TestStreamSource.operatorStateBackend).close(); + verify(TestStreamSource.keyedStateBackend).close(); + verify(TestStreamSource.rawOperatorStateInputs).close(); + verify(TestStreamSource.rawKeyedStateInputs).close(); + // ... and disposed + verify(TestStreamSource.operatorStateBackend).dispose(); + verify(TestStreamSource.keyedStateBackend).dispose(); assertEquals(ExecutionState.FINISHED, task.getExecutionState()); } @@ -248,8 +270,14 @@ public void testStateBackendClosingOnFailure() throws Exception { task.getExecutingThread().join(); // ensure that the state backends are closed - verify(StateBackendTestSource.operatorStateBackend).close(); - verify(StateBackendTestSource.keyedStateBackend).close(); + // ensure that the state backends and stream iterables are closed ... + verify(TestStreamSource.operatorStateBackend).close(); + verify(TestStreamSource.keyedStateBackend).close(); + verify(TestStreamSource.rawOperatorStateInputs).close(); + verify(TestStreamSource.rawKeyedStateInputs).close(); + // ... and disposed + verify(TestStreamSource.operatorStateBackend).dispose(); + verify(TestStreamSource.keyedStateBackend).dispose(); assertEquals(ExecutionState.FAILED, task.getExecutionState()); } @@ -461,6 +489,8 @@ public void testAsyncCheckpointingConcurrentCloseAfterAcknowledge() throws Excep when(mockTaskInfo.getIndexOfThisSubtask()).thenReturn(0); Environment mockEnvironment = mock(Environment.class); when(mockEnvironment.getTaskInfo()).thenReturn(mockTaskInfo); + + CheckpointResponder checkpointResponder = mock(CheckpointResponder.class); doAnswer(new Answer() { @Override public Object answer(InvocationOnMock invocation) throws Throwable { @@ -471,7 +501,21 @@ public Object answer(InvocationOnMock invocation) throws Throwable { return null; } - }).when(mockEnvironment).acknowledgeCheckpoint(anyLong(), any(CheckpointMetrics.class), any(TaskStateSnapshot.class)); + }).when(checkpointResponder).acknowledgeCheckpoint( + any(JobID.class), + any(ExecutionAttemptID.class), + anyLong(), + any(CheckpointMetrics.class), + any(TaskStateSnapshot.class)); + + TaskStateManager taskStateManager = new TaskStateManagerImpl( + new JobID(1L, 2L), + new ExecutionAttemptID(1L, 2L), + mock(TaskLocalStateStore.class), + null, + checkpointResponder); + + when(mockEnvironment.getTaskStateManager()).thenReturn(taskStateManager); StreamTask> streamTask = mock(StreamTask.class, Mockito.CALLS_REAL_METHODS); CheckpointMetaData checkpointMetaData = new CheckpointMetaData(checkpointId, timestamp); @@ -509,6 +553,11 @@ public Object answer(InvocationOnMock invocation) throws Throwable { AbstractStateBackend mockStateBackend = mock(AbstractStateBackend.class); when(mockStateBackend.createStreamFactory(any(JobID.class), anyString())).thenReturn(mockStreamFactory); + StreamTaskStateManager streamTaskStateManager = new StreamTaskStateManagerImpl( + mockEnvironment, + mockStateBackend, + mock(ProcessingTimeService.class)); + Whitebox.setInternalState(streamTask, "isRunning", true); Whitebox.setInternalState(streamTask, "lock", new Object()); Whitebox.setInternalState(streamTask, "operatorChain", operatorChain); @@ -516,6 +565,7 @@ public Object answer(InvocationOnMock invocation) throws Throwable { Whitebox.setInternalState(streamTask, "asyncOperationsThreadPool", Executors.newFixedThreadPool(1)); Whitebox.setInternalState(streamTask, "configuration", new StreamConfig(new Configuration())); Whitebox.setInternalState(streamTask, "stateBackend", mockStateBackend); + Whitebox.setInternalState(streamTask, "streamTaskStateManager", streamTaskStateManager); streamTask.triggerCheckpoint(checkpointMetaData, CheckpointOptions.forCheckpoint()); @@ -524,7 +574,12 @@ public Object answer(InvocationOnMock invocation) throws Throwable { ArgumentCaptor subtaskStateCaptor = ArgumentCaptor.forClass(TaskStateSnapshot.class); // check that the checkpoint has been completed - verify(mockEnvironment).acknowledgeCheckpoint(eq(checkpointId), any(CheckpointMetrics.class), subtaskStateCaptor.capture()); + verify(checkpointResponder).acknowledgeCheckpoint( + any(JobID.class), + any(ExecutionAttemptID.class), + eq(checkpointId), + any(CheckpointMetrics.class), + subtaskStateCaptor.capture()); TaskStateSnapshot subtaskStates = subtaskStateCaptor.getValue(); OperatorSubtaskState subtaskState = subtaskStates.getSubtaskStateMappings().iterator().next().getValue(); @@ -692,17 +747,30 @@ public void testEmptySubtaskStateLeadsToStatelessAcknowledgment() throws Excepti final OneShotLatch checkpointCompletedLatch = new OneShotLatch(); final List checkpointResult = new ArrayList<>(1); - // we remember what is acknowledged (expected to be null as our task will snapshot empty states). + CheckpointResponder checkpointResponder = mock(CheckpointResponder.class); doAnswer(new Answer() { @Override - public Object answer(InvocationOnMock invocationOnMock) throws Throwable { - SubtaskState subtaskState = invocationOnMock.getArgumentAt(2, SubtaskState.class); + public Object answer(InvocationOnMock invocation) throws Throwable { + SubtaskState subtaskState = invocation.getArgumentAt(4, SubtaskState.class); checkpointResult.add(subtaskState); checkpointCompletedLatch.trigger(); return null; } - }).when(mockEnvironment).acknowledgeCheckpoint(anyLong(), any(CheckpointMetrics.class), any(TaskStateSnapshot.class)); - + }).when(checkpointResponder).acknowledgeCheckpoint( + any(JobID.class), + any(ExecutionAttemptID.class), + anyLong(), + any(CheckpointMetrics.class), + any(TaskStateSnapshot.class)); + + TaskStateManager taskStateManager = new TaskStateManagerImpl( + new JobID(1L, 2L), + new ExecutionAttemptID(1L, 2L), + mock(TaskLocalStateStore.class), + null, + checkpointResponder); + + when(mockEnvironment.getTaskStateManager()).thenReturn(taskStateManager); when(mockEnvironment.getTaskInfo()).thenReturn(mockTaskInfo); StreamTask> streamTask = mock(StreamTask.class, Mockito.CALLS_REAL_METHODS); @@ -760,7 +828,8 @@ public void testOperatorClosingBeforeStopRunning() throws Throwable { new MockInputSplitProvider(), 1, taskConfiguration, - new ExecutionConfig()); + new ExecutionConfig(), + new TestTaskStateManager()); StreamTask streamTask = new NoOpStreamTask<>(mockEnvironment); final AtomicReference atomicThrowable = new AtomicReference<>(null); @@ -920,11 +989,12 @@ public static Task createTask( Collections.emptyList(), Collections.emptyList(), 0, - new TaskStateSnapshot(), + new JobManagerTaskRestore(1L, null), mock(MemoryManager.class), mock(IOManager.class), network, mock(BroadcastVariableManager.class), + new TestTaskStateManager(), mock(TaskManagerActions.class), mock(InputSplitProvider.class), mock(CheckpointResponder.class), @@ -999,40 +1069,17 @@ public static final class MockStateBackend implements StateBackendFactory() { - @Override - public OperatorStateBackend answer(InvocationOnMock invocationOnMock) throws Throwable { - return Mockito.mock(OperatorStateBackend.class); - } - }); - - Mockito.when(stateBackendMock.createKeyedStateBackend( - Mockito.any(Environment.class), - Mockito.any(JobID.class), - Mockito.any(String.class), - Mockito.any(TypeSerializer.class), - Mockito.any(int.class), - Mockito.any(KeyGroupRange.class), - Mockito.any(TaskKvStateRegistry.class))) - .thenAnswer(new Answer() { - @Override - public AbstractKeyedStateBackend answer(InvocationOnMock invocationOnMock) throws Throwable { - return Mockito.mock(AbstractKeyedStateBackend.class); - } - }); - } - catch (Exception e) { - // this is needed, because the signatures of the mocked methods throw 'Exception' - throw new RuntimeException(e); - } + return new MemoryStateBackend() { + @Override + public OperatorStateBackend createOperatorStateBackend(Environment env, String operatorIdentifier) throws Exception { + return spy(super.createOperatorStateBackend(env, operatorIdentifier)); + } - return stateBackendMock; + @Override + public AbstractKeyedStateBackend createKeyedStateBackend(Environment env, JobID jobID, String operatorIdentifier, TypeSerializer keySerializer, int numberOfKeyGroups, KeyGroupRange keyGroupRange, TaskKvStateRegistry kvStateRegistry) { + return spy(super.createKeyedStateBackend(env, jobID, operatorIdentifier, keySerializer, numberOfKeyGroups, keyGroupRange, kvStateRegistry)); + } + }; } } @@ -1048,18 +1095,9 @@ public static class StateBackendTestSource extends StreamTask { + + final StreamOperatorStateContext context = + streamTaskStateManager.streamOperatorStateContext(operator, keySerializer, closeableRegistry); + + return new StreamOperatorStateContext() { + @Override + public boolean isRestored() { + return context.isRestored(); + } + + @Override + public OperatorStateBackend operatorStateBackend() { + return context.operatorStateBackend(); + } + + @Override + public AbstractKeyedStateBackend keyedStateBackend() { + return context.keyedStateBackend(); + } + + @Override + public InternalTimeServiceManager internalTimerServiceManager() { + return spy(context.internalTimerServiceManager()); + } + + @Override + public CheckpointStreamFactory checkpointStreamFactory() { + return replaceWithSpy(context.checkpointStreamFactory()); + } + + @Override + public CloseableIterable rawOperatorStateInputs() { + return replaceWithSpy(context.rawOperatorStateInputs()); + } + + @Override + public CloseableIterable rawKeyedStateInputs() { + return replaceWithSpy(context.rawKeyedStateInputs()); + } + + public T replaceWithSpy(T closeable) { + T spyCloseable = spy(closeable); + if (closeableRegistry.unregisterCloseable(closeable)) { + try { + closeableRegistry.registerCloseable(spyCloseable); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + return spyCloseable; + } + }; + }; + } } /** @@ -1212,4 +1308,27 @@ public void close() { interrupt(); } } + + static class TestStreamSource> extends StreamSource { + + static AbstractKeyedStateBackend keyedStateBackend; + static OperatorStateBackend operatorStateBackend; + static CloseableIterable rawOperatorStateInputs; + static CloseableIterable rawKeyedStateInputs; + + public TestStreamSource(SRC sourceFunction) { + super(sourceFunction); + } + + @Override + public void initializeState(StateInitializationContext context) throws Exception { + keyedStateBackend = (AbstractKeyedStateBackend) getKeyedStateBackend(); + operatorStateBackend = getOperatorStateBackend(); + rawOperatorStateInputs = + (CloseableIterable) context.getRawOperatorStateInputs(); + rawKeyedStateInputs = + (CloseableIterable) context.getRawKeyedStateInputs(); + super.initializeState(context); + } + } } diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTestHarness.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTestHarness.java index 5b154770cfe80..b5f1c79a4b94e 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTestHarness.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTestHarness.java @@ -22,12 +22,14 @@ import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.configuration.Configuration; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.event.AbstractEvent; import org.apache.flink.runtime.io.network.partition.consumer.StreamTestSingleInputGate; import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable; import org.apache.flink.runtime.memory.MemoryManager; import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider; +import org.apache.flink.runtime.state.TestTaskStateManager; import org.apache.flink.streaming.api.TimeCharacteristic; import org.apache.flink.streaming.api.collector.selector.OutputSelector; import org.apache.flink.streaming.api.graph.StreamConfig; @@ -77,6 +79,8 @@ public class StreamTaskTestHarness { public Configuration taskConfig; protected StreamConfig streamConfig; + protected TestTaskStateManager taskStateManager; + private AbstractInvokable task; private TypeSerializer outputSerializer; @@ -108,6 +112,8 @@ public StreamTaskTestHarness(AbstractInvokable task, TypeInformation output outputSerializer = outputType.createSerializer(executionConfig); outputStreamRecordSerializer = new StreamElementSerializer(outputSerializer); + + this.taskStateManager = new TestTaskStateManager(); } public ProcessingTimeService getProcessingTimeService() { @@ -122,6 +128,16 @@ public ProcessingTimeService getProcessingTimeService() { */ protected void initializeInputs() throws IOException, InterruptedException {} + public TestTaskStateManager getTaskStateManager() { + return taskStateManager; + } + + public void setTaskStateSnapshot(long checkpointId, TaskStateSnapshot taskStateSnapshot) { + taskStateManager.setReportedCheckpointId(checkpointId); + taskStateManager.setTaskStateSnapshotsByCheckpointId( + Collections.singletonMap(checkpointId, taskStateSnapshot)); + } + @SuppressWarnings("unchecked") private void initializeOutput() { outputList = new LinkedBlockingQueue(); @@ -162,7 +178,13 @@ public void setupOutputForSingletonOperatorChain() { public StreamMockEnvironment createEnvironment() { return new StreamMockEnvironment( - jobConfig, taskConfig, executionConfig, memorySize, new MockInputSplitProvider(), bufferSize); + jobConfig, + taskConfig, + executionConfig, + memorySize, + new MockInputSplitProvider(), + bufferSize, + taskStateManager); } /** diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TaskCheckpointingBehaviourTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TaskCheckpointingBehaviourTest.java index d755c566efe62..84b5fb563514d 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TaskCheckpointingBehaviourTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TaskCheckpointingBehaviourTest.java @@ -62,6 +62,7 @@ import org.apache.flink.runtime.state.OperatorStateHandle; import org.apache.flink.runtime.state.StateSnapshotContext; import org.apache.flink.runtime.state.StreamStateHandle; +import org.apache.flink.runtime.state.TestTaskStateManager; import org.apache.flink.runtime.state.memory.MemoryStateBackend; import org.apache.flink.runtime.taskmanager.CheckpointResponder; import org.apache.flink.runtime.taskmanager.Task; @@ -231,6 +232,7 @@ private static Task createTask( mock(IOManager.class), network, mock(BroadcastVariableManager.class), + new TestTaskStateManager(), mock(TaskManagerActions.class), mock(InputSplitProvider.class), checkpointResponder, diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/AbstractStreamOperatorTestHarness.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/AbstractStreamOperatorTestHarness.java index 13f0b3fa59ba2..02ab609bc5488 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/AbstractStreamOperatorTestHarness.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/AbstractStreamOperatorTestHarness.java @@ -30,6 +30,7 @@ import org.apache.flink.runtime.checkpoint.OperatorSubtaskState; import org.apache.flink.runtime.checkpoint.RoundRobinOperatorStateRepartitioner; import org.apache.flink.runtime.checkpoint.StateAssignmentOperation; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.execution.Environment; import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.operators.testutils.MockEnvironment; @@ -37,9 +38,9 @@ import org.apache.flink.runtime.state.CheckpointStreamFactory; import org.apache.flink.runtime.state.KeyGroupRange; import org.apache.flink.runtime.state.KeyedStateHandle; -import org.apache.flink.runtime.state.OperatorStateBackend; import org.apache.flink.runtime.state.OperatorStateHandle; import org.apache.flink.runtime.state.StateBackend; +import org.apache.flink.runtime.state.TestTaskStateManager; import org.apache.flink.runtime.state.memory.MemoryStateBackend; import org.apache.flink.streaming.api.TimeCharacteristic; import org.apache.flink.streaming.api.graph.StreamConfig; @@ -48,6 +49,8 @@ import org.apache.flink.streaming.api.operators.OperatorSnapshotResult; import org.apache.flink.streaming.api.operators.Output; import org.apache.flink.streaming.api.operators.StreamOperator; +import org.apache.flink.streaming.api.operators.StreamTaskStateManager; +import org.apache.flink.streaming.api.operators.StreamTaskStateManagerImpl; import org.apache.flink.streaming.api.watermark.Watermark; import org.apache.flink.streaming.runtime.streamrecord.LatencyMarker; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; @@ -61,6 +64,7 @@ import org.apache.flink.util.OutputTag; import org.apache.flink.util.Preconditions; +import org.mockito.internal.util.MockUtil; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; @@ -76,6 +80,7 @@ import static org.mockito.Matchers.any; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; import static org.mockito.Mockito.when; /** @@ -97,6 +102,10 @@ public class AbstractStreamOperatorTestHarness implements AutoCloseable { protected final StreamTask mockTask; + protected final StreamTaskStateManager streamTaskStateManager; + + protected final TestTaskStateManager taskStateManager; + final Environment environment; CloseableRegistry closableRegistry; @@ -132,6 +141,7 @@ public AbstractStreamOperatorTestHarness( 1024, new Configuration(), new ExecutionConfig(), + new TestTaskStateManager(), maxParallelism, parallelism, subtaskIndex)); @@ -139,25 +149,41 @@ public AbstractStreamOperatorTestHarness( public AbstractStreamOperatorTestHarness( StreamOperator operator, - final Environment environment) throws Exception { + Environment env) throws Exception { this.operator = operator; this.outputList = new ConcurrentLinkedQueue<>(); this.sideOutputLists = new HashMap<>(); - Configuration underlyingConfig = environment.getTaskConfiguration(); + Configuration underlyingConfig = env.getTaskConfiguration(); this.config = new StreamConfig(underlyingConfig); this.config.setCheckpointingEnabled(true); this.config.setOperatorID(new OperatorID()); - this.executionConfig = environment.getExecutionConfig(); + this.executionConfig = env.getExecutionConfig(); this.closableRegistry = new CloseableRegistry(); this.checkpointLock = new Object(); - this.environment = Preconditions.checkNotNull(environment); + Preconditions.checkNotNull(env); + + MockUtil mockUtil = new MockUtil(); + + if (!mockUtil.isMock(env) && !mockUtil.isSpy(env)) { + env = spy(env); + } + + this.environment = env; + + this.taskStateManager = new TestTaskStateManager( + env.getJobID(), + env.getExecutionId()); + + when(this.environment.getTaskStateManager()).thenReturn(this.taskStateManager); mockTask = mock(StreamTask.class); processingTimeService = new TestProcessingTimeService(); processingTimeService.setCurrentTime(0); + this.streamTaskStateManager = createStreamTaskStateManager(environment, stateBackend, processingTimeService); + StreamStatusMaintainer mockStreamStatusMaintainer = new StreamStatusMaintainer() { StreamStatus currentStreamStatus = StreamStatus.ACTIVE; @@ -180,8 +206,9 @@ public StreamStatus getStreamStatus() { when(mockTask.getTaskConfiguration()).thenReturn(underlyingConfig); when(mockTask.getEnvironment()).thenReturn(environment); when(mockTask.getExecutionConfig()).thenReturn(executionConfig); + when(mockTask.getStreamTaskStateManager()).thenReturn(streamTaskStateManager); - ClassLoader cl = environment.getUserClassLoader(); + ClassLoader cl = env.getUserClassLoader(); when(mockTask.getUserCodeClassLoader()).thenReturn(cl); when(mockTask.getCancelables()).thenReturn(this.closableRegistry); @@ -208,31 +235,6 @@ public CheckpointStreamFactory answer(InvocationOnMock invocationOnMock) throws throw new RuntimeException(e.getMessage(), e); } - try { - doAnswer(new Answer() { - @Override - public OperatorStateBackend answer(InvocationOnMock invocationOnMock) throws Throwable { - final StreamOperator operator = (StreamOperator) invocationOnMock.getArguments()[0]; - final Collection stateHandles = (Collection) invocationOnMock.getArguments()[1]; - OperatorStateBackend osb; - - osb = stateBackend.createOperatorStateBackend( - environment, - operator.getClass().getSimpleName()); - - mockTask.getCancelables().registerCloseable(osb); - - if (null != stateHandles) { - osb.restore(stateHandles); - } - - return osb; - } - }).when(mockTask).createOperatorStateBackend(any(StreamOperator.class), any(Collection.class)); - } catch (Exception e) { - throw new RuntimeException(e.getMessage(), e); - } - doAnswer(new Answer() { @Override public ProcessingTimeService answer(InvocationOnMock invocation) throws Throwable { @@ -242,6 +244,16 @@ public ProcessingTimeService answer(InvocationOnMock invocation) throws Throwabl } + protected StreamTaskStateManager createStreamTaskStateManager( + Environment env, + StateBackend stateBackend, + ProcessingTimeService processingTimeService) { + return new StreamTaskStateManagerImpl( + env, + stateBackend, + processingTimeService); + } + public void setStateBackend(StateBackend stateBackend) { this.stateBackend = stateBackend; } @@ -300,7 +312,7 @@ public void setup(TypeSerializer outputSerializer) { } /** - * Calls {@link org.apache.flink.streaming.api.operators.StreamOperator#initializeState(OperatorSubtaskState)}. + * Calls {@link org.apache.flink.streaming.api.operators.StreamOperator#initializeState()}. * Calls {@link org.apache.flink.streaming.api.operators.StreamOperator#setup(StreamTask, StreamConfig, Output)} * if it was not called before. * @@ -357,16 +369,20 @@ public void initializeState(OperatorStateHandles operatorStateHandles) throws Ex rawOperatorState, numSubtasks).get(subtaskIndex); - OperatorSubtaskState massagedOperatorStateHandles = new OperatorSubtaskState( + OperatorSubtaskState operatorSubtaskState = new OperatorSubtaskState( nullToEmptyCollection(localManagedOperatorState), nullToEmptyCollection(localRawOperatorState), nullToEmptyCollection(localManagedKeyGroupState), nullToEmptyCollection(localRawKeyGroupState)); - operator.initializeState(massagedOperatorStateHandles); - } else { - operator.initializeState(null); + TaskStateSnapshot taskStateSnapshot = new TaskStateSnapshot(); + taskStateSnapshot.putSubtaskStateByOperatorID(operator.getOperatorID(), operatorSubtaskState); + + taskStateManager.setReportedCheckpointId(0); + taskStateManager.setTaskStateSnapshotsByCheckpointId(Collections.singletonMap(0L, taskStateSnapshot)); } + + operator.initializeState(); initializeCalled = true; } @@ -476,10 +492,10 @@ public OperatorStateHandles snapshot(long checkpointId, long timestamp) throws E } /** - * Calls {@link org.apache.flink.streaming.api.operators.StreamOperator#notifyOfCompletedCheckpoint(long)} ()}. + * Calls {@link org.apache.flink.streaming.api.operators.StreamOperator#notifyCheckpointComplete(long)} ()}. */ public void notifyOfCompletedCheckpoint(long checkpointId) throws Exception { - operator.notifyOfCompletedCheckpoint(checkpointId); + operator.notifyCheckpointComplete(checkpointId); } /** diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedOneInputStreamOperatorTestHarness.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedOneInputStreamOperatorTestHarness.java index c2ec63a6474d0..ea0755725db75 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedOneInputStreamOperatorTestHarness.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedOneInputStreamOperatorTestHarness.java @@ -18,32 +18,14 @@ package org.apache.flink.streaming.util; -import org.apache.flink.api.common.JobID; import org.apache.flink.api.common.typeinfo.TypeInformation; -import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.java.ClosureCleaner; import org.apache.flink.api.java.functions.KeySelector; -import org.apache.flink.runtime.checkpoint.StateAssignmentOperation; import org.apache.flink.runtime.execution.Environment; -import org.apache.flink.runtime.state.AbstractKeyedStateBackend; -import org.apache.flink.runtime.state.KeyGroupRange; import org.apache.flink.runtime.state.KeyedStateBackend; -import org.apache.flink.runtime.state.KeyedStateHandle; import org.apache.flink.runtime.state.heap.HeapKeyedStateBackend; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; import org.apache.flink.streaming.api.operators.OneInputStreamOperator; -import org.apache.flink.streaming.runtime.tasks.OperatorStateHandles; -import org.apache.flink.util.Migration; - -import org.mockito.invocation.InvocationOnMock; -import org.mockito.stubbing.Answer; - -import java.util.ArrayList; -import java.util.Collection; -import java.util.List; - -import static org.mockito.Matchers.any; -import static org.mockito.Mockito.anyInt; -import static org.mockito.Mockito.doAnswer; /** * Extension of {@link OneInputStreamOperatorTestHarness} that allows the operator to get @@ -52,14 +34,6 @@ public class KeyedOneInputStreamOperatorTestHarness extends OneInputStreamOperatorTestHarness { - // in case the operator creates one we store it here so that we - // can snapshot its state - private AbstractKeyedStateBackend keyedStateBackend = null; - - // when we restore we keep the state here so that we can call restore - // when the operator requests the keyed state backend - private List restoredKeyedState = null; - public KeyedOneInputStreamOperatorTestHarness( OneInputStreamOperator operator, final KeySelector keySelector, @@ -72,8 +46,6 @@ public KeyedOneInputStreamOperatorTestHarness( ClosureCleaner.clean(keySelector, false); config.setStatePartitioner(0, keySelector); config.setStateKeySerializer(keyType.createSerializer(executionConfig)); - - setupMockTaskCreateKeyedBackend(); } public KeyedOneInputStreamOperatorTestHarness( @@ -94,54 +66,11 @@ public KeyedOneInputStreamOperatorTestHarness( ClosureCleaner.clean(keySelector, false); config.setStatePartitioner(0, keySelector); config.setStateKeySerializer(keyType.createSerializer(executionConfig)); - - setupMockTaskCreateKeyedBackend(); - } - - private void setupMockTaskCreateKeyedBackend() { - - try { - doAnswer(new Answer() { - @Override - public KeyedStateBackend answer(InvocationOnMock invocationOnMock) throws Throwable { - - final TypeSerializer keySerializer = (TypeSerializer) invocationOnMock.getArguments()[0]; - final int numberOfKeyGroups = (Integer) invocationOnMock.getArguments()[1]; - final KeyGroupRange keyGroupRange = (KeyGroupRange) invocationOnMock.getArguments()[2]; - - if (keyedStateBackend != null) { - keyedStateBackend.dispose(); - } - - keyedStateBackend = stateBackend.createKeyedStateBackend( - mockTask.getEnvironment(), - new JobID(), - "test_op", - keySerializer, - numberOfKeyGroups, - keyGroupRange, - mockTask.getEnvironment().getTaskKvStateRegistry()); - - keyedStateBackend.restore(restoredKeyedState); - - return keyedStateBackend; - } - }).when(mockTask).createKeyedStateBackend(any(TypeSerializer.class), anyInt(), any(KeyGroupRange.class)); - } catch (Exception e) { - throw new RuntimeException(e.getMessage(), e); - } - } - - private static boolean hasMigrationHandles(Collection allKeyGroupsHandles) { - for (KeyedStateHandle handle : allKeyGroupsHandles) { - if (handle instanceof Migration) { - return true; - } - } - return false; } public int numKeyedStateEntries() { + AbstractStreamOperator abstractStreamOperator = (AbstractStreamOperator) operator; + KeyedStateBackend keyedStateBackend = abstractStreamOperator.getKeyedStateBackend(); if (keyedStateBackend instanceof HeapKeyedStateBackend) { return ((HeapKeyedStateBackend) keyedStateBackend).numStateEntries(); } else { @@ -150,47 +79,12 @@ public int numKeyedStateEntries() { } public int numKeyedStateEntries(N namespace) { + AbstractStreamOperator abstractStreamOperator = (AbstractStreamOperator) operator; + KeyedStateBackend keyedStateBackend = abstractStreamOperator.getKeyedStateBackend(); if (keyedStateBackend instanceof HeapKeyedStateBackend) { return ((HeapKeyedStateBackend) keyedStateBackend).numStateEntries(namespace); } else { throw new UnsupportedOperationException(); } } - - @Override - public void initializeState(OperatorStateHandles operatorStateHandles) throws Exception { - if (operatorStateHandles != null) { - int numKeyGroups = getEnvironment().getTaskInfo().getMaxNumberOfParallelSubtasks(); - int numSubtasks = getEnvironment().getTaskInfo().getNumberOfParallelSubtasks(); - int subtaskIndex = getEnvironment().getTaskInfo().getIndexOfThisSubtask(); - - // create a new OperatorStateHandles that only contains the state for our key-groups - - List keyGroupPartitions = StateAssignmentOperation.createKeyGroupPartitions( - numKeyGroups, - numSubtasks); - - KeyGroupRange localKeyGroupRange = - keyGroupPartitions.get(subtaskIndex); - - restoredKeyedState = null; - Collection managedKeyedState = operatorStateHandles.getManagedKeyedState(); - if (managedKeyedState != null) { - - // if we have migration handles, don't reshuffle state and preserve - // the migration tag - if (hasMigrationHandles(managedKeyedState)) { - List result = new ArrayList<>(managedKeyedState.size()); - result.addAll(managedKeyedState); - restoredKeyedState = result; - } else { - restoredKeyedState = StateAssignmentOperation.getKeyedStateHandles( - managedKeyedState, - localKeyGroupRange); - } - } - } - - super.initializeState(operatorStateHandles); - } } diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedTwoInputStreamOperatorTestHarness.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedTwoInputStreamOperatorTestHarness.java index b0500ca9d9e09..607eee045bdcb 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedTwoInputStreamOperatorTestHarness.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedTwoInputStreamOperatorTestHarness.java @@ -18,27 +18,13 @@ package org.apache.flink.streaming.util; -import org.apache.flink.api.common.JobID; import org.apache.flink.api.common.typeinfo.TypeInformation; -import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.java.ClosureCleaner; import org.apache.flink.api.java.functions.KeySelector; -import org.apache.flink.runtime.state.AbstractKeyedStateBackend; -import org.apache.flink.runtime.state.KeyGroupRange; import org.apache.flink.runtime.state.KeyedStateBackend; -import org.apache.flink.runtime.state.KeyedStateHandle; import org.apache.flink.runtime.state.heap.HeapKeyedStateBackend; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; -import org.apache.flink.streaming.runtime.tasks.OperatorStateHandles; - -import org.mockito.invocation.InvocationOnMock; -import org.mockito.stubbing.Answer; - -import java.util.Collection; - -import static org.mockito.Matchers.any; -import static org.mockito.Mockito.anyInt; -import static org.mockito.Mockito.doAnswer; /** * Extension of {@link TwoInputStreamOperatorTestHarness} that allows the operator to get @@ -47,14 +33,6 @@ public class KeyedTwoInputStreamOperatorTestHarness extends TwoInputStreamOperatorTestHarness { - // in case the operator creates one we store it here so that we - // can snapshot its state - private AbstractKeyedStateBackend keyedStateBackend = null; - - // when we restore we keep the state here so that we can call restore - // when the operator requests the keyed state backend - private Collection restoredKeyedState = null; - public KeyedTwoInputStreamOperatorTestHarness( TwoInputStreamOperator operator, KeySelector keySelector1, @@ -70,8 +48,6 @@ public KeyedTwoInputStreamOperatorTestHarness( config.setStatePartitioner(0, keySelector1); config.setStatePartitioner(1, keySelector2); config.setStateKeySerializer(keyType.createSerializer(executionConfig)); - - setupMockTaskCreateKeyedBackend(); } public KeyedTwoInputStreamOperatorTestHarness( @@ -82,50 +58,9 @@ public KeyedTwoInputStreamOperatorTestHarness( this(operator, keySelector1, keySelector2, keyType, 1, 1, 0); } - private void setupMockTaskCreateKeyedBackend() { - - try { - doAnswer(new Answer() { - @Override - public KeyedStateBackend answer(InvocationOnMock invocationOnMock) throws Throwable { - - final TypeSerializer keySerializer = (TypeSerializer) invocationOnMock.getArguments()[0]; - final int numberOfKeyGroups = (Integer) invocationOnMock.getArguments()[1]; - final KeyGroupRange keyGroupRange = (KeyGroupRange) invocationOnMock.getArguments()[2]; - - if (keyedStateBackend != null) { - keyedStateBackend.close(); - } - - keyedStateBackend = stateBackend.createKeyedStateBackend( - mockTask.getEnvironment(), - new JobID(), - "test_op", - keySerializer, - numberOfKeyGroups, - keyGroupRange, - mockTask.getEnvironment().getTaskKvStateRegistry()); - if (restoredKeyedState != null) { - keyedStateBackend.restore(restoredKeyedState); - } - return keyedStateBackend; - } - }).when(mockTask).createKeyedStateBackend(any(TypeSerializer.class), anyInt(), any(KeyGroupRange.class)); - } catch (Exception e) { - throw new RuntimeException(e.getMessage(), e); - } - } - - @Override - public void initializeState(OperatorStateHandles operatorStateHandles) throws Exception { - if (restoredKeyedState != null) { - restoredKeyedState = operatorStateHandles.getManagedKeyedState(); - } - - super.initializeState(operatorStateHandles); - } - public int numKeyedStateEntries() { + AbstractStreamOperator abstractStreamOperator = (AbstractStreamOperator) operator; + KeyedStateBackend keyedStateBackend = abstractStreamOperator.getKeyedStateBackend(); if (keyedStateBackend instanceof HeapKeyedStateBackend) { return ((HeapKeyedStateBackend) keyedStateBackend).numStateEntries(); } else { diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/SourceFunctionUtil.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/SourceFunctionUtil.java index 5f17467f72f98..f1cc07fe04ad4 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/SourceFunctionUtil.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/SourceFunctionUtil.java @@ -24,6 +24,7 @@ import org.apache.flink.configuration.Configuration; import org.apache.flink.runtime.operators.testutils.MockEnvironment; import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider; +import org.apache.flink.runtime.state.TestTaskStateManager; import org.apache.flink.streaming.api.functions.source.SourceFunction; import org.apache.flink.streaming.api.operators.AbstractStreamOperator; import org.apache.flink.streaming.api.operators.StreamingRuntimeContext; @@ -49,10 +50,15 @@ public static List runSourceFunction(SourceFunction< AbstractStreamOperator operator = mock(AbstractStreamOperator.class); when(operator.getExecutionConfig()).thenReturn(new ExecutionConfig()); - RuntimeContext runtimeContext = new StreamingRuntimeContext( - operator, - new MockEnvironment("MockTask", 3 * 1024 * 1024, new MockInputSplitProvider(), 1024), - new HashMap>()); + RuntimeContext runtimeContext = new StreamingRuntimeContext( + operator, + new MockEnvironment( + "MockTask", + 3 * 1024 * 1024, + new MockInputSplitProvider(), + 1024, + new TestTaskStateManager()), + new HashMap>()); ((RichFunction) sourceFunction).setRuntimeContext(runtimeContext); diff --git a/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/StateBackendITCase.java b/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/StateBackendITCase.java index f5c769d860d95..927597f90630e 100644 --- a/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/StateBackendITCase.java +++ b/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/StateBackendITCase.java @@ -40,7 +40,6 @@ import java.io.IOException; -import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; /** @@ -85,8 +84,24 @@ public String map(Tuple2 value) throws Exception { fail(); } catch (JobExecutionException e) { + boolean success = false; Throwable t = e.getCause(); - assertTrue("wrong exception", t instanceof SuccessException); + while (t != null) { + if (t instanceof SuccessException) { + success = true; + break; + } else { + if (t != t.getCause()) { + t = t.getCause(); + } else { + break; + } + } + } + + if (!success) { + fail(); + } } } @@ -102,7 +117,7 @@ public CheckpointStreamFactory createStreamFactory(JobID jobId, @Override public CheckpointStreamFactory createSavepointStreamFactory(JobID jobId, String operatorIdentifier, String targetLocation) throws IOException { - throw new UnsupportedOperationException(); + throw new SuccessException(); } @Override @@ -122,7 +137,7 @@ public OperatorStateBackend createOperatorStateBackend( Environment env, String operatorIdentifier) throws Exception { - throw new UnsupportedOperationException(); + throw new SuccessException(); } } diff --git a/flink-tests/src/test/java/org/apache/flink/test/typeserializerupgrade/PojoSerializerUpgradeTest.java b/flink-tests/src/test/java/org/apache/flink/test/typeserializerupgrade/PojoSerializerUpgradeTest.java index 54fe8792c5cdf..623121596fe79 100644 --- a/flink-tests/src/test/java/org/apache/flink/test/typeserializerupgrade/PojoSerializerUpgradeTest.java +++ b/flink-tests/src/test/java/org/apache/flink/test/typeserializerupgrade/PojoSerializerUpgradeTest.java @@ -41,6 +41,7 @@ import org.apache.flink.runtime.state.FunctionInitializationContext; import org.apache.flink.runtime.state.FunctionSnapshotContext; import org.apache.flink.runtime.state.StateBackend; +import org.apache.flink.runtime.state.TestTaskStateManager; import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction; import org.apache.flink.streaming.api.operators.OneInputStreamOperator; import org.apache.flink.streaming.api.operators.StreamMap; @@ -360,7 +361,8 @@ private OperatorStateHandles runOperator( 16, 1, 0, - classLoader); + classLoader, + new TestTaskStateManager()); OneInputStreamOperatorTestHarness harness;