diff --git a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java index 7ab35c4fd49bc..f332d1efd8418 100644 --- a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java +++ b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java @@ -65,6 +65,7 @@ import java.io.File; import java.io.IOException; import java.util.ArrayList; +import java.util.Collection; import java.util.Comparator; import java.util.HashMap; import java.util.List; @@ -185,7 +186,7 @@ public RocksDBKeyedStateBackend( TypeSerializer keySerializer, int numberOfKeyGroups, KeyGroupRange keyGroupRange, - List restoreState + Collection restoreState ) throws Exception { this(jobId, @@ -603,7 +604,7 @@ public RocksDBRestoreOperation(RocksDBKeyedStateBackend rocksDBKeyedStateBack * @throws ClassNotFoundException * @throws RocksDBException */ - public void doRestore(List keyGroupsStateHandles) + public void doRestore(Collection keyGroupsStateHandles) throws IOException, ClassNotFoundException, RocksDBException { for (KeyGroupsStateHandle keyGroupsStateHandle : keyGroupsStateHandles) { diff --git a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java index a0c980b24ce54..82e7899b035f3 100644 --- a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java +++ b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java @@ -40,6 +40,7 @@ import java.io.ObjectOutputStream; import java.net.URI; import java.util.ArrayList; +import java.util.Collection; import java.util.List; import java.util.Random; import java.util.UUID; @@ -258,7 +259,7 @@ public AbstractKeyedStateBackend restoreKeyedStateBackend( TypeSerializer keySerializer, int numberOfKeyGroups, KeyGroupRange keyGroupRange, - List restoredState, + Collection restoredState, TaskKvStateRegistry kvStateRegistry) throws Exception { lazyInitializeForJob(env, operatorIdentifier); diff --git a/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBAsyncSnapshotTest.java b/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBAsyncSnapshotTest.java index 8f58075132916..4d1ab50b60eb8 100644 --- a/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBAsyncSnapshotTest.java +++ b/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBAsyncSnapshotTest.java @@ -28,9 +28,9 @@ import org.apache.flink.configuration.ConfigConstants; import org.apache.flink.core.testutils.OneShotLatch; import org.apache.flink.runtime.checkpoint.CheckpointMetaData; +import org.apache.flink.runtime.checkpoint.SubtaskState; import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter; import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider; -import org.apache.flink.runtime.state.CheckpointStateHandles; import org.apache.flink.runtime.state.CheckpointStreamFactory; import org.apache.flink.runtime.state.VoidNamespace; import org.apache.flink.runtime.state.VoidNamespaceSerializer; @@ -70,7 +70,7 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.TimeUnit; -import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; /** * Tests for asynchronous RocksDB Key/Value state checkpoints. @@ -136,7 +136,7 @@ public String getKey(String value) throws Exception { @Override public void acknowledgeCheckpoint( CheckpointMetaData checkpointMetaData, - CheckpointStateHandles checkpointStateHandles) { + SubtaskState checkpointStateHandles) { super.acknowledgeCheckpoint(checkpointMetaData); @@ -148,8 +148,8 @@ public void acknowledgeCheckpoint( e.printStackTrace(); } - // should be only one k/v state - assertEquals(1, checkpointStateHandles.getKeyGroupsStateHandle().size()); + // should be one k/v state + assertNotNull(checkpointStateHandles.getManagedKeyedState()); // we now know that the checkpoint went through ensureCheckpointLatch.trigger(); diff --git a/flink-core/src/main/java/org/apache/flink/api/common/functions/RuntimeContext.java b/flink-core/src/main/java/org/apache/flink/api/common/functions/RuntimeContext.java index ce513cbc1ed85..b120fa5de027b 100644 --- a/flink-core/src/main/java/org/apache/flink/api/common/functions/RuntimeContext.java +++ b/flink-core/src/main/java/org/apache/flink/api/common/functions/RuntimeContext.java @@ -210,7 +210,7 @@ public interface RuntimeContext { * @return The distributed cache of the worker executing this instance. */ DistributedCache getDistributedCache(); - + // ------------------------------------------------------------------------ // Methods for accessing state // ------------------------------------------------------------------------ @@ -266,7 +266,7 @@ public interface RuntimeContext { * Gets a handle to the system's key/value list state. This state is similar to the state * accessed via {@link #getState(ValueStateDescriptor)}, but is optimized for state that * holds lists. One can adds elements to the list, or retrieve the list as a whole. - * + * *

This state is only accessible if the function is executed on a KeyedStream. * *

{@code
@@ -331,7 +331,7 @@ public interface RuntimeContext {
 	 *         return new Tuple2<>(value, sum.get());
 	 *     }
 	 * });
-	 * 
+	 *
 	 * }
* * @param stateProperties The descriptor defining the properties of the stats. diff --git a/flink-core/src/main/java/org/apache/flink/api/common/state/KeyedStateStore.java b/flink-core/src/main/java/org/apache/flink/api/common/state/KeyedStateStore.java new file mode 100644 index 0000000000000..89c12404faf0b --- /dev/null +++ b/flink-core/src/main/java/org/apache/flink/api/common/state/KeyedStateStore.java @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.common.state; + +import org.apache.flink.annotation.PublicEvolving; + +/** + * This interface contains methods for registering keyed state with a managed store. + */ +@PublicEvolving +public interface KeyedStateStore { + + /** + * Gets a handle to the system's key/value state. The key/value state is only accessible + * if the function is executed on a KeyedStream. On each access, the state exposes the value + * for the the key of the element currently processed by the function. + * Each function may have multiple partitioned states, addressed with different names. + * + *

Because the scope of each value is the key of the currently processed element, + * and the elements are distributed by the Flink runtime, the system can transparently + * scale out and redistribute the state and KeyedStream. + * + *

The following code example shows how to implement a continuous counter that counts + * how many times elements of a certain key occur, and emits an updated count for that + * element on each occurrence. + * + *

{@code
+	 * DataStream stream = ...;
+	 * KeyedStream keyedStream = stream.keyBy("id");
+	 *
+	 * keyedStream.map(new RichMapFunction>() {
+	 *
+	 *     private ValueState count;
+	 *
+	 *     public void open(Configuration cfg) {
+	 *         state = getRuntimeContext().getState(
+	 *                 new ValueStateDescriptor("count", LongSerializer.INSTANCE, 0L));
+	 *     }
+	 *
+	 *     public Tuple2 map(MyType value) {
+	 *         long count = state.value() + 1;
+	 *         state.update(value);
+	 *         return new Tuple2<>(value, count);
+	 *     }
+	 * });
+	 * }
+ * + * @param stateProperties The descriptor defining the properties of the stats. + * + * @param The type of value stored in the state. + * + * @return The partitioned state object. + * + * @throws UnsupportedOperationException Thrown, if no partitioned state is available for the + * function (function is not part of a KeyedStream). + */ + @PublicEvolving + ValueState getState(ValueStateDescriptor stateProperties); + + /** + * Gets a handle to the system's key/value list state. This state is similar to the state + * accessed via {@link #getState(ValueStateDescriptor)}, but is optimized for state that + * holds lists. One can adds elements to the list, or retrieve the list as a whole. + * + *

This state is only accessible if the function is executed on a KeyedStream. + * + *

{@code
+	 * DataStream stream = ...;
+	 * KeyedStream keyedStream = stream.keyBy("id");
+	 *
+	 * keyedStream.map(new RichFlatMapFunction>() {
+	 *
+	 *     private ListState state;
+	 *
+	 *     public void open(Configuration cfg) {
+	 *         state = getRuntimeContext().getListState(
+	 *                 new ListStateDescriptor<>("myState", MyType.class));
+	 *     }
+	 *
+	 *     public void flatMap(MyType value, Collector out) {
+	 *         if (value.isDivider()) {
+	 *             for (MyType t : state.get()) {
+	 *                 out.collect(t);
+	 *             }
+	 *         } else {
+	 *             state.add(value);
+	 *         }
+	 *     }
+	 * });
+	 * }
+ * + * @param stateProperties The descriptor defining the properties of the stats. + * + * @param The type of value stored in the state. + * + * @return The partitioned state object. + * + * @throws UnsupportedOperationException Thrown, if no partitioned state is available for the + * function (function is not part os a KeyedStream). + */ + @PublicEvolving + ListState getListState(ListStateDescriptor stateProperties); + + /** + * Gets a handle to the system's key/value list state. This state is similar to the state + * accessed via {@link #getState(ValueStateDescriptor)}, but is optimized for state that + * aggregates values. + * + *

This state is only accessible if the function is executed on a KeyedStream. + * + *

{@code
+	 * DataStream stream = ...;
+	 * KeyedStream keyedStream = stream.keyBy("id");
+	 *
+	 * keyedStream.map(new RichMapFunction>() {
+	 *
+	 *     private ReducingState sum;
+	 *
+	 *     public void open(Configuration cfg) {
+	 *         state = getRuntimeContext().getReducingState(
+	 *                 new ReducingStateDescriptor<>("sum", MyType.class, 0L, (a, b) -> a + b));
+	 *     }
+	 *
+	 *     public Tuple2 map(MyType value) {
+	 *         sum.add(value.count());
+	 *         return new Tuple2<>(value, sum.get());
+	 *     }
+	 * });
+	 *
+	 * }
+ * + * @param stateProperties The descriptor defining the properties of the stats. + * + * @param The type of value stored in the state. + * + * @return The partitioned state object. + * + * @throws UnsupportedOperationException Thrown, if no partitioned state is available for the + * function (function is not part of a KeyedStream). + */ + @PublicEvolving + ReducingState getReducingState(ReducingStateDescriptor stateProperties); +} \ No newline at end of file diff --git a/flink-core/src/main/java/org/apache/flink/api/common/state/OperatorStateStore.java b/flink-core/src/main/java/org/apache/flink/api/common/state/OperatorStateStore.java index 03c11f695bdee..43dbe51e4fb6a 100644 --- a/flink-core/src/main/java/org/apache/flink/api/common/state/OperatorStateStore.java +++ b/flink-core/src/main/java/org/apache/flink/api/common/state/OperatorStateStore.java @@ -18,16 +18,17 @@ package org.apache.flink.api.common.state; +import org.apache.flink.annotation.PublicEvolving; + import java.io.Serializable; import java.util.Set; /** - * Interface for a backend that manages operator state. + * This interface contains methods for registering operator state with a managed store. */ +@PublicEvolving public interface OperatorStateStore { - String DEFAULT_OPERATOR_STATE_NAME = "_default_"; - /** * Creates a state descriptor of the given name that uses Java serialization to persist the * state. @@ -39,7 +40,7 @@ public interface OperatorStateStore { * @return A list state using Java serialization to serialize state objects. * @throws Exception */ - ListState getSerializableListState(String stateName) throws Exception; + ListState getSerializableListState(String stateName) throws Exception; /** * Creates (or restores) a list state. Each state is registered under a unique name. diff --git a/flink-core/src/main/java/org/apache/flink/core/fs/local/LocalDataInputStream.java b/flink-core/src/main/java/org/apache/flink/core/fs/local/LocalDataInputStream.java index e7b2828bf40b2..172da796f4f43 100644 --- a/flink-core/src/main/java/org/apache/flink/core/fs/local/LocalDataInputStream.java +++ b/flink-core/src/main/java/org/apache/flink/core/fs/local/LocalDataInputStream.java @@ -18,14 +18,14 @@ package org.apache.flink.core.fs.local; -import java.io.File; -import java.io.FileInputStream; -import java.io.IOException; - import org.apache.flink.annotation.Internal; import org.apache.flink.core.fs.FSDataInputStream; import javax.annotation.Nonnull; +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.nio.channels.FileChannel; /** * The LocalDataInputStream class is a wrapper class for a data @@ -36,6 +36,7 @@ public class LocalDataInputStream extends FSDataInputStream { /** The file input stream used to read data from.*/ private final FileInputStream fis; + private final FileChannel fileChannel; /** * Constructs a new LocalDataInputStream object from a given {@link File} object. @@ -46,16 +47,19 @@ public class LocalDataInputStream extends FSDataInputStream { */ public LocalDataInputStream(File file) throws IOException { this.fis = new FileInputStream(file); + this.fileChannel = fis.getChannel(); } @Override public void seek(long desired) throws IOException { - this.fis.getChannel().position(desired); + if (desired != getPos()) { + this.fileChannel.position(desired); + } } @Override public long getPos() throws IOException { - return this.fis.getChannel().position(); + return this.fileChannel.position(); } @Override @@ -70,6 +74,7 @@ public int read(@Nonnull byte[] buffer, int offset, int length) throws IOExcepti @Override public void close() throws IOException { + // Accoring to javadoc, this also closes the channel this.fis.close(); } diff --git a/flink-core/src/main/java/org/apache/flink/util/CollectionUtil.java b/flink-core/src/main/java/org/apache/flink/util/CollectionUtil.java new file mode 100644 index 0000000000000..15d00aea71f41 --- /dev/null +++ b/flink-core/src/main/java/org/apache/flink/util/CollectionUtil.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.util; + +import java.util.Collection; +import java.util.Map; + +public final class CollectionUtil { + + private CollectionUtil() { + throw new AssertionError(); + } + + public static boolean isNullOrEmpty(Collection collection) { + return collection == null || collection.isEmpty(); + } + + public static boolean isNullOrEmpty(Map map) { + return map == null || map.isEmpty(); + } +} \ No newline at end of file diff --git a/flink-core/src/main/java/org/apache/flink/util/FutureUtil.java b/flink-core/src/main/java/org/apache/flink/util/FutureUtil.java new file mode 100644 index 0000000000000..62d836bb62b0d --- /dev/null +++ b/flink-core/src/main/java/org/apache/flink/util/FutureUtil.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.util; + +import java.util.concurrent.ExecutionException; +import java.util.concurrent.RunnableFuture; + +public class FutureUtil { + + private FutureUtil() { + throw new AssertionError(); + } + + public static T runIfNotDoneAndGet(RunnableFuture future) throws ExecutionException, InterruptedException { + + if (null == future) { + return null; + } + + if (!future.isDone()) { + future.run(); + } + + return future.get(); + } +} \ No newline at end of file diff --git a/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/operator/CEPOperatorTest.java b/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/operator/CEPOperatorTest.java index 1fd8de80f0acb..0f49b138884cb 100644 --- a/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/operator/CEPOperatorTest.java +++ b/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/operator/CEPOperatorTest.java @@ -135,7 +135,7 @@ public Integer getKey(Event value) throws Exception { harness.processElement(new StreamRecord(new Event(42, "foobar", 1.0), 2)); // simulate snapshot/restore with some elements in internal sorting queue - StreamStateHandle snapshot = harness.snapshot(0, 0); + StreamStateHandle snapshot = harness.snapshotLegacy(0, 0); harness.close(); harness = new OneInputStreamOperatorTestHarness<>( @@ -157,7 +157,7 @@ public Integer getKey(Event value) throws Exception { harness.processWatermark(new Watermark(2)); // simulate snapshot/restore with empty element queue but NFA state - StreamStateHandle snapshot2 = harness.snapshot(1, 1); + StreamStateHandle snapshot2 = harness.snapshotLegacy(1, 1); harness.close(); harness = new OneInputStreamOperatorTestHarness<>( @@ -228,7 +228,7 @@ public Integer getKey(Event value) throws Exception { harness.processElement(new StreamRecord(new Event(42, "foobar", 1.0), 2)); // simulate snapshot/restore with some elements in internal sorting queue - StreamStateHandle snapshot = harness.snapshot(0, 0); + StreamStateHandle snapshot = harness.snapshotLegacy(0, 0); harness.close(); harness = new KeyedOneInputStreamOperatorTestHarness<>( @@ -254,7 +254,7 @@ public Integer getKey(Event value) throws Exception { harness.processWatermark(new Watermark(2)); // simulate snapshot/restore with empty element queue but NFA state - StreamStateHandle snapshot2 = harness.snapshot(1, 1); + StreamStateHandle snapshot2 = harness.snapshotLegacy(1, 1); harness.close(); harness = new KeyedOneInputStreamOperatorTestHarness<>( @@ -337,7 +337,7 @@ public Integer getKey(Event value) throws Exception { harness.processElement(new StreamRecord(new Event(42, "foobar", 1.0), 2)); // simulate snapshot/restore with some elements in internal sorting queue - StreamStateHandle snapshot = harness.snapshot(0, 0); + StreamStateHandle snapshot = harness.snapshotLegacy(0, 0); harness.close(); harness = new KeyedOneInputStreamOperatorTestHarness<>( @@ -368,7 +368,7 @@ public Integer getKey(Event value) throws Exception { harness.processWatermark(new Watermark(2)); // simulate snapshot/restore with empty element queue but NFA state - StreamStateHandle snapshot2 = harness.snapshot(1, 1); + StreamStateHandle snapshot2 = harness.snapshotLegacy(1, 1); harness.close(); harness = new KeyedOneInputStreamOperatorTestHarness<>( diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java index 00028c43d847d..588ba8427b1f4 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java @@ -36,22 +36,10 @@ import org.apache.flink.runtime.messages.checkpoint.DeclineCheckpoint; import org.apache.flink.runtime.messages.checkpoint.NotifyCheckpointComplete; import org.apache.flink.runtime.messages.checkpoint.TriggerCheckpoint; -import org.apache.flink.runtime.state.ChainedStateHandle; -import org.apache.flink.runtime.state.CheckpointStateHandles; -import org.apache.flink.runtime.state.KeyGroupRange; -import org.apache.flink.runtime.state.KeyGroupRangeAssignment; -import org.apache.flink.runtime.state.KeyGroupsStateHandle; -import org.apache.flink.runtime.state.OperatorStateHandle; -import org.apache.flink.runtime.state.StreamStateHandle; -import org.apache.flink.util.Preconditions; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.util.ArrayDeque; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.Collections; import java.util.HashMap; import java.util.Iterator; import java.util.LinkedHashMap; @@ -444,11 +432,11 @@ public void run() { // note that checkpoint completion discards the pending checkpoint object if (!checkpoint.isDiscarded()) { LOG.info("Checkpoint " + checkpointID + " expired before completing."); - + checkpoint.abortExpired(); pendingCheckpoints.remove(checkpointID); rememberRecentCheckpointId(checkpointID); - + triggerQueuedRequests(); } } @@ -578,7 +566,7 @@ public boolean receiveDeclineMessage(DeclineCheckpoint message) throws Exception isPendingCheckpoint = true; LOG.info("Discarding checkpoint " + checkpointId - + " because of checkpoint decline from task " + message.getTaskExecutionId()); + + " because of checkpoint decline from task " + message.getTaskExecutionId()); pendingCheckpoints.remove(checkpointId); checkpoint.abortDeclined(); @@ -602,7 +590,7 @@ public boolean receiveDeclineMessage(DeclineCheckpoint message) throws Exception } else if (checkpoint != null) { // this should not happen throw new IllegalStateException( - "Received message for discarded but non-removed checkpoint " + checkpointId); + "Received message for discarded but non-removed checkpoint " + checkpointId); } else { // message is for an unknown checkpoint, or comes too late (checkpoint disposed) if (recentPendingCheckpoints.contains(checkpointId)) { @@ -660,7 +648,7 @@ public boolean receiveAcknowledgeMessage(AcknowledgeCheckpoint message) throws E if (checkpoint.acknowledgeTask( message.getTaskExecutionId(), - message.getCheckpointStateHandles())) { + message.getSubtaskState())) { if (checkpoint.isFullyAcknowledged()) { completed = checkpoint.finalizeCheckpoint(); @@ -804,199 +792,15 @@ public boolean restoreLatestCheckpointedState( LOG.info("Restoring from latest valid checkpoint: {}.", latest); - for (Map.Entry taskGroupStateEntry: latest.getTaskStates().entrySet()) { - TaskState taskState = taskGroupStateEntry.getValue(); - ExecutionJobVertex executionJobVertex = tasks.get(taskGroupStateEntry.getKey()); - - if (executionJobVertex != null) { - // check that the number of key groups have not changed - if (taskState.getMaxParallelism() != executionJobVertex.getMaxParallelism()) { - throw new IllegalStateException("The maximum parallelism (" + - taskState.getMaxParallelism() + ") with which the latest " + - "checkpoint of the execution job vertex " + executionJobVertex + - " has been taken and the current maximum parallelism (" + - executionJobVertex.getMaxParallelism() + ") changed. This " + - "is currently not supported."); - } - - - int oldParallelism = taskState.getParallelism(); - int newParallelism = executionJobVertex.getParallelism(); - boolean parallelismChanged = oldParallelism != newParallelism; - boolean hasNonPartitionedState = taskState.hasNonPartitionedState(); - - if (hasNonPartitionedState && parallelismChanged) { - throw new IllegalStateException("Cannot restore the latest checkpoint because " + - "the operator " + executionJobVertex.getJobVertexId() + " has non-partitioned " + - "state and its parallelism changed. The operator" + executionJobVertex.getJobVertexId() + - " has parallelism " + newParallelism + " whereas the corresponding" + - "state object has a parallelism of " + oldParallelism); - } - - List keyGroupPartitions = createKeyGroupPartitions( - executionJobVertex.getMaxParallelism(), - newParallelism); - - // operator chain index -> list of the stored partitionables states from all parallel instances - @SuppressWarnings("unchecked") - List[] chainParallelStates = - new List[taskState.getChainLength()]; - - for (int i = 0; i < oldParallelism; ++i) { - - ChainedStateHandle partitionableState = - taskState.getPartitionableState(i); - - if (partitionableState != null) { - for (int j = 0; j < partitionableState.getLength(); ++j) { - OperatorStateHandle opParalleState = partitionableState.get(j); - if (opParalleState != null) { - List opParallelStates = - chainParallelStates[j]; - if (opParallelStates == null) { - opParallelStates = new ArrayList<>(); - chainParallelStates[j] = opParallelStates; - } - opParallelStates.add(opParalleState); - } - } - } - } - - // operator chain index -> lists with collected states (one collection for each parallel subtasks) - @SuppressWarnings("unchecked") - List>[] redistributedParallelStates = - new List[taskState.getChainLength()]; - - //TODO here we can employ different redistribution strategies for state, e.g. union state. For now we only offer round robin as the default. - OperatorStateRepartitioner repartitioner = RoundRobinOperatorStateRepartitioner.INSTANCE; - - for (int i = 0; i < chainParallelStates.length; ++i) { - List chainOpParallelStates = chainParallelStates[i]; - if (chainOpParallelStates != null) { - //We only redistribute if the parallelism of the operator changed from previous executions - if (parallelismChanged) { - redistributedParallelStates[i] = repartitioner.repartitionState( - chainOpParallelStates, - newParallelism); - } else { - List> repacking = new ArrayList<>(newParallelism); - for (OperatorStateHandle operatorStateHandle : chainOpParallelStates) { - repacking.add(Collections.singletonList(operatorStateHandle)); - } - redistributedParallelStates[i] = repacking; - } - } - } - - int counter = 0; - - for (int i = 0; i < newParallelism; ++i) { - - // non-partitioned state - ChainedStateHandle state = null; - - if (hasNonPartitionedState) { - SubtaskState subtaskState = taskState.getState(i); - - if (subtaskState != null) { - // count the number of executions for which we set a state - ++counter; - state = subtaskState.getChainedStateHandle(); - } - } - - // partitionable state - @SuppressWarnings("unchecked") - Collection[] ia = new Collection[taskState.getChainLength()]; - List> subTaskPartitionableState = Arrays.asList(ia); - - for (int j = 0; j < redistributedParallelStates.length; ++j) { - List> redistributedParallelState = - redistributedParallelStates[j]; - - if (redistributedParallelState != null) { - subTaskPartitionableState.set(j, redistributedParallelState.get(i)); - } - } + StateAssignmentOperation stateAssignmentOperation = + new StateAssignmentOperation(tasks, latest, allOrNothingState); - // key-partitioned state - KeyGroupRange subtaskKeyGroupIds = keyGroupPartitions.get(i); - - // Again, we only repartition if the parallelism changed - List subtaskKeyGroupStates = parallelismChanged ? - getKeyGroupsStateHandles(taskState.getKeyGroupStates(), subtaskKeyGroupIds) - : Collections.singletonList(taskState.getKeyGroupState(i)); - - Execution currentExecutionAttempt = executionJobVertex - .getTaskVertices()[i] - .getCurrentExecutionAttempt(); - - CheckpointStateHandles checkpointStateHandles = new CheckpointStateHandles( - state, - null/*subTaskPartionableState*/, //TODO chose right structure and put redistributed states here - subtaskKeyGroupStates); - - currentExecutionAttempt.setInitialState(checkpointStateHandles, subTaskPartitionableState); - } - - if (allOrNothingState && counter > 0 && counter < newParallelism) { - throw new IllegalStateException("The checkpoint contained state only for " + - "a subset of tasks for vertex " + executionJobVertex); - } - } else { - throw new IllegalStateException("There is no execution job vertex for the job" + - " vertex ID " + taskGroupStateEntry.getKey()); - } - } + stateAssignmentOperation.assignStates(); return true; } } - /** - * Determine the subset of {@link KeyGroupsStateHandle KeyGroupsStateHandles} with correct - * key group index for the given subtask {@link KeyGroupRange}. - * - *

This is publicly visible to be used in tests. - */ - public static List getKeyGroupsStateHandles( - Collection allKeyGroupsHandles, - KeyGroupRange subtaskKeyGroupIds) { - - List subtaskKeyGroupStates = new ArrayList<>(); - - for (KeyGroupsStateHandle storedKeyGroup : allKeyGroupsHandles) { - KeyGroupsStateHandle intersection = storedKeyGroup.getKeyGroupIntersection(subtaskKeyGroupIds); - if (intersection.getNumberOfKeyGroups() > 0) { - subtaskKeyGroupStates.add(intersection); - } - } - return subtaskKeyGroupStates; - } - - /** - * Groups the available set of key groups into key group partitions. A key group partition is - * the set of key groups which is assigned to the same task. Each set of the returned list - * constitutes a key group partition. - * - * IMPORTANT: The assignment of key groups to partitions has to be in sync with the - * KeyGroupStreamPartitioner. - * - * @param numberKeyGroups Number of available key groups (indexed from 0 to numberKeyGroups - 1) - * @param parallelism Parallelism to generate the key group partitioning for - * @return List of key group partitions - */ - public static List createKeyGroupPartitions(int numberKeyGroups, int parallelism) { - Preconditions.checkArgument(numberKeyGroups >= parallelism); - List result = new ArrayList<>(parallelism); - - for (int i = 0; i < parallelism; ++i) { - result.add(KeyGroupRangeAssignment.computeKeyGroupRangeForOperatorIndex(numberKeyGroups, parallelism, i)); - } - return result; - } - // -------------------------------------------------------------------------------------------- // Accessors // -------------------------------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointMetaData.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointMetaData.java index 6f117f2a8f65d..2627b220a466b 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointMetaData.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointMetaData.java @@ -59,6 +59,15 @@ public CheckpointMetaData( asynchronousDurationMillis); } + public CheckpointMetaData( + long checkpointId, + long timestamp, + CheckpointMetrics metrics) { + this.checkpointId = checkpointId; + this.timestamp = timestamp; + this.metrics = Preconditions.checkNotNull(metrics); + } + public CheckpointMetrics getMetrics() { return metrics; } @@ -110,4 +119,37 @@ public long getSyncDurationMillis() { public long getAsyncDurationMillis() { return metrics.getAsyncDurationMillis(); } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + CheckpointMetaData that = (CheckpointMetaData) o; + + return (checkpointId == that.checkpointId) + && (timestamp == that.timestamp) + && (metrics.equals(that.metrics)); + } + + @Override + public int hashCode() { + int result = (int) (checkpointId ^ (checkpointId >>> 32)); + result = 31 * result + (int) (timestamp ^ (timestamp >>> 32)); + result = 31 * result + metrics.hashCode(); + return result; + } + + @Override + public String toString() { + return "CheckpointMetaData{" + + "checkpointId=" + checkpointId + + ", timestamp=" + timestamp + + ", metrics=" + metrics + + '}'; + } } \ No newline at end of file diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java index 6f503929e44bd..92dca21d9dce1 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java @@ -28,8 +28,6 @@ import org.apache.flink.runtime.executiongraph.ExecutionVertex; import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.state.ChainedStateHandle; -import org.apache.flink.runtime.state.CheckpointStateHandles; -import org.apache.flink.runtime.state.KeyGroupsStateHandle; import org.apache.flink.runtime.state.OperatorStateHandle; import org.apache.flink.runtime.state.StateUtil; import org.apache.flink.runtime.state.StreamStateHandle; @@ -37,7 +35,6 @@ import org.slf4j.LoggerFactory; import java.util.HashMap; -import java.util.List; import java.util.Map; import static org.apache.flink.util.Preconditions.checkArgument; @@ -234,80 +231,61 @@ public CompletedCheckpoint finalizeCheckpoint() throws Exception { public boolean acknowledgeTask( ExecutionAttemptID attemptID, - CheckpointStateHandles checkpointStateHandles) { + SubtaskState checkpointedSubtaskState) { synchronized (lock) { + if (discarded) { return false; } - ExecutionVertex vertex = notYetAcknowledgedTasks.remove(attemptID); - - if (vertex != null) { - if (checkpointStateHandles != null) { - List keyGroupsState = checkpointStateHandles.getKeyGroupsStateHandle(); - ChainedStateHandle nonPartitionedState = - checkpointStateHandles.getNonPartitionedStateHandles(); - ChainedStateHandle partitioneableState = - checkpointStateHandles.getPartitioneableStateHandles(); - - if (nonPartitionedState != null || partitioneableState != null || keyGroupsState != null) { - - JobVertexID jobVertexID = vertex.getJobvertexId(); + final ExecutionVertex vertex = notYetAcknowledgedTasks.remove(attemptID); - int subtaskIndex = vertex.getParallelSubtaskIndex(); + if (vertex == null) { + return false; + } - TaskState taskState; + if (null != checkpointedSubtaskState && checkpointedSubtaskState.hasState()) { - if (taskStates.containsKey(jobVertexID)) { - taskState = taskStates.get(jobVertexID); - } else { - //TODO this should go away when we remove chained state, assigning state to operators directly instead - int chainLength; - if (nonPartitionedState != null) { - chainLength = nonPartitionedState.getLength(); - } else if (partitioneableState != null) { - chainLength = partitioneableState.getLength(); - } else { - chainLength = 1; - } + JobVertexID jobVertexID = vertex.getJobvertexId(); - taskState = new TaskState( - jobVertexID, - vertex.getTotalNumberOfParallelSubtasks(), - vertex.getMaxParallelism(), - chainLength); + int subtaskIndex = vertex.getParallelSubtaskIndex(); - taskStates.put(jobVertexID, taskState); - } + TaskState taskState = taskStates.get(jobVertexID); - long duration = System.currentTimeMillis() - checkpointTimestamp; - - if (nonPartitionedState != null) { - taskState.putState( - subtaskIndex, - new SubtaskState(nonPartitionedState, duration)); - } + if (null == taskState) { + ChainedStateHandle nonPartitionedState = + checkpointedSubtaskState.getLegacyOperatorState(); + ChainedStateHandle partitioneableState = + checkpointedSubtaskState.getManagedOperatorState(); + //TODO this should go away when we remove chained state, assigning state to operators directly instead + int chainLength; + if (nonPartitionedState != null) { + chainLength = nonPartitionedState.getLength(); + } else if (partitioneableState != null) { + chainLength = partitioneableState.getLength(); + } else { + chainLength = 1; + } - if(partitioneableState != null && !partitioneableState.isEmpty()) { - taskState.putPartitionableState(subtaskIndex, partitioneableState); - } + taskState = new TaskState( + jobVertexID, + vertex.getTotalNumberOfParallelSubtasks(), + vertex.getMaxParallelism(), + chainLength); - // currently a checkpoint can only contain keyed state - // for the head operator - if (keyGroupsState != null && !keyGroupsState.isEmpty()) { - KeyGroupsStateHandle keyGroupsStateHandle = keyGroupsState.get(0); - taskState.putKeyedState(subtaskIndex, keyGroupsStateHandle); - } - } + taskStates.put(jobVertexID, taskState); } - ++numAcknowledgedTasks; + long duration = System.currentTimeMillis() - checkpointTimestamp; + checkpointedSubtaskState.setDuration(duration); - return true; - } else { - return false; + taskState.putState(subtaskIndex, checkpointedSubtaskState); } + + ++numAcknowledgedTasks; + + return true; } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/RoundRobinOperatorStateRepartitioner.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/RoundRobinOperatorStateRepartitioner.java index 09a35f659c6f6..16a7e27048908 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/RoundRobinOperatorStateRepartitioner.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/RoundRobinOperatorStateRepartitioner.java @@ -176,7 +176,7 @@ private List> repartition( Map mergeMap = mergeMapList.get(parallelOpIdx); OperatorStateHandle psh = mergeMap.get(handleWithOffsets.f0); if (psh == null) { - psh = new OperatorStateHandle(handleWithOffsets.f0, new HashMap()); + psh = new OperatorStateHandle(new HashMap(), handleWithOffsets.f0); mergeMap.put(handleWithOffsets.f0, psh); } psh.getStateNameToPartitionOffsets().put(e.getKey(), offs); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java new file mode 100644 index 0000000000000..8e2b0bf2a4198 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java @@ -0,0 +1,329 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint; + +import org.apache.flink.runtime.executiongraph.Execution; +import org.apache.flink.runtime.executiongraph.ExecutionJobVertex; +import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.state.ChainedStateHandle; +import org.apache.flink.runtime.state.KeyGroupRange; +import org.apache.flink.runtime.state.KeyGroupRangeAssignment; +import org.apache.flink.runtime.state.KeyGroupsStateHandle; +import org.apache.flink.runtime.state.OperatorStateHandle; +import org.apache.flink.runtime.state.StreamStateHandle; +import org.apache.flink.runtime.state.TaskStateHandles; +import org.apache.flink.util.Preconditions; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +/** + * This class encapsulates the operation of assigning restored state when restoring from a checkpoint. + */ +public class StateAssignmentOperation { + + public StateAssignmentOperation( + Map tasks, + CompletedCheckpoint latest, + boolean allOrNothingState) { + + this.tasks = tasks; + this.latest = latest; + this.allOrNothingState = allOrNothingState; + } + + private final Map tasks; + private final CompletedCheckpoint latest; + private final boolean allOrNothingState; + + public boolean assignStates() throws Exception { + + for (Map.Entry taskGroupStateEntry : latest.getTaskStates().entrySet()) { + TaskState taskState = taskGroupStateEntry.getValue(); + ExecutionJobVertex executionJobVertex = tasks.get(taskGroupStateEntry.getKey()); + + if (executionJobVertex != null) { + // check that the number of key groups have not changed + if (taskState.getMaxParallelism() != executionJobVertex.getMaxParallelism()) { + throw new IllegalStateException("The maximum parallelism (" + + taskState.getMaxParallelism() + ") with which the latest " + + "checkpoint of the execution job vertex " + executionJobVertex + + " has been taken and the current maximum parallelism (" + + executionJobVertex.getMaxParallelism() + ") changed. This " + + "is currently not supported."); + } + + final int oldParallelism = taskState.getParallelism(); + final int newParallelism = executionJobVertex.getParallelism(); + final boolean parallelismChanged = oldParallelism != newParallelism; + final boolean hasNonPartitionedState = taskState.hasNonPartitionedState(); + + if (hasNonPartitionedState && parallelismChanged) { + throw new IllegalStateException("Cannot restore the latest checkpoint because " + + "the operator " + executionJobVertex.getJobVertexId() + " has non-partitioned " + + "state and its parallelism changed. The operator" + executionJobVertex.getJobVertexId() + + " has parallelism " + newParallelism + " whereas the corresponding" + + "state object has a parallelism of " + oldParallelism); + } + + List keyGroupPartitions = createKeyGroupPartitions( + executionJobVertex.getMaxParallelism(), + newParallelism); + + final int chainLength = taskState.getChainLength(); + + // operator chain idx -> list of the stored op states from all parallel instances for this chain idx + @SuppressWarnings("unchecked") + List[] parallelOpStatesBackend = new List[chainLength]; + @SuppressWarnings("unchecked") + List[] parallelOpStatesStream = new List[chainLength]; + + List parallelKeyedStatesBackend = new ArrayList<>(oldParallelism); + List parallelKeyedStateStream = new ArrayList<>(oldParallelism); + + int counter = 0; + for (int p = 0; p < oldParallelism; ++p) { + + SubtaskState subtaskState = taskState.getState(p); + + if (null != subtaskState) { + + ++counter; + + collectParallelStatesByChainOperator( + parallelOpStatesBackend, subtaskState.getManagedOperatorState()); + + collectParallelStatesByChainOperator( + parallelOpStatesStream, subtaskState.getRawOperatorState()); + + KeyGroupsStateHandle keyedStateBackend = subtaskState.getManagedKeyedState(); + if (null != keyedStateBackend) { + parallelKeyedStatesBackend.add(keyedStateBackend); + } + + KeyGroupsStateHandle keyedStateStream = subtaskState.getRawKeyedState(); + if (null != keyedStateStream) { + parallelKeyedStateStream.add(keyedStateStream); + } + } + } + + if (allOrNothingState && counter > 0 && counter < oldParallelism) { + throw new IllegalStateException("The checkpoint contained state only for " + + "a subset of tasks for vertex " + executionJobVertex); + } + + // operator chain index -> lists with collected states (one collection for each parallel subtasks) + @SuppressWarnings("unchecked") + List>[] partitionedParallelStatesBackend = new List[chainLength]; + + @SuppressWarnings("unchecked") + List>[] partitionedParallelStatesStream = new List[chainLength]; + + //TODO here we can employ different redistribution strategies for state, e.g. union state. + // For now we only offer round robin as the default. + OperatorStateRepartitioner opStateRepartitioner = RoundRobinOperatorStateRepartitioner.INSTANCE; + + for (int chainIdx = 0; chainIdx < chainLength; ++chainIdx) { + + List chainOpParallelStatesBackend = parallelOpStatesBackend[chainIdx]; + List chainOpParallelStatesStream = parallelOpStatesStream[chainIdx]; + + partitionedParallelStatesBackend[chainIdx] = applyRepartitioner( + opStateRepartitioner, + chainOpParallelStatesBackend, + oldParallelism, + newParallelism); + + partitionedParallelStatesStream[chainIdx] = applyRepartitioner( + opStateRepartitioner, + chainOpParallelStatesStream, + oldParallelism, + newParallelism); + } + + for (int subTaskIdx = 0; subTaskIdx < newParallelism; ++subTaskIdx) { + // non-partitioned state + ChainedStateHandle nonPartitionableState = null; + + if (hasNonPartitionedState) { + // count the number of executions for which we set a state + nonPartitionableState = taskState.getState(subTaskIdx).getLegacyOperatorState(); + } + + // partitionable state + @SuppressWarnings("unchecked") + Collection[] iab = new Collection[chainLength]; + @SuppressWarnings("unchecked") + Collection[] ias = new Collection[chainLength]; + List> operatorStateFromBackend = Arrays.asList(iab); + List> operatorStateFromStream = Arrays.asList(ias); + + for (int chainIdx = 0; chainIdx < partitionedParallelStatesBackend.length; ++chainIdx) { + List> redistributedOpStateBackend = + partitionedParallelStatesBackend[chainIdx]; + + List> redistributedOpStateStream = + partitionedParallelStatesStream[chainIdx]; + + if (redistributedOpStateBackend != null) { + operatorStateFromBackend.set(chainIdx, redistributedOpStateBackend.get(subTaskIdx)); + } + + if (redistributedOpStateStream != null) { + operatorStateFromStream.set(chainIdx, redistributedOpStateStream.get(subTaskIdx)); + } + } + + Execution currentExecutionAttempt = executionJobVertex + .getTaskVertices()[subTaskIdx] + .getCurrentExecutionAttempt(); + + List newKeyedStatesBackend; + List newKeyedStateStream; + if (parallelismChanged) { + KeyGroupRange subtaskKeyGroupIds = keyGroupPartitions.get(subTaskIdx); + newKeyedStatesBackend = getKeyGroupsStateHandles(parallelKeyedStatesBackend, subtaskKeyGroupIds); + newKeyedStateStream = getKeyGroupsStateHandles(parallelKeyedStateStream, subtaskKeyGroupIds); + } else { + SubtaskState subtaskState = taskState.getState(subTaskIdx); + KeyGroupsStateHandle oldKeyedStatesBackend = subtaskState.getManagedKeyedState(); + KeyGroupsStateHandle oldKeyedStatesStream = subtaskState.getRawKeyedState(); + newKeyedStatesBackend = oldKeyedStatesBackend != null ? Collections.singletonList(oldKeyedStatesBackend) : null; + newKeyedStateStream = oldKeyedStatesStream != null ? Collections.singletonList(oldKeyedStatesStream) : null; + } + + TaskStateHandles taskStateHandles = new TaskStateHandles( + nonPartitionableState, + operatorStateFromBackend, + operatorStateFromStream, + newKeyedStatesBackend, + newKeyedStateStream); + + currentExecutionAttempt.setInitialState(taskStateHandles); + } + + } else { + throw new IllegalStateException("There is no execution job vertex for the job" + + " vertex ID " + taskGroupStateEntry.getKey()); + } + } + + return true; + + } + + /** + * Determine the subset of {@link KeyGroupsStateHandle KeyGroupsStateHandles} with correct + * key group index for the given subtask {@link KeyGroupRange}. + *

+ *

This is publicly visible to be used in tests. + */ + public static List getKeyGroupsStateHandles( + Collection allKeyGroupsHandles, + KeyGroupRange subtaskKeyGroupIds) { + + List subtaskKeyGroupStates = new ArrayList<>(); + + for (KeyGroupsStateHandle storedKeyGroup : allKeyGroupsHandles) { + KeyGroupsStateHandle intersection = storedKeyGroup.getKeyGroupIntersection(subtaskKeyGroupIds); + if (intersection.getNumberOfKeyGroups() > 0) { + subtaskKeyGroupStates.add(intersection); + } + } + return subtaskKeyGroupStates; + } + + /** + * Groups the available set of key groups into key group partitions. A key group partition is + * the set of key groups which is assigned to the same task. Each set of the returned list + * constitutes a key group partition. + *

+ * IMPORTANT: The assignment of key groups to partitions has to be in sync with the + * KeyGroupStreamPartitioner. + * + * @param numberKeyGroups Number of available key groups (indexed from 0 to numberKeyGroups - 1) + * @param parallelism Parallelism to generate the key group partitioning for + * @return List of key group partitions + */ + public static List createKeyGroupPartitions(int numberKeyGroups, int parallelism) { + Preconditions.checkArgument(numberKeyGroups >= parallelism); + List result = new ArrayList<>(parallelism); + + for (int i = 0; i < parallelism; ++i) { + result.add(KeyGroupRangeAssignment.computeKeyGroupRangeForOperatorIndex(numberKeyGroups, parallelism, i)); + } + return result; + } + + /** + * @param chainParallelOpStates array = chain ops, array[idx] = parallel states for this chain op. + * @param chainOpState + */ + private static void collectParallelStatesByChainOperator( + List[] chainParallelOpStates, ChainedStateHandle chainOpState) { + + if (null != chainOpState) { + for (int chainIdx = 0; chainIdx < chainParallelOpStates.length; ++chainIdx) { + OperatorStateHandle operatorState = chainOpState.get(chainIdx); + + if (null != operatorState) { + + List opParallelStatesForOneChainOp = chainParallelOpStates[chainIdx]; + + if (null == opParallelStatesForOneChainOp) { + opParallelStatesForOneChainOp = new ArrayList<>(); + chainParallelOpStates[chainIdx] = opParallelStatesForOneChainOp; + } + opParallelStatesForOneChainOp.add(operatorState); + } + } + } + } + + private static List> applyRepartitioner( + OperatorStateRepartitioner opStateRepartitioner, + List chainOpParallelStates, + int oldParallelism, + int newParallelism) { + + if (chainOpParallelStates == null) { + return null; + } + + //We only redistribute if the parallelism of the operator changed from previous executions + if (newParallelism != oldParallelism) { + + return opStateRepartitioner.repartitionState( + chainOpParallelStates, + newParallelism); + } else { + + List> repackStream = new ArrayList<>(newParallelism); + for (OperatorStateHandle operatorStateHandle : chainOpParallelStates) { + repackStream.add(Collections.singletonList(operatorStateHandle)); + } + return repackStream; + } + } +} \ No newline at end of file diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java index 2aa049127b754..1865a0cff5c34 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java @@ -19,10 +19,13 @@ package org.apache.flink.runtime.checkpoint; import org.apache.flink.runtime.state.ChainedStateHandle; +import org.apache.flink.runtime.state.KeyGroupsStateHandle; +import org.apache.flink.runtime.state.OperatorStateHandle; import org.apache.flink.runtime.state.StateObject; +import org.apache.flink.runtime.state.StateUtil; import org.apache.flink.runtime.state.StreamStateHandle; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; + +import java.util.Arrays; import static org.apache.flink.util.Preconditions.checkNotNull; @@ -34,10 +37,15 @@ public class SubtaskState implements StateObject { private static final long serialVersionUID = -2394696997971923995L; - private static final Logger LOG = LoggerFactory.getLogger(SubtaskState.class); - - /** The state of the parallel operator */ - private final ChainedStateHandle chainedStateHandle; + /** + * The state of the parallel operator + */ + @Deprecated + private final ChainedStateHandle legacyOperatorState; + private final ChainedStateHandle managedOperatorState; + private final ChainedStateHandle rawOperatorState; + private final KeyGroupsStateHandle managedKeyedState; + private final KeyGroupsStateHandle rawKeyedState; /** * The state size. This is also part of the deserialized state handle. @@ -46,26 +54,76 @@ public class SubtaskState implements StateObject { */ private final long stateSize; - /** The duration of the checkpoint (ack timestamp - trigger timestamp). */ - private final long duration; - + /** + * The duration of the checkpoint (ack timestamp - trigger timestamp). + */ + private long duration; + public SubtaskState( - ChainedStateHandle chainedStateHandle, + ChainedStateHandle legacyOperatorState, + ChainedStateHandle managedOperatorState, + ChainedStateHandle rawOperatorState, + KeyGroupsStateHandle managedKeyedState, + KeyGroupsStateHandle rawKeyedState) { + this(legacyOperatorState, + managedOperatorState, + rawOperatorState, + managedKeyedState, + rawKeyedState, + 0L); + } + + public SubtaskState( + ChainedStateHandle legacyOperatorState, + ChainedStateHandle managedOperatorState, + ChainedStateHandle rawOperatorState, + KeyGroupsStateHandle managedKeyedState, + KeyGroupsStateHandle rawKeyedState, long duration) { - this.chainedStateHandle = checkNotNull(chainedStateHandle, "State"); + this.legacyOperatorState = checkNotNull(legacyOperatorState, "State"); + this.managedOperatorState = managedOperatorState; + this.rawOperatorState = rawOperatorState; + this.managedKeyedState = managedKeyedState; + this.rawKeyedState = rawKeyedState; this.duration = duration; try { - stateSize = chainedStateHandle.getStateSize(); + long calculateStateSize = getSizeNullSafe(legacyOperatorState); + calculateStateSize += getSizeNullSafe(managedOperatorState); + calculateStateSize += getSizeNullSafe(rawOperatorState); + calculateStateSize += getSizeNullSafe(managedKeyedState); + calculateStateSize += getSizeNullSafe(rawKeyedState); + stateSize = calculateStateSize; } catch (Exception e) { throw new RuntimeException("Failed to get state size.", e); } } + private static final long getSizeNullSafe(StateObject stateObject) throws Exception { + return stateObject != null ? stateObject.getStateSize() : 0L; + } + // -------------------------------------------------------------------------------------------- - - public ChainedStateHandle getChainedStateHandle() { - return chainedStateHandle; + + @Deprecated + public ChainedStateHandle getLegacyOperatorState() { + return legacyOperatorState; + } + + public ChainedStateHandle getManagedOperatorState() { + return managedOperatorState; + } + + public ChainedStateHandle getRawOperatorState() { + return rawOperatorState; + } + + public KeyGroupsStateHandle getManagedKeyedState() { + return managedKeyedState; + } + + public KeyGroupsStateHandle getRawKeyedState() { + return rawKeyedState; } @Override @@ -79,35 +137,94 @@ public long getDuration() { @Override public void discardState() throws Exception { - chainedStateHandle.discardState(); + StateUtil.bestEffortDiscardAllStateObjects( + Arrays.asList( + legacyOperatorState, + managedOperatorState, + rawOperatorState, + managedKeyedState, + rawKeyedState)); + } + + public void setDuration(long duration) { + this.duration = duration; } // -------------------------------------------------------------------------------------------- + @Override public boolean equals(Object o) { if (this == o) { return true; } - else if (o instanceof SubtaskState) { - SubtaskState that = (SubtaskState) o; - return this.chainedStateHandle.equals(that.chainedStateHandle) && stateSize == that.stateSize && - duration == that.duration; + if (o == null || getClass() != o.getClass()) { + return false; + } + + SubtaskState that = (SubtaskState) o; + + if (stateSize != that.stateSize) { + return false; + } + if (duration != that.duration) { + return false; + } + if (legacyOperatorState != null ? + !legacyOperatorState.equals(that.legacyOperatorState) + : that.legacyOperatorState != null) { + return false; + } + if (managedOperatorState != null ? + !managedOperatorState.equals(that.managedOperatorState) + : that.managedOperatorState != null) { + return false; } - else { + if (rawOperatorState != null ? + !rawOperatorState.equals(that.rawOperatorState) + : that.rawOperatorState != null) { return false; } + if (managedKeyedState != null ? + !managedKeyedState.equals(that.managedKeyedState) + : that.managedKeyedState != null) { + return false; + } + return rawKeyedState != null ? + rawKeyedState.equals(that.rawKeyedState) + : that.rawKeyedState == null; + + } + + public boolean hasState() { + return (null != legacyOperatorState && !legacyOperatorState.isEmpty()) + || (null != managedOperatorState && !managedOperatorState.isEmpty()) + || null != managedKeyedState + || null != rawKeyedState; } @Override public int hashCode() { - return (int) (this.stateSize ^ this.stateSize >>> 32) + - 31 * ((int) (this.duration ^ this.duration >>> 32) + - 31 * chainedStateHandle.hashCode()); + int result = legacyOperatorState != null ? legacyOperatorState.hashCode() : 0; + result = 31 * result + (managedOperatorState != null ? managedOperatorState.hashCode() : 0); + result = 31 * result + (rawOperatorState != null ? rawOperatorState.hashCode() : 0); + result = 31 * result + (managedKeyedState != null ? managedKeyedState.hashCode() : 0); + result = 31 * result + (rawKeyedState != null ? rawKeyedState.hashCode() : 0); + result = 31 * result + (int) (stateSize ^ (stateSize >>> 32)); + result = 31 * result + (int) (duration ^ (duration >>> 32)); + return result; } @Override public String toString() { - return String.format("SubtaskState(Size: %d, Duration: %d, State: %s)", stateSize, duration, chainedStateHandle); + return "SubtaskState{" + + "chainedStateHandle=" + legacyOperatorState + + ", operatorStateFromBackend=" + managedOperatorState + + ", operatorStateFromStream=" + rawOperatorState + + ", keyedStateFromBackend=" + managedKeyedState + + ", keyedStateHandleFromStream=" + rawKeyedState + + ", stateSize=" + stateSize + + ", duration=" + duration + + '}'; } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskState.java index 7e4eded86e10a..3cdc5e95e92c9 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskState.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskState.java @@ -18,11 +18,7 @@ package org.apache.flink.runtime.checkpoint; -import com.google.common.collect.Iterables; import org.apache.flink.runtime.jobgraph.JobVertexID; -import org.apache.flink.runtime.state.ChainedStateHandle; -import org.apache.flink.runtime.state.KeyGroupsStateHandle; -import org.apache.flink.runtime.state.OperatorStateHandle; import org.apache.flink.runtime.state.StateObject; import org.apache.flink.runtime.state.StateUtil; import org.apache.flink.util.Preconditions; @@ -49,12 +45,6 @@ public class TaskState implements StateObject { /** handles to non-partitioned states, subtaskindex -> subtaskstate */ private final Map subtaskStates; - /** handles to partitionable states, subtaskindex -> partitionable state */ - private final Map> partitionableStates; - - /** handles to key-partitioned states, subtaskindex -> keyed state */ - private final Map keyGroupsStateHandles; - /** parallelism of the operator when it was checkpointed */ private final int parallelism; @@ -62,6 +52,7 @@ public class TaskState implements StateObject { /** maximum parallelism of the operator when the job was first created */ private final int maxParallelism; + /** length of the operator chain */ private final int chainLength; public TaskState(JobVertexID jobVertexID, int parallelism, int maxParallelism, int chainLength) { @@ -73,8 +64,6 @@ public TaskState(JobVertexID jobVertexID, int parallelism, int maxParallelism, i this.jobVertexID = jobVertexID; this.subtaskStates = new HashMap<>(parallelism); - this.partitionableStates = new HashMap<>(parallelism); - this.keyGroupsStateHandles = new HashMap<>(parallelism); this.parallelism = parallelism; this.maxParallelism = maxParallelism; @@ -96,32 +85,6 @@ public void putState(int subtaskIndex, SubtaskState subtaskState) { } } - public void putPartitionableState( - int subtaskIndex, - ChainedStateHandle partitionableState) { - - Preconditions.checkNotNull(partitionableState); - - if (subtaskIndex < 0 || subtaskIndex >= parallelism) { - throw new IndexOutOfBoundsException("The given sub task index " + subtaskIndex + - " exceeds the maximum number of sub tasks " + subtaskStates.size()); - } else { - partitionableStates.put(subtaskIndex, partitionableState); - } - } - - public void putKeyedState(int subtaskIndex, KeyGroupsStateHandle keyGroupsStateHandle) { - Preconditions.checkNotNull(keyGroupsStateHandle); - - if (subtaskIndex < 0 || subtaskIndex >= parallelism) { - throw new IndexOutOfBoundsException("The given sub task index " + subtaskIndex + - " exceeds the maximum number of sub tasks " + subtaskStates.size()); - } else { - keyGroupsStateHandles.put(subtaskIndex, keyGroupsStateHandle); - } - } - - public SubtaskState getState(int subtaskIndex) { if (subtaskIndex < 0 || subtaskIndex >= parallelism) { throw new IndexOutOfBoundsException("The given sub task index " + subtaskIndex + @@ -131,24 +94,6 @@ public SubtaskState getState(int subtaskIndex) { } } - public ChainedStateHandle getPartitionableState(int subtaskIndex) { - if (subtaskIndex < 0 || subtaskIndex >= parallelism) { - throw new IndexOutOfBoundsException("The given sub task index " + subtaskIndex + - " exceeds the maximum number of sub tasks " + subtaskStates.size()); - } else { - return partitionableStates.get(subtaskIndex); - } - } - - public KeyGroupsStateHandle getKeyGroupState(int subtaskIndex) { - if (subtaskIndex < 0 || subtaskIndex >= parallelism) { - throw new IndexOutOfBoundsException("The given sub task index " + subtaskIndex + - " exceeds the maximum number of sub tasks " + keyGroupsStateHandles.size()); - } else { - return keyGroupsStateHandles.get(subtaskIndex); - } - } - public Collection getStates() { return subtaskStates.values(); } @@ -169,13 +114,9 @@ public int getChainLength() { return chainLength; } - public Collection getKeyGroupStates() { - return keyGroupsStateHandles.values(); - } - public boolean hasNonPartitionedState() { for(SubtaskState sts : subtaskStates.values()) { - if (sts != null && !sts.getChainedStateHandle().isEmpty()) { + if (sts != null && !sts.getLegacyOperatorState().isEmpty()) { return true; } } @@ -184,8 +125,7 @@ public boolean hasNonPartitionedState() { @Override public void discardState() throws Exception { - StateUtil.bestEffortDiscardAllStateObjects( - Iterables.concat(subtaskStates.values(), partitionableStates.values(), keyGroupsStateHandles.values())); + StateUtil.bestEffortDiscardAllStateObjects(subtaskStates.values()); } @@ -198,16 +138,6 @@ public long getStateSize() throws IOException { if (subtaskState != null) { result += subtaskState.getStateSize(); } - - ChainedStateHandle partitionableState = partitionableStates.get(i); - if (partitionableState != null) { - result += partitionableState.getStateSize(); - } - - KeyGroupsStateHandle keyGroupsState = keyGroupsStateHandles.get(i); - if (keyGroupsState != null) { - result += keyGroupsState.getStateSize(); - } } return result; @@ -220,9 +150,7 @@ public boolean equals(Object obj) { return jobVertexID.equals(other.jobVertexID) && parallelism == other.parallelism - && subtaskStates.equals(other.subtaskStates) - && partitionableStates.equals(other.partitionableStates) - && keyGroupsStateHandles.equals(other.keyGroupsStateHandles); + && subtaskStates.equals(other.subtaskStates); } else { return false; } @@ -230,18 +158,10 @@ public boolean equals(Object obj) { @Override public int hashCode() { - return parallelism + 31 * Objects.hash(jobVertexID, subtaskStates, partitionableStates, keyGroupsStateHandles); + return parallelism + 31 * Objects.hash(jobVertexID, subtaskStates); } public Map getSubtaskStates() { return Collections.unmodifiableMap(subtaskStates); } - - public Map getKeyGroupsStateHandles() { - return Collections.unmodifiableMap(keyGroupsStateHandles); - } - - public Map> getPartitionableStates() { - return partitionableStates; - } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Serializer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Serializer.java index 666176b638cae..89f1f42b2157f 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Serializer.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Serializer.java @@ -80,46 +80,13 @@ public void serialize(SavepointV1 savepoint, DataOutputStream dos) throws IOExce dos.writeInt(taskState.getMaxParallelism()); dos.writeInt(taskState.getChainLength()); - // Sub task non-partitionable states + // Sub task states Map subtaskStateMap = taskState.getSubtaskStates(); dos.writeInt(subtaskStateMap.size()); for (Map.Entry entry : subtaskStateMap.entrySet()) { dos.writeInt(entry.getKey()); - - SubtaskState subtaskState = entry.getValue(); - ChainedStateHandle chainedStateHandle = subtaskState.getChainedStateHandle(); - dos.writeInt(chainedStateHandle.getLength()); - for (int j = 0; j < chainedStateHandle.getLength(); ++j) { - StreamStateHandle stateHandle = chainedStateHandle.get(j); - serializeStreamStateHandle(stateHandle, dos); - } - - dos.writeLong(subtaskState.getDuration()); + serializeSubtaskState(entry.getValue(), dos); } - - // Sub task partitionable states - Map> partitionableStatesMap = taskState.getPartitionableStates(); - dos.writeInt(partitionableStatesMap.size()); - - for (Map.Entry> entry : partitionableStatesMap.entrySet()) { - dos.writeInt(entry.getKey()); - - ChainedStateHandle chainedStateHandle = entry.getValue(); - dos.writeInt(chainedStateHandle.getLength()); - for (int j = 0; j < chainedStateHandle.getLength(); ++j) { - OperatorStateHandle stateHandle = chainedStateHandle.get(j); - serializePartitionableStateHandle(stateHandle, dos); - } - } - - // Keyed state - Map keyGroupsStateHandles = taskState.getKeyGroupsStateHandles(); - dos.writeInt(keyGroupsStateHandles.size()); - for (Map.Entry entry : keyGroupsStateHandles.entrySet()) { - dos.writeInt(entry.getKey()); - serializeKeyGroupStateHandle(entry.getValue(), dos); - } - } } catch (Exception e) { throw new IOException(e); @@ -149,50 +116,99 @@ public SavepointV1 deserialize(DataInputStream dis) throws IOException { for (int j = 0; j < numSubTaskStates; j++) { int subtaskIndex = dis.readInt(); - int chainedStateHandleSize = dis.readInt(); - List streamStateHandleList = new ArrayList<>(chainedStateHandleSize); - for (int k = 0; k < chainedStateHandleSize; ++k) { - StreamStateHandle streamStateHandle = deserializeStreamStateHandle(dis); - streamStateHandleList.add(streamStateHandle); - } - - long duration = dis.readLong(); - ChainedStateHandle chainedStateHandle = new ChainedStateHandle<>(streamStateHandleList); - SubtaskState subtaskState = new SubtaskState(chainedStateHandle, duration); + SubtaskState subtaskState = deserializeSubtaskState(dis); taskState.putState(subtaskIndex, subtaskState); } + } - int numPartitionableOpStates = dis.readInt(); + return new SavepointV1(checkpointId, taskStates); + } - for (int j = 0; j < numPartitionableOpStates; j++) { - int subtaskIndex = dis.readInt(); - int chainedStateHandleSize = dis.readInt(); - List streamStateHandleList = new ArrayList<>(chainedStateHandleSize); + public static void serializeSubtaskState(SubtaskState subtaskState, DataOutputStream dos) throws IOException { - for (int k = 0; k < chainedStateHandleSize; ++k) { - OperatorStateHandle streamStateHandle = deserializePartitionableStateHandle(dis); - streamStateHandleList.add(streamStateHandle); - } + dos.writeLong(subtaskState.getDuration()); - ChainedStateHandle chainedStateHandle = - new ChainedStateHandle<>(streamStateHandleList); + ChainedStateHandle nonPartitionableState = subtaskState.getLegacyOperatorState(); - taskState.putPartitionableState(subtaskIndex, chainedStateHandle); - } + int len = nonPartitionableState != null ? nonPartitionableState.getLength() : 0; + dos.writeInt(len); + for (int i = 0; i < len; ++i) { + StreamStateHandle stateHandle = nonPartitionableState.get(i); + serializeStreamStateHandle(stateHandle, dos); + } - // Key group states - int numKeyGroupStates = dis.readInt(); - for (int j = 0; j < numKeyGroupStates; j++) { - int keyGroupIndex = dis.readInt(); + ChainedStateHandle operatorStateBackend = subtaskState.getManagedOperatorState(); - KeyGroupsStateHandle keyGroupsStateHandle = deserializeKeyGroupStateHandle(dis); - if (keyGroupsStateHandle != null) { - taskState.putKeyedState(keyGroupIndex, keyGroupsStateHandle); - } - } + len = operatorStateBackend != null ? operatorStateBackend.getLength() : 0; + dos.writeInt(len); + for (int i = 0; i < len; ++i) { + OperatorStateHandle stateHandle = operatorStateBackend.get(i); + serializeOperatorStateHandle(stateHandle, dos); } - return new SavepointV1(checkpointId, taskStates); + ChainedStateHandle operatorStateFromStream = subtaskState.getRawOperatorState(); + + len = operatorStateFromStream != null ? operatorStateFromStream.getLength() : 0; + dos.writeInt(len); + for (int i = 0; i < len; ++i) { + OperatorStateHandle stateHandle = operatorStateFromStream.get(i); + serializeOperatorStateHandle(stateHandle, dos); + } + + KeyGroupsStateHandle keyedStateBackend = subtaskState.getManagedKeyedState(); + serializeKeyGroupStateHandle(keyedStateBackend, dos); + + KeyGroupsStateHandle keyedStateStream = subtaskState.getRawKeyedState(); + serializeKeyGroupStateHandle(keyedStateStream, dos); + + } + + public static SubtaskState deserializeSubtaskState(DataInputStream dis) throws IOException { + + long duration = dis.readLong(); + + int len = dis.readInt(); + List nonPartitionableState = new ArrayList<>(len); + for (int i = 0; i < len; ++i) { + StreamStateHandle streamStateHandle = deserializeStreamStateHandle(dis); + nonPartitionableState.add(streamStateHandle); + } + + + len = dis.readInt(); + List operatorStateBackend = new ArrayList<>(len); + for (int i = 0; i < len; ++i) { + OperatorStateHandle streamStateHandle = deserializeOperatorStateHandle(dis); + operatorStateBackend.add(streamStateHandle); + } + + len = dis.readInt(); + List operatorStateStream = new ArrayList<>(len); + for (int i = 0; i < len; ++i) { + OperatorStateHandle streamStateHandle = deserializeOperatorStateHandle(dis); + operatorStateStream.add(streamStateHandle); + } + + KeyGroupsStateHandle keyedStateBackend = deserializeKeyGroupStateHandle(dis); + + KeyGroupsStateHandle keyedStateStream = deserializeKeyGroupStateHandle(dis); + + ChainedStateHandle nonPartitionableStateChain = + new ChainedStateHandle<>(nonPartitionableState); + + ChainedStateHandle operatorStateBackendChain = + new ChainedStateHandle<>(operatorStateBackend); + + ChainedStateHandle operatorStateStreamChain = + new ChainedStateHandle<>(operatorStateStream); + + return new SubtaskState( + nonPartitionableStateChain, + operatorStateBackendChain, + operatorStateStreamChain, + keyedStateBackend, + keyedStateStream, + duration); } public static void serializeKeyGroupStateHandle( @@ -231,7 +247,7 @@ public static KeyGroupsStateHandle deserializeKeyGroupStateHandle(DataInputStrea } } - public static void serializePartitionableStateHandle( + public static void serializeOperatorStateHandle( OperatorStateHandle stateHandle, DataOutputStream dos) throws IOException { if (stateHandle != null) { @@ -252,7 +268,7 @@ public static void serializePartitionableStateHandle( } } - public static OperatorStateHandle deserializePartitionableStateHandle( + public static OperatorStateHandle deserializeOperatorStateHandle( DataInputStream dis) throws IOException { final int type = dis.readByte(); @@ -270,13 +286,14 @@ public static OperatorStateHandle deserializePartitionableStateHandle( offsetsMap.put(key, offsets); } StreamStateHandle stateHandle = deserializeStreamStateHandle(dis); - return new OperatorStateHandle(stateHandle, offsetsMap); + return new OperatorStateHandle(offsetsMap, stateHandle); } else { throw new IllegalStateException("Reading invalid OperatorStateHandle, type: " + type); } } - public static void serializeStreamStateHandle(StreamStateHandle stateHandle, DataOutputStream dos) throws IOException { + public static void serializeStreamStateHandle( + StreamStateHandle stateHandle, DataOutputStream dos) throws IOException { if (stateHandle == null) { dos.writeByte(NULL_HANDLE); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptor.java b/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptor.java index 7bbdb2acde9ac..bf31e5141e1bc 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptor.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptor.java @@ -25,10 +25,7 @@ import org.apache.flink.runtime.blob.BlobKey; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.jobgraph.JobVertexID; -import org.apache.flink.runtime.state.ChainedStateHandle; -import org.apache.flink.runtime.state.KeyGroupsStateHandle; -import org.apache.flink.runtime.state.OperatorStateHandle; -import org.apache.flink.runtime.state.StreamStateHandle; +import org.apache.flink.runtime.state.TaskStateHandles; import org.apache.flink.util.SerializedValue; import java.io.Serializable; @@ -36,7 +33,6 @@ import java.util.Collection; import java.util.List; - import static org.apache.flink.util.Preconditions.checkArgument; import static org.apache.flink.util.Preconditions.checkNotNull; @@ -95,13 +91,7 @@ public final class TaskDeploymentDescriptor implements Serializable { /** The list of classpaths required to run this task. */ private final List requiredClasspaths; - /** Handle to the non-partitioned state of the operator chain */ - private final ChainedStateHandle operatorState; - - /** Handle to the key-grouped state of the head operator in the chain */ - private final List keyGroupState; - - private final List> partitionableOperatorState; + private final TaskStateHandles taskStateHandles; /** The execution configuration (see {@link ExecutionConfig}) related to the specific job. */ private final SerializedValue serializedExecutionConfig; @@ -128,9 +118,7 @@ public TaskDeploymentDescriptor( List requiredJarFiles, List requiredClasspaths, int targetSlotNumber, - ChainedStateHandle operatorState, - List keyGroupState, - List> partitionableOperatorStateHandles) { + TaskStateHandles taskStateHandles) { checkArgument(indexInSubtaskGroup >= 0); checkArgument(numberOfSubtasks > indexInSubtaskGroup); @@ -155,9 +143,7 @@ public TaskDeploymentDescriptor( this.requiredJarFiles = checkNotNull(requiredJarFiles); this.requiredClasspaths = checkNotNull(requiredClasspaths); this.targetSlotNumber = targetSlotNumber; - this.operatorState = operatorState; - this.keyGroupState = keyGroupState; - this.partitionableOperatorState = partitionableOperatorStateHandles; + this.taskStateHandles = taskStateHandles; } public TaskDeploymentDescriptor( @@ -199,8 +185,6 @@ public TaskDeploymentDescriptor( requiredJarFiles, requiredClasspaths, targetSlotNumber, - null, - null, null); } @@ -346,15 +330,7 @@ private String collectionToString(Collection collection) { return strBuilder.toString(); } - public ChainedStateHandle getOperatorState() { - return operatorState; - } - - public List getKeyGroupState() { - return keyGroupState; - } - - public List> getPartitionableOperatorState() { - return partitionableOperatorState; + public TaskStateHandles getTaskStateHandles() { + return taskStateHandles; } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/execution/Environment.java b/flink-runtime/src/main/java/org/apache/flink/runtime/execution/Environment.java index f0ff918ad7a75..af1a640347bc5 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/execution/Environment.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/execution/Environment.java @@ -26,6 +26,7 @@ import org.apache.flink.runtime.accumulators.AccumulatorRegistry; import org.apache.flink.runtime.broadcast.BroadcastVariableManager; import org.apache.flink.runtime.checkpoint.CheckpointMetaData; +import org.apache.flink.runtime.checkpoint.SubtaskState; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.io.disk.iomanager.IOManager; import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter; @@ -35,7 +36,6 @@ import org.apache.flink.runtime.memory.MemoryManager; import org.apache.flink.runtime.metrics.groups.TaskMetricGroup; import org.apache.flink.runtime.query.TaskKvStateRegistry; -import org.apache.flink.runtime.state.CheckpointStateHandles; import org.apache.flink.runtime.state.KvState; import org.apache.flink.runtime.taskmanager.TaskManagerRuntimeInfo; @@ -171,12 +171,12 @@ public interface Environment { * the checkpoint with the give checkpoint-ID. This method does include * the given state in the checkpoint. * - * @param checkpointStateHandles All state handles for the checkpointed state * @param checkpointMetaData the meta data for this checkpoint + * @param subtaskState All state handles for the checkpointed state */ void acknowledgeCheckpoint( CheckpointMetaData checkpointMetaData, - CheckpointStateHandles checkpointStateHandles); + SubtaskState subtaskState); /** * Marks task execution failed for an external reason (a reason other than the task code itself diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java index 0b56931ef1fa5..17e0df1a2fc2f 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java @@ -34,7 +34,6 @@ import org.apache.flink.runtime.deployment.ResultPartitionLocation; import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor; import org.apache.flink.runtime.execution.ExecutionState; -import org.apache.flink.runtime.state.OperatorStateHandle; import org.apache.flink.runtime.instance.ActorGateway; import org.apache.flink.runtime.instance.SimpleSlot; import org.apache.flink.runtime.instance.SlotProvider; @@ -47,19 +46,14 @@ import org.apache.flink.runtime.jobmanager.scheduler.SlotSharingGroup; import org.apache.flink.runtime.messages.Messages; import org.apache.flink.runtime.messages.TaskMessages.TaskOperationResult; -import org.apache.flink.runtime.state.ChainedStateHandle; -import org.apache.flink.runtime.state.CheckpointStateHandles; -import org.apache.flink.runtime.state.KeyGroupsStateHandle; -import org.apache.flink.runtime.state.StreamStateHandle; +import org.apache.flink.runtime.state.TaskStateHandles; import org.apache.flink.runtime.taskmanager.TaskManagerLocation; import org.apache.flink.util.ExceptionUtils; import org.slf4j.Logger; - import scala.concurrent.ExecutionContext; import scala.concurrent.duration.FiniteDuration; import java.util.ArrayList; -import java.util.Collection; import java.util.List; import java.util.Map; import java.util.concurrent.Callable; @@ -136,12 +130,7 @@ public class Execution implements AccessExecution, Archiveable chainedStateHandle; - - private List> chainedPartitionableStateHandle; - - private List keyGroupsStateHandles; - + private TaskStateHandles taskStateHandles; /** The execution context which is used to execute futures. */ private ExecutionContext executionContext; @@ -232,39 +221,27 @@ public long getStateTimestamp(ExecutionState state) { return this.stateTimestamps[state.ordinal()]; } - public ChainedStateHandle getChainedStateHandle() { - return chainedStateHandle; - } - - public List getKeyGroupsStateHandles() { - return keyGroupsStateHandles; - } - - public List> getChainedPartitionableStateHandle() { - return chainedPartitionableStateHandle; - } - public boolean isFinished() { return state.isTerminal(); } + public TaskStateHandles getTaskStateHandles() { + return taskStateHandles; + } + /** * Sets the initial state for the execution. The serialized state is then shipped via the * {@link TaskDeploymentDescriptor} to the TaskManagers. * * @param checkpointStateHandles all checkpointed operator state */ - public void setInitialState(CheckpointStateHandles checkpointStateHandles, List> chainedPartitionableStateHandle) { + public void setInitialState(TaskStateHandles checkpointStateHandles) { if (state != ExecutionState.CREATED) { throw new IllegalArgumentException("Can only assign operator state when execution attempt is in CREATED"); } - if(checkpointStateHandles != null) { - this.chainedStateHandle = checkpointStateHandles.getNonPartitionedStateHandles(); - this.chainedPartitionableStateHandle = chainedPartitionableStateHandle; - this.keyGroupsStateHandles = checkpointStateHandles.getKeyGroupsStateHandle(); - } + this.taskStateHandles = checkpointStateHandles; } // -------------------------------------------------------------------------------------------- @@ -390,9 +367,7 @@ public void deployToSlot(final SimpleSlot slot) throws JobException { final TaskDeploymentDescriptor deployment = vertex.createDeploymentDescriptor( attemptId, slot, - chainedStateHandle, - keyGroupsStateHandles, - chainedPartitionableStateHandle, + taskStateHandles, attemptNumber); // register this execution at the execution graph, to receive call backs diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionVertex.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionVertex.java index 96af91eb7c117..b647385a0f0ef 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionVertex.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionVertex.java @@ -18,7 +18,9 @@ package org.apache.flink.runtime.executiongraph; +import org.apache.flink.api.common.Archiveable; import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.JobID; import org.apache.flink.runtime.JobException; import org.apache.flink.runtime.blob.BlobKey; import org.apache.flink.runtime.deployment.InputChannelDeploymentDescriptor; @@ -27,36 +29,28 @@ import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor; import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor; import org.apache.flink.runtime.execution.ExecutionState; -import org.apache.flink.api.common.Archiveable; -import org.apache.flink.runtime.instance.SlotProvider; -import org.apache.flink.runtime.state.OperatorStateHandle; -import org.apache.flink.runtime.taskmanager.TaskManagerLocation; import org.apache.flink.runtime.instance.ActorGateway; import org.apache.flink.runtime.instance.SimpleSlot; +import org.apache.flink.runtime.instance.SlotProvider; import org.apache.flink.runtime.io.network.partition.ResultPartitionID; import org.apache.flink.runtime.jobgraph.DistributionPattern; import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; import org.apache.flink.runtime.jobgraph.JobEdge; -import org.apache.flink.api.common.JobID; import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.jobmanager.scheduler.CoLocationConstraint; import org.apache.flink.runtime.jobmanager.scheduler.CoLocationGroup; import org.apache.flink.runtime.jobmanager.scheduler.NoResourceAvailableException; +import org.apache.flink.runtime.state.TaskStateHandles; +import org.apache.flink.runtime.taskmanager.TaskManagerLocation; import org.apache.flink.util.ExceptionUtils; import org.apache.flink.util.SerializedValue; -import org.apache.flink.runtime.state.ChainedStateHandle; -import org.apache.flink.runtime.state.KeyGroupsStateHandle; -import org.apache.flink.runtime.state.StreamStateHandle; - import org.slf4j.Logger; - import scala.concurrent.duration.FiniteDuration; import java.io.Serializable; import java.net.URL; import java.util.ArrayList; -import java.util.Collection; import java.util.Collections; import java.util.HashSet; import java.util.LinkedHashMap; @@ -622,9 +616,7 @@ void notifyStateTransition(ExecutionAttemptID executionId, ExecutionState newSta TaskDeploymentDescriptor createDeploymentDescriptor( ExecutionAttemptID executionId, SimpleSlot targetSlot, - ChainedStateHandle operatorState, - List keyGroupStates, - List> partitionableOperatorStateHandle, + TaskStateHandles taskStateHandles, int attemptNumber) { // Produced intermediate results @@ -676,9 +668,7 @@ TaskDeploymentDescriptor createDeploymentDescriptor( jarFiles, classpaths, targetSlot.getRoot().getSlotNumber(), - operatorState, - keyGroupStates, - partitionableOperatorStateHandle); + taskStateHandles); } // -------------------------------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/fs/hdfs/HadoopDataInputStream.java b/flink-runtime/src/main/java/org/apache/flink/runtime/fs/hdfs/HadoopDataInputStream.java index 8893ba4839f7e..47b63f04ac527 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/fs/hdfs/HadoopDataInputStream.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/fs/hdfs/HadoopDataInputStream.java @@ -18,11 +18,10 @@ package org.apache.flink.runtime.fs.hdfs; -import java.io.IOException; - import org.apache.flink.core.fs.FSDataInputStream; import javax.annotation.Nonnull; +import java.io.IOException; import static org.apache.flink.util.Preconditions.checkNotNull; @@ -46,7 +45,11 @@ public HadoopDataInputStream(org.apache.hadoop.fs.FSDataInputStream fsDataInputS @Override public void seek(long desired) throws IOException { - fsDataInputStream.seek(desired); + // This optimization prevents some implementations of distributed FS to perform expensive seeks when they are + // actually not needed + if (desired != getPos()) { + fsDataInputStream.seek(desired); + } } @Override diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/tasks/StatefulTask.java b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/tasks/StatefulTask.java index e1d15e22ac00e..b0c3730caa3d0 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/tasks/StatefulTask.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/tasks/StatefulTask.java @@ -19,13 +19,7 @@ package org.apache.flink.runtime.jobgraph.tasks; import org.apache.flink.runtime.checkpoint.CheckpointMetaData; -import org.apache.flink.runtime.state.ChainedStateHandle; -import org.apache.flink.runtime.state.KeyGroupsStateHandle; -import org.apache.flink.runtime.state.OperatorStateHandle; -import org.apache.flink.runtime.state.StreamStateHandle; - -import java.util.Collection; -import java.util.List; +import org.apache.flink.runtime.state.TaskStateHandles; /** * This interface must be implemented by any invokable that has recoverable state and participates @@ -37,15 +31,9 @@ public interface StatefulTask { * Sets the initial state of the operator, upon recovery. The initial state is typically * a snapshot of the state from a previous execution. * - * TODO this should use @{@link org.apache.flink.runtime.state.CheckpointStateHandles} after redoing chained state. - * - * @param chainedState Handle for the chained operator states. - * @param keyGroupsState Handle for key group states. + * @param taskStateHandles All state handle for the task. */ - void setInitialState( - ChainedStateHandle chainedState, - List keyGroupsState, - List> partitionableOperatorState) throws Exception; + void setInitialState(TaskStateHandles taskStateHandles) throws Exception; /** * This method is called to trigger a checkpoint, asynchronously by the checkpoint diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/messages/checkpoint/AcknowledgeCheckpoint.java b/flink-runtime/src/main/java/org/apache/flink/runtime/messages/checkpoint/AcknowledgeCheckpoint.java index ac14d3aeb1d87..c63bac5708c2c 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/messages/checkpoint/AcknowledgeCheckpoint.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/messages/checkpoint/AcknowledgeCheckpoint.java @@ -20,8 +20,8 @@ import org.apache.flink.api.common.JobID; import org.apache.flink.runtime.checkpoint.CheckpointMetaData; +import org.apache.flink.runtime.checkpoint.SubtaskState; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; -import org.apache.flink.runtime.state.CheckpointStateHandles; import static org.apache.flink.util.Preconditions.checkArgument; @@ -38,7 +38,7 @@ public class AcknowledgeCheckpoint extends AbstractCheckpointMessage implements private static final long serialVersionUID = -7606214777192401493L; - private final CheckpointStateHandles checkpointStateHandles; + private final SubtaskState subtaskState; private final CheckpointMetaData checkpointMetaData; @@ -55,11 +55,11 @@ public AcknowledgeCheckpoint( JobID job, ExecutionAttemptID taskExecutionId, CheckpointMetaData checkpointMetaData, - CheckpointStateHandles checkpointStateHandles) { + SubtaskState subtaskState) { super(job, taskExecutionId, checkpointMetaData.getCheckpointId()); - this.checkpointStateHandles = checkpointStateHandles; + this.subtaskState = subtaskState; this.checkpointMetaData = checkpointMetaData; // these may be "-1", in case the values are unknown or not set checkArgument(checkpointMetaData.getSyncDurationMillis() >= -1); @@ -72,8 +72,8 @@ public AcknowledgeCheckpoint( // properties // ------------------------------------------------------------------------ - public CheckpointStateHandles getCheckpointStateHandles() { - return checkpointStateHandles; + public SubtaskState getSubtaskState() { + return subtaskState; } public long getSynchronousDurationMillis() { @@ -107,21 +107,21 @@ public boolean equals(Object o) { } AcknowledgeCheckpoint that = (AcknowledgeCheckpoint) o; - return checkpointStateHandles != null ? - checkpointStateHandles.equals(that.checkpointStateHandles) : that.checkpointStateHandles == null; + return subtaskState != null ? + subtaskState.equals(that.subtaskState) : that.subtaskState == null; } @Override public int hashCode() { int result = super.hashCode(); - result = 31 * result + (checkpointStateHandles != null ? checkpointStateHandles.hashCode() : 0); + result = 31 * result + (subtaskState != null ? subtaskState.hashCode() : 0); return result; } @Override public String toString() { return String.format("Confirm Task Checkpoint %d for (%s/%s) - state=%s", - getCheckpointId(), getJob(), getTaskExecutionId(), checkpointStateHandles); + getCheckpointId(), getJob(), getTaskExecutionId(), subtaskState); } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/query/KvStateMessage.java b/flink-runtime/src/main/java/org/apache/flink/runtime/query/KvStateMessage.java index 857b8b304a743..9ac2d4496b0a6 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/query/KvStateMessage.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/query/KvStateMessage.java @@ -130,7 +130,7 @@ public NotifyKvStateRegistered( this.jobId = Preconditions.checkNotNull(jobId, "JobID"); this.jobVertexId = Preconditions.checkNotNull(jobVertexId, "JobVertexID"); - Preconditions.checkArgument(keyGroupRange != KeyGroupRange.EMPTY_KEY_GROUP); + Preconditions.checkArgument(keyGroupRange != KeyGroupRange.EMPTY_KEY_GROUP_RANGE); this.keyGroupRange = Preconditions.checkNotNull(keyGroupRange); this.registrationName = Preconditions.checkNotNull(registrationName, "Registration name"); this.kvStateId = Preconditions.checkNotNull(kvStateId, "KvStateID"); @@ -236,7 +236,7 @@ public NotifyKvStateUnregistered( this.jobId = Preconditions.checkNotNull(jobId, "JobID"); this.jobVertexId = Preconditions.checkNotNull(jobVertexId, "JobVertexID"); - Preconditions.checkArgument(keyGroupRange != KeyGroupRange.EMPTY_KEY_GROUP); + Preconditions.checkArgument(keyGroupRange != KeyGroupRange.EMPTY_KEY_GROUP_RANGE); this.keyGroupRange = Preconditions.checkNotNull(keyGroupRange); this.registrationName = Preconditions.checkNotNull(registrationName, "Registration name"); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java index 7ca3b382864e4..e5d9b2bf57a9f 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java @@ -50,7 +50,7 @@ * @param Type of the key by which state is keyed. */ public abstract class AbstractKeyedStateBackend - implements KeyedStateBackend, SnapshotProvider, Closeable { + implements KeyedStateBackend, Snapshotable, Closeable { /** {@link TypeSerializer} for our key. */ protected final TypeSerializer keySerializer; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractStateBackend.java index c683a0236eca1..1b53f1a76e71f 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractStateBackend.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractStateBackend.java @@ -25,7 +25,6 @@ import java.io.IOException; import java.util.Collection; -import java.util.List; /** * A state backend defines how state is stored and snapshotted during checkpoints. @@ -70,7 +69,7 @@ public abstract AbstractKeyedStateBackend restoreKeyedStateBackend( TypeSerializer keySerializer, int numberOfKeyGroups, KeyGroupRange keyGroupRange, - List restoredState, + Collection restoredState, TaskKvStateRegistry kvStateRegistry ) throws Exception; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/BoundedInputStream.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/BoundedInputStream.java new file mode 100644 index 0000000000000..aeb0ce85fa12d --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/BoundedInputStream.java @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.runtime.state; + +import org.apache.flink.core.fs.FSDataInputStream; + +import java.io.IOException; +import java.io.InputStream; + +/** + * Wrapper around a FSDataInputStream to limit the maximum read offset. + * + * Based on the implementation from org.apache.commons.io.input.BoundedInputStream + */ +public class BoundedInputStream extends InputStream { + private final FSDataInputStream delegate; + private long endOffsetExclusive; + private long position; + private long mark; + + public BoundedInputStream(FSDataInputStream delegate, long endOffsetExclusive) throws IOException { + this.position = delegate.getPos(); + this.mark = -1L; + this.endOffsetExclusive = endOffsetExclusive; + this.delegate = delegate; + } + + public int read() throws IOException { + if (endOffsetExclusive >= 0L && position >= endOffsetExclusive) { + return -1; + } else { + int result = delegate.read(); + ++position; + return result; + } + } + + public int read(byte[] b) throws IOException { + return read(b, 0, b.length); + } + + public int read(byte[] b, int off, int len) throws IOException { + if (endOffsetExclusive >= 0L && position >= endOffsetExclusive) { + return -1; + } else { + long maxRead = endOffsetExclusive >= 0L ? Math.min((long) len, endOffsetExclusive - position) : (long) len; + int bytesRead = delegate.read(b, off, (int) maxRead); + if (bytesRead == -1) { + return -1; + } else { + position += (long) bytesRead; + return bytesRead; + } + } + } + + public long skip(long n) throws IOException { + long toSkip = endOffsetExclusive >= 0L ? Math.min(n, endOffsetExclusive - position) : n; + long skippedBytes = delegate.skip(toSkip); + position += skippedBytes; + return skippedBytes; + } + + public int available() throws IOException { + return endOffsetExclusive >= 0L && position >= endOffsetExclusive ? 0 : delegate.available(); + } + + public String toString() { + return delegate.toString(); + } + + public void close() throws IOException { + delegate.close(); + } + + public synchronized void reset() throws IOException { + delegate.reset(); + position = mark; + } + + public synchronized void mark(int readlimit) { + delegate.mark(readlimit); + mark = position; + } + + public long getEndOffsetExclusive() { + return endOffsetExclusive; + } + + public void setEndOffsetExclusive(long endOffsetExclusive) { + this.endOffsetExclusive = endOffsetExclusive; + } + + public boolean markSupported() { + return delegate.markSupported(); + } +} \ No newline at end of file diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ChainedStateHandle.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ChainedStateHandle.java index c6904c08d7af1..a807428f3199d 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ChainedStateHandle.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ChainedStateHandle.java @@ -123,4 +123,8 @@ public int hashCode() { public static ChainedStateHandle wrapSingleHandle(T stateHandleToWrap) { return new ChainedStateHandle(Collections.singletonList(stateHandleToWrap)); } + + public static boolean isNullOrEmpty(ChainedStateHandle chainedStateHandle) { + return chainedStateHandle == null || chainedStateHandle.isEmpty(); + } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/CheckpointStateHandles.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/CheckpointStateHandles.java deleted file mode 100644 index 9daf9639132e6..0000000000000 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/CheckpointStateHandles.java +++ /dev/null @@ -1,103 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.runtime.state; - -import java.io.Serializable; -import java.util.List; - -/** - * Container state handles that contains all state handles from the different state types of a checkpointed state. - * TODO This will be changed in the future if we get rid of chained state and instead connect state directly to individual operators in a chain. - */ -public class CheckpointStateHandles implements Serializable { - - private static final long serialVersionUID = 3252351989995L; - - private final ChainedStateHandle nonPartitionedStateHandles; - - private final ChainedStateHandle partitioneableStateHandles; - - private final List keyGroupsStateHandle; - - public CheckpointStateHandles( - ChainedStateHandle nonPartitionedStateHandles, - ChainedStateHandle partitioneableStateHandles, - List keyGroupsStateHandle) { - - this.nonPartitionedStateHandles = nonPartitionedStateHandles; - this.partitioneableStateHandles = partitioneableStateHandles; - this.keyGroupsStateHandle = keyGroupsStateHandle; - } - - public ChainedStateHandle getNonPartitionedStateHandles() { - return nonPartitionedStateHandles; - } - - public ChainedStateHandle getPartitioneableStateHandles() { - return partitioneableStateHandles; - } - - public List getKeyGroupsStateHandle() { - return keyGroupsStateHandle; - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof CheckpointStateHandles)) { - return false; - } - - CheckpointStateHandles that = (CheckpointStateHandles) o; - - if (nonPartitionedStateHandles != null ? - !nonPartitionedStateHandles.equals(that.nonPartitionedStateHandles) - : that.nonPartitionedStateHandles != null) { - return false; - } - - if (partitioneableStateHandles != null ? - !partitioneableStateHandles.equals(that.partitioneableStateHandles) - : that.partitioneableStateHandles != null) { - return false; - } - return keyGroupsStateHandle != null ? - keyGroupsStateHandle.equals(that.keyGroupsStateHandle) : that.keyGroupsStateHandle == null; - - } - - @Override - public int hashCode() { - int result = nonPartitionedStateHandles != null ? nonPartitionedStateHandles.hashCode() : 0; - result = 31 * result + (partitioneableStateHandles != null ? partitioneableStateHandles.hashCode() : 0); - result = 31 * result + (keyGroupsStateHandle != null ? keyGroupsStateHandle.hashCode() : 0); - return result; - } - - @Override - public String toString() { - return "CheckpointStateHandles{" + - "nonPartitionedStateHandles=" + nonPartitionedStateHandles + - ", partitioneableStateHandles=" + partitioneableStateHandles + - ", keyGroupsStateHandle=" + keyGroupsStateHandle + - '}'; - } -} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ClosableRegistry.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ClosableRegistry.java index 26d61927ab194..f6285b058794b 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ClosableRegistry.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ClosableRegistry.java @@ -25,6 +25,13 @@ import java.util.HashSet; import java.util.Set; +/** + * This class allows to register instances of {@link Closeable}, which are all closed if this registry is closed. + *

+ * Registering to an already closed registry will throw an exception and close the provided {@link Closeable} + *

+ * All methods in this class are thread-safe. + */ public class ClosableRegistry implements Closeable { private final Set registeredCloseables; @@ -35,7 +42,15 @@ public ClosableRegistry() { this.closed = false; } - public boolean registerClosable(Closeable closeable) { + /** + * Registers a {@link Closeable} with the registry. In case the registry is already closed, this method throws an + * {@link IllegalStateException} and closes the passed {@link Closeable}. + * + * @param closeable Closable tor register + * @return true if the the Closable was newly added to the registry + * @throws IOException exception when the registry was closed before + */ + public boolean registerClosable(Closeable closeable) throws IOException { if (null == closeable) { return false; @@ -43,13 +58,20 @@ public boolean registerClosable(Closeable closeable) { synchronized (getSynchronizationLock()) { if (closed) { - throw new IllegalStateException("Cannot register Closable, registry is already closed."); + IOUtils.closeQuietly(closeable); + throw new IOException("Cannot register Closable, registry is already closed. Closed passed closable."); } return registeredCloseables.add(closeable); } } + /** + * Removes a {@link Closeable} from the registry. + * + * @param closeable instance to remove from the registry. + * @return true, if the instance was actually registered and now removed + */ public boolean unregisterClosable(Closeable closeable) { if (null == closeable) { @@ -78,6 +100,12 @@ public void close() throws IOException { } } + public boolean isClosed() { + synchronized (getSynchronizationLock()) { + return closed; + } + } + private Object getSynchronizationLock() { return registeredCloseables; } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultKeyedStateStore.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultKeyedStateStore.java new file mode 100644 index 0000000000000..e34e731939c85 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultKeyedStateStore.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state; + +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.functions.RuntimeContext; +import org.apache.flink.api.common.state.KeyedStateStore; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.state.ReducingState; +import org.apache.flink.api.common.state.ReducingStateDescriptor; +import org.apache.flink.api.common.state.State; +import org.apache.flink.api.common.state.StateDescriptor; +import org.apache.flink.api.common.state.ValueState; +import org.apache.flink.api.common.state.ValueStateDescriptor; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.util.Preconditions; + +import static java.util.Objects.requireNonNull; + +/** + * Default implementation of KeyedStateStore that currently forwards state registration to a {@link RuntimeContext}. + */ +public class DefaultKeyedStateStore implements KeyedStateStore { + + private final AbstractKeyedStateBackend keyedStateBackend; + private final ExecutionConfig executionConfig; + + public DefaultKeyedStateStore(AbstractKeyedStateBackend keyedStateBackend, ExecutionConfig executionConfig) { + this.keyedStateBackend = Preconditions.checkNotNull(keyedStateBackend); + this.executionConfig = Preconditions.checkNotNull(executionConfig); + } + + @Override + public ValueState getState(ValueStateDescriptor stateProperties) { + requireNonNull(stateProperties, "The state properties must not be null"); + try { + stateProperties.initializeSerializerUnlessSet(executionConfig); + return getPartitionedState(stateProperties); + } catch (Exception e) { + throw new RuntimeException("Error while getting state", e); + } + } + + @Override + public ListState getListState(ListStateDescriptor stateProperties) { + requireNonNull(stateProperties, "The state properties must not be null"); + try { + stateProperties.initializeSerializerUnlessSet(executionConfig); + ListState originalState = getPartitionedState(stateProperties); + return new UserFacingListState<>(originalState); + } catch (Exception e) { + throw new RuntimeException("Error while getting state", e); + } + } + + @Override + public ReducingState getReducingState(ReducingStateDescriptor stateProperties) { + requireNonNull(stateProperties, "The state properties must not be null"); + try { + stateProperties.initializeSerializerUnlessSet(executionConfig); + return getPartitionedState(stateProperties); + } catch (Exception e) { + throw new RuntimeException("Error while getting state", e); + } + } + + /** + * Creates a partitioned state handle, using the state backend configured for this task. + * + * @throws IllegalStateException Thrown, if the key/value state was already initialized. + * @throws Exception Thrown, if the state backend cannot create the key/value state. + */ + public S getPartitionedState(StateDescriptor stateDescriptor) throws Exception { + return getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, stateDescriptor); + } + + /** + * Creates a partitioned state handle, using the state backend configured for this task. + * + * @throws IllegalStateException Thrown, if the key/value state was already initialized. + * @throws Exception Thrown, if the state backend cannot create the key/value state. + */ + @SuppressWarnings("unchecked") + public S getPartitionedState( + N namespace, TypeSerializer namespaceSerializer, + StateDescriptor stateDescriptor) throws Exception { + + return keyedStateBackend.getPartitionedState( + namespace, + namespaceSerializer, + stateDescriptor); + } +} \ No newline at end of file diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java index b1ab7e365174c..2f5d3cb13bbbb 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java @@ -20,7 +20,6 @@ import org.apache.flink.api.common.state.ListState; import org.apache.flink.api.common.state.ListStateDescriptor; -import org.apache.flink.api.common.state.OperatorStateStore; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.core.fs.FSDataInputStream; import org.apache.flink.core.fs.FSDataOutputStream; @@ -45,6 +44,9 @@ */ public class DefaultOperatorStateBackend implements OperatorStateBackend { + /** The default namespace for state in cases where no state name is provided */ + public static final String DEFAULT_OPERATOR_STATE_NAME = "_default_"; + private final Map> registeredStates; private final Collection restoreSnapshots; private final ClosableRegistry closeStreamOnCancelRegistry; @@ -72,15 +74,12 @@ public DefaultOperatorStateBackend( public DefaultOperatorStateBackend(ClassLoader userClassLoader) { this(userClassLoader, null); } - + @SuppressWarnings("unchecked") @Override - public ListState getSerializableListState(String stateName) throws Exception { - return getOperatorState(new ListStateDescriptor<>(stateName, javaSerializer)); + public ListState getSerializableListState(String stateName) throws Exception { + return (ListState) getOperatorState(new ListStateDescriptor<>(stateName, javaSerializer)); } - - /** - * @see OperatorStateStore - */ + @Override public ListState getOperatorState( ListStateDescriptor stateDescriptor) throws IOException { @@ -102,8 +101,9 @@ public ListState getOperatorState( // Try to restore previous state if state handles to snapshots are provided if (restoreSnapshots != null) { for (OperatorStateHandle stateHandle : restoreSnapshots) { - - long[] offsets = stateHandle.getStateNameToPartitionOffsets().get(name); + //TODO we coud be even more gc friendly be removing handles from the collections one the map is empty + // search and remove to be gc friendly + long[] offsets = stateHandle.getStateNameToPartitionOffsets().remove(name); if (offsets != null) { @@ -130,10 +130,7 @@ public ListState getOperatorState( return partitionableListState; } - - /** - * @see SnapshotProvider - */ + @Override public RunnableFuture snapshot( long checkpointId, long timestamp, CheckpointStreamFactory streamFactory) throws Exception { @@ -159,7 +156,7 @@ public RunnableFuture snapshot( writtenStatesMetaData.put(entry.getKey(), partitionOffsets); } - OperatorStateHandle handle = new OperatorStateHandle(out.closeAndGetHandle(), writtenStatesMetaData); + OperatorStateHandle handle = new OperatorStateHandle(writtenStatesMetaData, out.closeAndGetHandle()); return new DoneFuture<>(handle); } finally { @@ -170,48 +167,59 @@ public RunnableFuture snapshot( @Override public void dispose() { - + registeredStates.clear(); } static final class PartitionableListState implements ListState { - private final List listState; + private final List internalList; private final TypeSerializer partitionStateSerializer; public PartitionableListState(TypeSerializer partitionStateSerializer) { - this.listState = new ArrayList<>(); + this.internalList = new ArrayList<>(); this.partitionStateSerializer = Preconditions.checkNotNull(partitionStateSerializer); } @Override public void clear() { - listState.clear(); + internalList.clear(); } @Override public Iterable get() { - return listState; + return internalList; } @Override public void add(S value) { - listState.add(value); + internalList.add(value); } public long[] write(FSDataOutputStream out) throws IOException { - long[] partitionOffsets = new long[listState.size()]; + long[] partitionOffsets = new long[internalList.size()]; DataOutputView dov = new DataOutputViewStreamWrapper(out); - for (int i = 0; i < listState.size(); ++i) { - S element = listState.get(i); + for (int i = 0; i < internalList.size(); ++i) { + S element = internalList.get(i); partitionOffsets[i] = out.getPos(); partitionStateSerializer.serialize(element, dov); } return partitionOffsets; } + + public List getInternalList() { + return internalList; + } + + @Override + public String toString() { + return "PartitionableListState{" + + "listState=" + internalList + + '}'; + } } @Override @@ -223,5 +231,6 @@ public Set getRegisteredStateNames() { public void close() throws IOException { closeStreamOnCancelRegistry.close(); } + } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/FunctionInitializationContext.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/FunctionInitializationContext.java new file mode 100644 index 0000000000000..dce57e739997b --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/FunctionInitializationContext.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state; + +import org.apache.flink.annotation.PublicEvolving; + +/** + * This interface provides a context in which user functions can initialize by registering to managed state (i.e. state + * that is managed by state backends). + *

+ * + * Operator state is available to all functions, while keyed state is only available for functions after keyBy. + *

+ * + * For the purpose of initialization, the context signals if the state is empty or was restored from a previous + * execution. + * + */ +@PublicEvolving +public interface FunctionInitializationContext extends ManagedInitializationContext { +} \ No newline at end of file diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/FunctionSnapshotContext.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/FunctionSnapshotContext.java new file mode 100644 index 0000000000000..571b881813701 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/FunctionSnapshotContext.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state; + +import org.apache.flink.annotation.PublicEvolving; + +/** + * This interface provides a context in which user functions that use managed state (i.e. state that is managed by state + * backends) can participate in a snapshot. As snapshots of the backends themselves are taken by the system, this + * interface mainly provides meta information about the checkpoint. + */ +@PublicEvolving +public interface FunctionSnapshotContext extends ManagedSnapshotContext { +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupRange.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupRange.java index 3a9d3d05b318f..32151dbc27f84 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupRange.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupRange.java @@ -27,10 +27,12 @@ * This class defines a range of key-group indexes. Key-groups are the granularity into which the keyspace of a job * is partitioned for keyed state-handling in state backends. The boundaries of the range are inclusive. */ -public class KeyGroupRange implements Iterable, Serializable { +public class KeyGroupRange implements KeyGroupsList, Serializable { + + private static final long serialVersionUID = 4869121477592070607L; /** The empty key-group */ - public static final KeyGroupRange EMPTY_KEY_GROUP = new KeyGroupRange(); + public static final KeyGroupRange EMPTY_KEY_GROUP_RANGE = new KeyGroupRange(); private final int startKeyGroup; private final int endKeyGroup; @@ -64,6 +66,7 @@ public KeyGroupRange(int startKeyGroup, int endKeyGroup) { * @param keyGroup Key-group to check for inclusion. * @return True, only if the key-group is in the range. */ + @Override public boolean contains(int keyGroup) { return keyGroup >= startKeyGroup && keyGroup <= endKeyGroup; } @@ -77,13 +80,14 @@ public boolean contains(int keyGroup) { public KeyGroupRange getIntersection(KeyGroupRange other) { int start = Math.max(startKeyGroup, other.startKeyGroup); int end = Math.min(endKeyGroup, other.endKeyGroup); - return start <= end ? new KeyGroupRange(start, end) : EMPTY_KEY_GROUP; + return start <= end ? new KeyGroupRange(start, end) : EMPTY_KEY_GROUP_RANGE; } /** * * @return The number of key-groups in the range */ + @Override public int getNumberOfKeyGroups() { return 1 + endKeyGroup - startKeyGroup; } @@ -104,6 +108,14 @@ public int getEndKeyGroup() { return endKeyGroup; } + @Override + public int getKeyGroupId(int idx) { + if (idx < 0 || idx > getNumberOfKeyGroups()) { + throw new IndexOutOfBoundsException("Key group index out of bounds: " + idx); + } + return startKeyGroup + idx; + } + @Override public boolean equals(Object o) { if (this == o) { @@ -172,7 +184,6 @@ public void remove() { * @return the key-group from start to end or an empty key-group range. */ public static KeyGroupRange of(int startKeyGroup, int endKeyGroup) { - return startKeyGroup <= endKeyGroup ? new KeyGroupRange(startKeyGroup, endKeyGroup) : EMPTY_KEY_GROUP; + return startKeyGroup <= endKeyGroup ? new KeyGroupRange(startKeyGroup, endKeyGroup) : EMPTY_KEY_GROUP_RANGE; } - } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupRangeOffsets.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupRangeOffsets.java index 8e7207e004fb4..d4252784d66d9 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupRangeOffsets.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupRangeOffsets.java @@ -137,7 +137,11 @@ public Iterator> iterator() { } private int computeKeyGroupIndex(int keyGroup) { - return keyGroup - keyGroupRange.getStartKeyGroup(); + int idx = keyGroup - keyGroupRange.getStartKeyGroup(); + if (idx < 0 || idx >= offsets.length) { + throw new IllegalArgumentException("Key group " + keyGroup + " is not in " + keyGroupRange + "."); + } + return idx; } /** diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupStatePartitionStreamProvider.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupStatePartitionStreamProvider.java new file mode 100644 index 0000000000000..2a91f0f29df9f --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupStatePartitionStreamProvider.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state; + +import org.apache.flink.annotation.PublicEvolving; + +import java.io.IOException; +import java.io.InputStream; + +/** + * This class provides access to an input stream that contains state data for one key group and the key group id. + */ +@PublicEvolving +public class KeyGroupStatePartitionStreamProvider extends StatePartitionStreamProvider { + + /** Key group that corresponds to the data in the provided stream */ + private final int keyGroupId; + + public KeyGroupStatePartitionStreamProvider(IOException creationException, int keyGroupId) { + super(creationException); + this.keyGroupId = keyGroupId; + } + + public KeyGroupStatePartitionStreamProvider(InputStream stream, int keyGroupId) { + super(stream); + this.keyGroupId = keyGroupId; + } + + /** + * Returns the key group that corresponds to the data in the provided stream. + */ + public int getKeyGroupId() { + return keyGroupId; + } +} \ No newline at end of file diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupsList.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupsList.java new file mode 100644 index 0000000000000..928ebf3f94400 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupsList.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state; + +/** + * This interface offers ordered random read access to multiple key group ids. + */ +public interface KeyGroupsList extends Iterable { + + /** + * Returns the number of key group ids in the list. + */ + int getNumberOfKeyGroups(); + + /** + * Returns the id of the keygroup at the given index, where index in interval [0, {@link #getNumberOfKeyGroups()}[. + * + * @param idx the index into the list + * @return key group id at the given index + */ + int getKeyGroupId(int idx); + + /** + * Returns true, if the given key group id is contained in the list, otherwise false. + */ + boolean contains(int keyGroupId); +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java index 7293a8486a0fb..03f584ee47bb6 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java @@ -54,9 +54,9 @@ public interface KeyedStateBackend { int getNumberOfKeyGroups(); /** - * Returns the key group range for this backend. + * Returns the key groups for this backend. */ - KeyGroupRange getKeyGroupRange(); + KeyGroupsList getKeyGroupRange(); /** * {@link TypeSerializer} for the state backend key type. diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateCheckpointOutputStream.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateCheckpointOutputStream.java new file mode 100644 index 0000000000000..21215746b7b71 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateCheckpointOutputStream.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state; + +import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.util.Arrays; + +/** + * Checkpoint output stream that allows to write raw keyed state in a partitioned way, split into key-groups. + */ +@PublicEvolving +public final class KeyedStateCheckpointOutputStream extends NonClosingCheckpointOutputStream { + + public static final long NO_OFFSET_SET = -1L; + public static final int NO_CURRENT_KEY_GROUP = -1; + + private int currentKeyGroup; + private final KeyGroupRangeOffsets keyGroupRangeOffsets; + + public KeyedStateCheckpointOutputStream( + CheckpointStreamFactory.CheckpointStateOutputStream delegate, KeyGroupRange keyGroupRange) { + + super(delegate); + Preconditions.checkNotNull(keyGroupRange); + Preconditions.checkArgument(keyGroupRange != KeyGroupRange.EMPTY_KEY_GROUP_RANGE); + + this.currentKeyGroup = NO_CURRENT_KEY_GROUP; + long[] emptyOffsets = new long[keyGroupRange.getNumberOfKeyGroups()]; + // mark offsets as currently not set + Arrays.fill(emptyOffsets, NO_OFFSET_SET); + this.keyGroupRangeOffsets = new KeyGroupRangeOffsets(keyGroupRange, emptyOffsets); + } + + @Override + public void close() throws IOException { + // users should not be able to actually close the stream, it is closed by the system. + // TODO if we want to support async writes, this call could trigger a callback to the snapshot context that a handle is available. + } + + /** + * Returns a list of all key-groups which can be written to this stream. + */ + public KeyGroupsList getKeyGroupList() { + return keyGroupRangeOffsets.getKeyGroupRange(); + } + + /** + * User code can call this method to signal that it begins to write a new key group with the given key group id. + * This id must be within the {@link KeyGroupsList} provided by the stream. Each key-group can only be started once + * and is considered final/immutable as soon as this method is called again. + */ + public void startNewKeyGroup(int keyGroupId) throws IOException { + if (isKeyGroupAlreadyStarted(keyGroupId)) { + throw new IOException("Key group " + keyGroupId + " already registered!"); + } + keyGroupRangeOffsets.setKeyGroupOffset(keyGroupId, delegate.getPos()); + currentKeyGroup = keyGroupId; + } + + /** + * Returns true, if the key group with the given id was already started. The key group might not yet be finished, + * if it's id is equal to the return value of {@link #getCurrentKeyGroup()}. + */ + public boolean isKeyGroupAlreadyStarted(int keyGroupId) { + return NO_OFFSET_SET != keyGroupRangeOffsets.getKeyGroupOffset(keyGroupId); + } + + /** + * Returns true if the key group is already completely written and immutable. It was started and since then another + * key group has been started. + */ + public boolean isKeyGroupAlreadyFinished(int keyGroupId) { + return isKeyGroupAlreadyStarted(keyGroupId) && keyGroupId != getCurrentKeyGroup(); + } + + /** + * Returns the key group that is currently being written. The key group was started but not yet finished, i.e. data + * can still be added. If no key group was started, this returns {@link #NO_CURRENT_KEY_GROUP}. + */ + public int getCurrentKeyGroup() { + return currentKeyGroup; + } + + @Override + KeyGroupsStateHandle closeAndGetHandle() throws IOException { + StreamStateHandle streamStateHandle = delegate.closeAndGetHandle(); + return streamStateHandle != null ? new KeyGroupsStateHandle(keyGroupRangeOffsets, streamStateHandle) : null; + } +} \ No newline at end of file diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ManagedInitializationContext.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ManagedInitializationContext.java new file mode 100644 index 0000000000000..cf7106393b586 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ManagedInitializationContext.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state; + +import org.apache.flink.api.common.state.KeyedStateStore; +import org.apache.flink.api.common.state.OperatorStateStore; + +/** + * This interface provides a context in which operators can initialize by registering to managed state (i.e. state that + * is managed by state backends). + *

+ * + * Operator state is available to all operators, while keyed state is only available for operators after keyBy. + *

+ * + * For the purpose of initialization, the context signals if the state is empty (new operator) or was restored from + * a previous execution of this operator. + * + */ +public interface ManagedInitializationContext { + + /** + * Returns true, if some managed state was restored from the snapshot of a previous execution. + */ + boolean isRestored(); + + /** + * Returns an interface that allows for registering operator state with the backend. + */ + OperatorStateStore getManagedOperatorStateStore(); + + /** + * Returns an interface that allows for registering keyed state with the backend. + */ + KeyedStateStore getManagedKeyedStateStore(); + +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ManagedSnapshotContext.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ManagedSnapshotContext.java new file mode 100644 index 0000000000000..14156a6408c8c --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ManagedSnapshotContext.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state; + +import org.apache.flink.annotation.PublicEvolving; + +/** + * This interface provides a context in which operators that use managed state (i.e. state that is managed by state + * backends) can perform a snapshot. As snapshots of the backends themselves are taken by the system, this interface + * mainly provides meta information about the checkpoint. + */ +@PublicEvolving +public interface ManagedSnapshotContext { + + /** + * Returns the Id of the checkpoint for which the snapshot is taken. + */ + long getCheckpointId(); + + /** + * Returns the timestamp of the checkpoint for which the snapshot is taken. + */ + long getCheckpointTimestamp(); + +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/NonClosingCheckpointOutputStream.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/NonClosingCheckpointOutputStream.java new file mode 100644 index 0000000000000..f7f4bdbcdd8fe --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/NonClosingCheckpointOutputStream.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state; + +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.io.OutputStream; + +/** + * Abstract class to implement custom checkpoint output streams which should not be closable for user code. + * + * @param type of the returned state handle. + */ +public abstract class NonClosingCheckpointOutputStream extends OutputStream { + + protected final CheckpointStreamFactory.CheckpointStateOutputStream delegate; + + public NonClosingCheckpointOutputStream( + CheckpointStreamFactory.CheckpointStateOutputStream delegate) { + this.delegate = Preconditions.checkNotNull(delegate); + } + + @Override + public void flush() throws IOException { + delegate.flush(); + } + + @Override + public void write(int b) throws IOException { + delegate.write(b); + } + + @Override + public void write(byte[] b) throws IOException { + delegate.write(b); + } + + @Override + public void write(byte[] b, int off, int len) throws IOException { + delegate.write(b, off, len); + } + + @Override + public void close() throws IOException { + // users should not be able to actually close the stream, it is closed by the system. + // TODO if we want to support async writes, this call could trigger a callback to the snapshot context that a handle is available. + } + + + /** + * This method should not be public so as to not expose internals to user code. + */ + CheckpointStreamFactory.CheckpointStateOutputStream getDelegate() { + return delegate; + } + + /** + * This method should not be public so as to not expose internals to user code. Closes the underlying stream and + * returns a state handle. + */ + abstract T closeAndGetHandle() throws IOException; + +} \ No newline at end of file diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateBackend.java index 83e6369f79ac5..aee5226c08397 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateBackend.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateBackend.java @@ -24,10 +24,10 @@ /** * Interface that combines both, the user facing {@link OperatorStateStore} interface and the system interface - * {@link SnapshotProvider} + * {@link Snapshotable} * */ -public interface OperatorStateBackend extends OperatorStateStore, SnapshotProvider, Closeable { +public interface OperatorStateBackend extends OperatorStateStore, Snapshotable, Closeable { /** * Disposes the backend and releases all resources. diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateCheckpointOutputStream.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateCheckpointOutputStream.java new file mode 100644 index 0000000000000..eaa9fd9014886 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateCheckpointOutputStream.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state; + +import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.runtime.util.LongArrayList; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +/** + * Checkpoint output stream that allows to write raw operator state in a partitioned way. + */ +@PublicEvolving +public final class OperatorStateCheckpointOutputStream + extends NonClosingCheckpointOutputStream { + + private LongArrayList partitionOffsets; + private final long initialPosition; + + public OperatorStateCheckpointOutputStream( + CheckpointStreamFactory.CheckpointStateOutputStream delegate) throws IOException { + + super(delegate); + this.partitionOffsets = new LongArrayList(16); + this.initialPosition = delegate.getPos(); + } + + /** + * User code can call this method to signal that it begins to write a new partition of operator state. + * Each previously written partition is considered final/immutable as soon as this method is called again. + */ + public void startNewPartition() throws IOException { + partitionOffsets.add(delegate.getPos()); + } + + /** + * This method should not be public so as to not expose internals to user code. + */ + @Override + OperatorStateHandle closeAndGetHandle() throws IOException { + StreamStateHandle streamStateHandle = delegate.closeAndGetHandle(); + + if (null == streamStateHandle) { + return null; + } + + if (partitionOffsets.isEmpty() && delegate.getPos() > initialPosition) { + startNewPartition(); + } + + Map offsetsMap = new HashMap<>(1); + offsetsMap.put(DefaultOperatorStateBackend.DEFAULT_OPERATOR_STATE_NAME, partitionOffsets.toArray()); + + return new OperatorStateHandle(offsetsMap, streamStateHandle); + } + + public int getNumberOfPartitions() { + return partitionOffsets.size(); + } +} \ No newline at end of file diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateHandle.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateHandle.java index 3e2d713af309f..1ad41ea6735a4 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateHandle.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateHandle.java @@ -38,8 +38,8 @@ public class OperatorStateHandle implements StreamStateHandle { private final StreamStateHandle delegateStateHandle; public OperatorStateHandle( - StreamStateHandle delegateStateHandle, - Map stateNameToPartitionOffsets) { + Map stateNameToPartitionOffsets, + StreamStateHandle delegateStateHandle) { this.delegateStateHandle = Preconditions.checkNotNull(delegateStateHandle); this.stateNameToPartitionOffsets = Preconditions.checkNotNull(stateNameToPartitionOffsets); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/PartitionableCheckpointStateOutputStream.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/PartitionableCheckpointStateOutputStream.java deleted file mode 100644 index 065f9c2f61981..0000000000000 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/PartitionableCheckpointStateOutputStream.java +++ /dev/null @@ -1,96 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.runtime.state; - -import org.apache.flink.core.fs.FSDataOutputStream; -import org.apache.flink.util.Preconditions; - -import java.io.IOException; -import java.util.Arrays; -import java.util.HashMap; -import java.util.Map; - -public class PartitionableCheckpointStateOutputStream extends FSDataOutputStream { - - private final Map stateNameToPartitionOffsets; - private final CheckpointStreamFactory.CheckpointStateOutputStream delegate; - - public PartitionableCheckpointStateOutputStream(CheckpointStreamFactory.CheckpointStateOutputStream delegate) { - this.delegate = Preconditions.checkNotNull(delegate); - this.stateNameToPartitionOffsets = new HashMap<>(); - } - - @Override - public long getPos() throws IOException { - return delegate.getPos(); - } - - @Override - public void flush() throws IOException { - delegate.flush(); - } - - @Override - public void sync() throws IOException { - delegate.sync(); - } - - @Override - public void write(int b) throws IOException { - delegate.write(b); - } - - @Override - public void write(byte[] b) throws IOException { - delegate.write(b); - } - - @Override - public void write(byte[] b, int off, int len) throws IOException { - delegate.write(b, off, len); - } - - @Override - public void close() throws IOException { - delegate.close(); - } - - public OperatorStateHandle closeAndGetHandle() throws IOException { - StreamStateHandle streamStateHandle = delegate.closeAndGetHandle(); - return new OperatorStateHandle(streamStateHandle, stateNameToPartitionOffsets); - } - - public void startNewPartition(String stateName) throws IOException { - long[] offs = stateNameToPartitionOffsets.get(stateName); - if (offs == null) { - offs = new long[1]; - } else { - //TODO maybe we can use some primitive array list here instead of an array to avoid resize on each call. - offs = Arrays.copyOf(offs, offs.length + 1); - } - - offs[offs.length - 1] = getPos(); - stateNameToPartitionOffsets.put(stateName, offs); - } - - public static PartitionableCheckpointStateOutputStream wrap( - CheckpointStreamFactory.CheckpointStateOutputStream stream) { - return new PartitionableCheckpointStateOutputStream(stream); - } -} \ No newline at end of file diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/SnapshotProvider.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/Snapshotable.java similarity index 96% rename from flink-runtime/src/main/java/org/apache/flink/runtime/state/SnapshotProvider.java rename to flink-runtime/src/main/java/org/apache/flink/runtime/state/Snapshotable.java index c47fedd2ce79d..2aa282d064988 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/SnapshotProvider.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/Snapshotable.java @@ -25,7 +25,7 @@ * * @param Generic type of the state object that is created as handle to snapshots. */ -public interface SnapshotProvider { +public interface Snapshotable { /** * Operation that writes a snapshot into a stream that is provided by the given {@link CheckpointStreamFactory} and diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateInitializationContext.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateInitializationContext.java new file mode 100644 index 0000000000000..150aa2fb974a9 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateInitializationContext.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state; + +import org.apache.flink.annotation.PublicEvolving; + +/** + * This interface provides a context in which operators can initialize by registering to managed state (i.e. state that + * is managed by state backends) or iterating over streams of state partitions written as raw state in a previous + * snapshot. + *

+ * Similar to the managed state from {@link ManagedInitializationContext} and in general, raw operator state is + * available to all operators, while raw keyed state is only available for operators after keyBy. + *

+ * For the purpose of initialization, the context signals if all state is empty (new operator) or if any state was + * restored from a previous execution of this operator. + * + */ +@PublicEvolving +public interface StateInitializationContext extends FunctionInitializationContext { + + /** + * Returns an iterable to obtain input streams for previously stored operator state partitions that are assigned to + * this operator. + */ + Iterable getRawOperatorStateInputs(); + + /** + * Returns an iterable to obtain input streams for previously stored keyed state partitions that are assigned to + * this operator. + */ + Iterable getRawKeyedStateInputs(); + +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateInitializationContextImpl.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateInitializationContextImpl.java new file mode 100644 index 0000000000000..8fbde051dd6be --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateInitializationContextImpl.java @@ -0,0 +1,270 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state; + +import org.apache.commons.io.IOUtils; +import org.apache.flink.api.common.state.KeyedStateStore; +import org.apache.flink.api.common.state.OperatorStateStore; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.core.fs.FSDataInputStream; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.util.Collection; +import java.util.Collections; +import java.util.Iterator; + +/** + * Default implementation of {@link StateInitializationContext}. + */ +public class StateInitializationContextImpl implements StateInitializationContext { + + /** Closable registry to participate in the operator's cancel/close methods */ + private final ClosableRegistry closableRegistry; + + /** Signal whether any state to restore was found */ + private final boolean restored; + + private final OperatorStateStore operatorStateStore; + private final Collection operatorStateHandles; + + private final KeyedStateStore keyedStateStore; + private final Collection keyGroupsStateHandles; + + private final Iterable keyedStateIterable; + + public StateInitializationContextImpl( + boolean restored, + OperatorStateStore operatorStateStore, + KeyedStateStore keyedStateStore, + Collection keyGroupsStateHandles, + Collection operatorStateHandles, + ClosableRegistry closableRegistry) { + + this.restored = restored; + this.closableRegistry = Preconditions.checkNotNull(closableRegistry); + this.operatorStateStore = operatorStateStore; + this.keyedStateStore = keyedStateStore; + this.operatorStateHandles = operatorStateHandles; + this.keyGroupsStateHandles = keyGroupsStateHandles; + + this.keyedStateIterable = keyGroupsStateHandles == null ? + null + : new Iterable() { + @Override + public Iterator iterator() { + return new KeyGroupStreamIterator(getKeyGroupsStateHandles().iterator(), getClosableRegistry()); + } + }; + } + + @Override + public boolean isRestored() { + return restored; + } + + public Collection getOperatorStateHandles() { + return operatorStateHandles; + } + + public Collection getKeyGroupsStateHandles() { + return keyGroupsStateHandles; + } + + public ClosableRegistry getClosableRegistry() { + return closableRegistry; + } + + @Override + public Iterable getRawOperatorStateInputs() { + if (null != operatorStateHandles) { + return new Iterable() { + @Override + public Iterator iterator() { + return new OperatorStateStreamIterator( + DefaultOperatorStateBackend.DEFAULT_OPERATOR_STATE_NAME, + getOperatorStateHandles().iterator(), getClosableRegistry()); + } + }; + } else { + return Collections.emptyList(); + } + } + + @Override + public Iterable getRawKeyedStateInputs() { + if(null == keyedStateStore) { + throw new IllegalStateException("Attempt to access keyed state from non-keyed operator."); + } + + if (null != keyGroupsStateHandles) { + return keyedStateIterable; + } else { + return Collections.emptyList(); + } + } + + @Override + public OperatorStateStore getManagedOperatorStateStore() { + return operatorStateStore; + } + + @Override + public KeyedStateStore getManagedKeyedStateStore() { + return keyedStateStore; + } + + public void close() { + IOUtils.closeQuietly(closableRegistry); + } + + private static class KeyGroupStreamIterator implements Iterator { + + private final Iterator stateHandleIterator; + private final ClosableRegistry closableRegistry; + + private KeyGroupsStateHandle currentStateHandle; + private FSDataInputStream currentStream; + private Iterator> currentOffsetsIterator; + + public KeyGroupStreamIterator( + Iterator stateHandleIterator, ClosableRegistry closableRegistry) { + + this.stateHandleIterator = Preconditions.checkNotNull(stateHandleIterator); + this.closableRegistry = Preconditions.checkNotNull(closableRegistry); + } + + @Override + public boolean hasNext() { + if (null != currentStateHandle && currentOffsetsIterator.hasNext()) { + return true; + } else { + while (stateHandleIterator.hasNext()) { + currentStateHandle = stateHandleIterator.next(); + if (currentStateHandle.getNumberOfKeyGroups() > 0) { + currentOffsetsIterator = currentStateHandle.getGroupRangeOffsets().iterator(); + closableRegistry.unregisterClosable(currentStream); + IOUtils.closeQuietly(currentStream); + currentStream = null; + return true; + } + } + return false; + } + } + + private void openStream() throws IOException { + FSDataInputStream stream = currentStateHandle.openInputStream(); + closableRegistry.registerClosable(stream); + currentStream = stream; + } + + @Override + public KeyGroupStatePartitionStreamProvider next() { + Tuple2 keyGroupOffset = currentOffsetsIterator.next(); + try { + if (null == currentStream) { + openStream(); + } + currentStream.seek(keyGroupOffset.f1); + return new KeyGroupStatePartitionStreamProvider(currentStream, keyGroupOffset.f0); + } catch (IOException ioex) { + return new KeyGroupStatePartitionStreamProvider(ioex, keyGroupOffset.f0); + } + } + + @Override + public void remove() { + throw new UnsupportedOperationException("Read only Iterator"); + } + } + + private static class OperatorStateStreamIterator implements Iterator { + + private final String stateName; //TODO since we only support a single named state in raw, this could be dropped + + private final Iterator stateHandleIterator; + private final ClosableRegistry closableRegistry; + + private OperatorStateHandle currentStateHandle; + private FSDataInputStream currentStream; + private long[] offsets; + private int offPos; + + public OperatorStateStreamIterator( + String stateName, + Iterator stateHandleIterator, + ClosableRegistry closableRegistry) { + + this.stateName = Preconditions.checkNotNull(stateName); + this.stateHandleIterator = Preconditions.checkNotNull(stateHandleIterator); + this.closableRegistry = Preconditions.checkNotNull(closableRegistry); + } + + @Override + public boolean hasNext() { + if (null != currentStateHandle && offPos < offsets.length) { + return true; + } else { + while (stateHandleIterator.hasNext()) { + currentStateHandle = stateHandleIterator.next(); + long[] offsets = currentStateHandle.getStateNameToPartitionOffsets().get(stateName); + if (null != offsets && offsets.length > 0) { + + this.offsets = offsets; + this.offPos = 0; + + closableRegistry.unregisterClosable(currentStream); + IOUtils.closeQuietly(currentStream); + currentStream = null; + + return true; + } + } + return false; + } + } + + private void openStream() throws IOException { + FSDataInputStream stream = currentStateHandle.openInputStream(); + closableRegistry.registerClosable(stream); + currentStream = stream; + } + + @Override + public StatePartitionStreamProvider next() { + long offset = offsets[offPos++]; + try { + if (null == currentStream) { + openStream(); + } + currentStream.seek(offset); + + return new StatePartitionStreamProvider(currentStream); + } catch (IOException ioex) { + return new StatePartitionStreamProvider(ioex); + } + } + + @Override + public void remove() { + throw new UnsupportedOperationException("Read only Iterator"); + } + } +} \ No newline at end of file diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StatePartitionStreamProvider.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StatePartitionStreamProvider.java new file mode 100644 index 0000000000000..8b07da8e1c2bc --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StatePartitionStreamProvider.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state; + +import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.runtime.util.NonClosingStreamDecorator; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.io.InputStream; + +/** + * This class provides access to input streams that contain data of one state partition of a partitionable state. + * + * TODO use bounded stream that fail fast if the limit is exceeded on corrupted reads. + */ +@PublicEvolving +public class StatePartitionStreamProvider { + + /** A ready-made stream that contains data for one state partition */ + private final InputStream stream; + + /** Holds potential exception that happened when actually trying to create the stream */ + private final IOException creationException; + + public StatePartitionStreamProvider(IOException creationException) { + this.creationException = Preconditions.checkNotNull(creationException); + this.stream = null; + } + + public StatePartitionStreamProvider(InputStream stream) { + this.stream = new NonClosingStreamDecorator(Preconditions.checkNotNull(stream)); + this.creationException = null; + } + + + /** + * Returns a stream with the data of one state partition. + */ + public InputStream getStream() throws IOException { + if (creationException != null) { + throw new IOException(creationException); + } + return stream; + } +} \ No newline at end of file diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateSnapshotContext.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateSnapshotContext.java new file mode 100644 index 0000000000000..4dbbeaf6342cd --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateSnapshotContext.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state; + +import org.apache.flink.annotation.PublicEvolving; + +/** + * This interface provides a context in which operators that use managed (i.e. state that is managed by state + * backends) or raw (i.e. the operator can write it's state streams) state can perform a snapshot. + */ +@PublicEvolving +public interface StateSnapshotContext extends FunctionSnapshotContext { + + /** + * Returns an output stream for keyed state + */ + KeyedStateCheckpointOutputStream getRawKeyedOperatorStateOutput() throws Exception; + + /** + * Returns an output stream for operator state + */ + OperatorStateCheckpointOutputStream getRawOperatorStateOutput() throws Exception; + +} \ No newline at end of file diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateSnapshotContextSynchronousImpl.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateSnapshotContextSynchronousImpl.java new file mode 100644 index 0000000000000..d632529de49f7 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateSnapshotContextSynchronousImpl.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state; + +import org.apache.flink.annotation.VisibleForTesting; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.util.concurrent.RunnableFuture; + +/** + * This class is a default implementation for StateSnapshotContext. + */ +public class StateSnapshotContextSynchronousImpl implements StateSnapshotContext { + + private final long checkpointId; + private final long checkpointTimestamp; + + /** Factory for he checkpointing stream */ + private final CheckpointStreamFactory streamFactory; + + /** Key group range for the operator that created this context. Only for keyed operators */ + private final KeyGroupRange keyGroupRange; + + /** + * Registry for opened streams to participate in the lifecycle of the stream task. Hence, this registry should be + * obtained from and managed by the stream task. + */ + private final ClosableRegistry closableRegistry; + + private KeyedStateCheckpointOutputStream keyedStateCheckpointOutputStream; + private OperatorStateCheckpointOutputStream operatorStateCheckpointOutputStream; + + @VisibleForTesting + public StateSnapshotContextSynchronousImpl(long checkpointId, long checkpointTimestamp) { + this.checkpointId = checkpointId; + this.checkpointTimestamp = checkpointTimestamp; + this.streamFactory = null; + this.keyGroupRange = KeyGroupRange.EMPTY_KEY_GROUP_RANGE; + this.closableRegistry = null; + } + + + public StateSnapshotContextSynchronousImpl( + long checkpointId, + long checkpointTimestamp, + CheckpointStreamFactory streamFactory, + KeyGroupRange keyGroupRange, + ClosableRegistry closableRegistry) { + + this.checkpointId = checkpointId; + this.checkpointTimestamp = checkpointTimestamp; + this.streamFactory = Preconditions.checkNotNull(streamFactory); + this.keyGroupRange = Preconditions.checkNotNull(keyGroupRange); + this.closableRegistry = Preconditions.checkNotNull(closableRegistry); + } + + @Override + public long getCheckpointId() { + return checkpointId; + } + + @Override + public long getCheckpointTimestamp() { + return checkpointTimestamp; + } + + private CheckpointStreamFactory.CheckpointStateOutputStream openAndRegisterNewStream() throws Exception { + CheckpointStreamFactory.CheckpointStateOutputStream cout = + streamFactory.createCheckpointStateOutputStream(checkpointId, checkpointTimestamp); + + closableRegistry.registerClosable(cout); + return cout; + } + + @Override + public KeyedStateCheckpointOutputStream getRawKeyedOperatorStateOutput() throws Exception { + if (null == keyedStateCheckpointOutputStream) { + Preconditions.checkState(keyGroupRange != KeyGroupRange.EMPTY_KEY_GROUP_RANGE, "Not a keyed operator"); + keyedStateCheckpointOutputStream = new KeyedStateCheckpointOutputStream(openAndRegisterNewStream(), keyGroupRange); + } + return keyedStateCheckpointOutputStream; + } + + @Override + public OperatorStateCheckpointOutputStream getRawOperatorStateOutput() throws Exception { + if (null == operatorStateCheckpointOutputStream) { + operatorStateCheckpointOutputStream = new OperatorStateCheckpointOutputStream(openAndRegisterNewStream()); + } + return operatorStateCheckpointOutputStream; + } + + public RunnableFuture getKeyedStateStreamFuture() throws IOException { + return closeAndUnregisterStreamToObtainStateHandle(keyedStateCheckpointOutputStream); + } + + public RunnableFuture getOperatorStateStreamFuture() throws IOException { + return closeAndUnregisterStreamToObtainStateHandle(operatorStateCheckpointOutputStream); + } + + private RunnableFuture closeAndUnregisterStreamToObtainStateHandle( + NonClosingCheckpointOutputStream stream) throws IOException { + if (null == stream) { + return null; + } + + closableRegistry.unregisterClosable(stream.getDelegate()); + + // for now we only support synchronous writing + return new DoneFuture<>(stream.closeAndGetHandle()); + } + +} \ No newline at end of file diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskStateHandles.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskStateHandles.java new file mode 100644 index 0000000000000..ecd63995185bb --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskStateHandles.java @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state; + +import org.apache.flink.runtime.checkpoint.SubtaskState; +import org.apache.flink.util.CollectionUtil; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; + +/** + * This class encapsulates all state handles for a task. + */ +public class TaskStateHandles implements Serializable { + + public static final TaskStateHandles EMPTY = new TaskStateHandles(); + + private static final long serialVersionUID = 267686583583579359L; + + /** State handle with the (non-partitionable) legacy operator state*/ + @Deprecated + private final ChainedStateHandle legacyOperatorState; + + /** Collection of handles which represent the managed keyed state of the head operator */ + private final Collection managedKeyedState; + + /** Collection of handles which represent the raw/streamed keyed state of the head operator */ + private final Collection rawKeyedState; + + /** Outer list represents the operator chain, each collection holds handles for managed state of a single operator */ + private final List> managedOperatorState; + + /** Outer list represents the operator chain, each collection holds handles for raw/streamed state of a single operator */ + private final List> rawOperatorState; + + public TaskStateHandles() { + this(null, null, null, null, null); + } + + public TaskStateHandles(SubtaskState checkpointStateHandles) { + this(checkpointStateHandles.getLegacyOperatorState(), + transform(checkpointStateHandles.getManagedOperatorState()), + transform(checkpointStateHandles.getRawOperatorState()), + transform(checkpointStateHandles.getManagedKeyedState()), + transform(checkpointStateHandles.getRawKeyedState())); + } + + public TaskStateHandles( + ChainedStateHandle legacyOperatorState, + List> managedOperatorState, + List> rawOperatorState, + Collection managedKeyedState, + Collection rawKeyedState) { + + this.legacyOperatorState = legacyOperatorState; + this.managedKeyedState = managedKeyedState; + this.rawKeyedState = rawKeyedState; + this.managedOperatorState = managedOperatorState; + this.rawOperatorState = rawOperatorState; + } + + @Deprecated + public ChainedStateHandle getLegacyOperatorState() { + return legacyOperatorState; + } + + public Collection getManagedKeyedState() { + return managedKeyedState; + } + + public Collection getRawKeyedState() { + return rawKeyedState; + } + + public List> getRawOperatorState() { + return rawOperatorState; + } + + public List> getManagedOperatorState() { + return managedOperatorState; + } + + public boolean hasState() { + return !ChainedStateHandle.isNullOrEmpty(legacyOperatorState) + || !CollectionUtil.isNullOrEmpty(managedKeyedState) + || !CollectionUtil.isNullOrEmpty(rawKeyedState) + || !CollectionUtil.isNullOrEmpty(rawOperatorState) + || !CollectionUtil.isNullOrEmpty(managedOperatorState); + } + + private static List> transform(ChainedStateHandle in) { + if (null == in) { + return Collections.emptyList(); + } + List> out = new ArrayList<>(in.getLength()); + for (int i = 0; i < in.getLength(); ++i) { + OperatorStateHandle osh = in.get(i); + out.add(osh != null ? Collections.singletonList(osh) : null); + } + return out; + } + + private static List transform(KeyGroupsStateHandle in) { + return in == null ? Collections.emptyList() : Collections.singletonList(in); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + TaskStateHandles that = (TaskStateHandles) o; + + if (legacyOperatorState != null ? + !legacyOperatorState.equals(that.legacyOperatorState) + : that.legacyOperatorState != null) { + return false; + } + if (managedKeyedState != null ? + !managedKeyedState.equals(that.managedKeyedState) + : that.managedKeyedState != null) { + return false; + } + if (rawKeyedState != null ? + !rawKeyedState.equals(that.rawKeyedState) + : that.rawKeyedState != null) { + return false; + } + + if (rawOperatorState != null ? + !rawOperatorState.equals(that.rawOperatorState) + : that.rawOperatorState != null) { + return false; + } + return managedOperatorState != null ? + managedOperatorState.equals(that.managedOperatorState) + : that.managedOperatorState == null; + } + + @Override + public int hashCode() { + int result = legacyOperatorState != null ? legacyOperatorState.hashCode() : 0; + result = 31 * result + (managedKeyedState != null ? managedKeyedState.hashCode() : 0); + result = 31 * result + (rawKeyedState != null ? rawKeyedState.hashCode() : 0); + result = 31 * result + (managedOperatorState != null ? managedOperatorState.hashCode() : 0); + result = 31 * result + (rawOperatorState != null ? rawOperatorState.hashCode() : 0); + return result; + } +} \ No newline at end of file diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/UserFacingListState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/UserFacingListState.java similarity index 97% rename from flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/UserFacingListState.java rename to flink-runtime/src/main/java/org/apache/flink/runtime/state/UserFacingListState.java index a02a2043314aa..71026c60e9ea9 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/UserFacingListState.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/UserFacingListState.java @@ -16,7 +16,7 @@ * limitations under the License. */ -package org.apache.flink.streaming.api.operators; +package org.apache.flink.runtime.state; import org.apache.flink.api.common.state.ListState; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java index e027632d0fa11..4e15cd51048c9 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java @@ -36,7 +36,7 @@ import java.io.IOException; import java.net.URI; import java.net.URISyntaxException; -import java.util.List; +import java.util.Collection; /** * The file state backend is a state backend that stores the state of streaming jobs in a file system. @@ -199,7 +199,7 @@ public AbstractKeyedStateBackend restoreKeyedStateBackend( TypeSerializer keySerializer, int numberOfKeyGroups, KeyGroupRange keyGroupRange, - List restoredState, + Collection restoredState, TaskKvStateRegistry kvStateRegistry) throws Exception { return new HeapKeyedStateBackend<>( kvStateRegistry, diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java index b283494993633..56be46fbd5494 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java @@ -50,8 +50,8 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.Collection; import java.util.HashMap; -import java.util.List; import java.util.Map; import java.util.concurrent.RunnableFuture; @@ -94,7 +94,7 @@ public HeapKeyedStateBackend( ClassLoader userCodeClassLoader, int numberOfKeyGroups, KeyGroupRange keyGroupRange, - List restoredState) throws Exception { + Collection restoredState) throws Exception { super(kvStateRegistry, keySerializer, userCodeClassLoader, numberOfKeyGroups, keyGroupRange); LOG.info("Initializing heap keyed state backend from snapshot."); @@ -248,7 +248,7 @@ private void writeStateTableForKeyGroup( } @SuppressWarnings({"unchecked"}) - private void restorePartitionedState(List state) throws Exception { + private void restorePartitionedState(Collection state) throws Exception { int numRegisteredKvStates = 0; Map kvStatesById = new HashMap<>(); @@ -259,13 +259,10 @@ private void restorePartitionedState(List state) throws Ex continue; } - FSDataInputStream fsDataInputStream = null; + FSDataInputStream fsDataInputStream = keyGroupsHandle.openInputStream(); + cancelStreamRegistry.registerClosable(fsDataInputStream); try { - - fsDataInputStream = keyGroupsHandle.openInputStream(); - cancelStreamRegistry.registerClosable(fsDataInputStream); - DataInputViewStreamWrapper inView = new DataInputViewStreamWrapper(fsDataInputStream); int numKvStates = inView.readShort(); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemCheckpointStreamFactory.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemCheckpointStreamFactory.java index 028f8c83c399b..30de638f80f86 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemCheckpointStreamFactory.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemCheckpointStreamFactory.java @@ -127,6 +127,10 @@ public long getPos() throws IOException { return os.getPosition(); } + public boolean isClosed() { + return closed; + } + /** * Closes the stream and returns the byte array containing the stream's data. * @return The byte array containing the stream's data. diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java index 1772dbee070ee..33f03ad3ae0fd 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java @@ -30,7 +30,7 @@ import org.apache.flink.runtime.state.heap.HeapKeyedStateBackend; import java.io.IOException; -import java.util.List; +import java.util.Collection; /** * A {@link AbstractStateBackend} that stores all its data and checkpoints in memory and has no @@ -100,7 +100,7 @@ public AbstractKeyedStateBackend restoreKeyedStateBackend( TypeSerializer keySerializer, int numberOfKeyGroups, KeyGroupRange keyGroupRange, - List restoredState, + Collection restoredState, TaskKvStateRegistry kvStateRegistry) throws Exception { return new HeapKeyedStateBackend<>( diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/ActorGatewayCheckpointResponder.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/ActorGatewayCheckpointResponder.java index 6f1bf7b4b96ac..38defccc97084 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/ActorGatewayCheckpointResponder.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/ActorGatewayCheckpointResponder.java @@ -20,11 +20,11 @@ import org.apache.flink.api.common.JobID; import org.apache.flink.runtime.checkpoint.CheckpointMetaData; +import org.apache.flink.runtime.checkpoint.SubtaskState; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.instance.ActorGateway; import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint; import org.apache.flink.runtime.messages.checkpoint.DeclineCheckpoint; -import org.apache.flink.runtime.state.CheckpointStateHandles; import org.apache.flink.util.Preconditions; /** @@ -43,7 +43,7 @@ public void acknowledgeCheckpoint( JobID jobID, ExecutionAttemptID executionAttemptID, CheckpointMetaData checkpointMetaData, - CheckpointStateHandles checkpointStateHandles) { + SubtaskState checkpointStateHandles) { AcknowledgeCheckpoint message = new AcknowledgeCheckpoint( jobID, executionAttemptID, checkpointMetaData, diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/CheckpointResponder.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/CheckpointResponder.java index 4fa20e685b612..7dbb76c52ad75 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/CheckpointResponder.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/CheckpointResponder.java @@ -20,8 +20,8 @@ import org.apache.flink.api.common.JobID; import org.apache.flink.runtime.checkpoint.CheckpointMetaData; +import org.apache.flink.runtime.checkpoint.SubtaskState; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; -import org.apache.flink.runtime.state.CheckpointStateHandles; /** * Responder for checkpoint acknowledge and decline messages in the {@link Task}. @@ -35,7 +35,7 @@ public interface CheckpointResponder { * Job ID of the running job * @param executionAttemptID * Execution attempt ID of the running task - * @param checkpointStateHandles + * @param subtaskState * State handles for the checkpoint * @param checkpointMetaData * Meta data for this checkpoint @@ -45,7 +45,7 @@ void acknowledgeCheckpoint( JobID jobID, ExecutionAttemptID executionAttemptID, CheckpointMetaData checkpointMetaData, - CheckpointStateHandles checkpointStateHandles); + SubtaskState subtaskState); /** * Declines the given checkpoint. diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/RuntimeEnvironment.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/RuntimeEnvironment.java index f6720e7113bff..fa69a60a928e8 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/RuntimeEnvironment.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/RuntimeEnvironment.java @@ -26,6 +26,7 @@ import org.apache.flink.runtime.accumulators.AccumulatorRegistry; import org.apache.flink.runtime.broadcast.BroadcastVariableManager; import org.apache.flink.runtime.checkpoint.CheckpointMetaData; +import org.apache.flink.runtime.checkpoint.SubtaskState; import org.apache.flink.runtime.execution.Environment; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.io.disk.iomanager.IOManager; @@ -36,7 +37,6 @@ import org.apache.flink.runtime.memory.MemoryManager; import org.apache.flink.runtime.metrics.groups.TaskMetricGroup; import org.apache.flink.runtime.query.TaskKvStateRegistry; -import org.apache.flink.runtime.state.CheckpointStateHandles; import java.util.Map; import java.util.concurrent.Future; @@ -245,7 +245,7 @@ public void acknowledgeCheckpoint(CheckpointMetaData checkpointMetaData) { @Override public void acknowledgeCheckpoint( CheckpointMetaData checkpointMetaData, - CheckpointStateHandles checkpointStateHandles) { + SubtaskState checkpointStateHandles) { checkpointResponder.acknowledgeCheckpoint( diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java index 02a41b5484d08..bd522bd929596 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java @@ -25,16 +25,11 @@ import org.apache.flink.configuration.ConfigConstants; import org.apache.flink.configuration.Configuration; import org.apache.flink.core.fs.Path; -import org.apache.flink.runtime.checkpoint.CheckpointMetaData; -import org.apache.flink.runtime.concurrent.BiFunction; -import org.apache.flink.runtime.io.network.PartitionState; -import org.apache.flink.runtime.io.network.netty.PartitionStateChecker; -import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider; -import org.apache.flink.runtime.io.network.partition.ResultPartitionConsumableNotifier; -import org.apache.flink.runtime.metrics.groups.TaskMetricGroup; import org.apache.flink.runtime.accumulators.AccumulatorRegistry; import org.apache.flink.runtime.blob.BlobKey; import org.apache.flink.runtime.broadcast.BroadcastVariableManager; +import org.apache.flink.runtime.checkpoint.CheckpointMetaData; +import org.apache.flink.runtime.concurrent.BiFunction; import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor; import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor; import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor; @@ -46,23 +41,24 @@ import org.apache.flink.runtime.filecache.FileCache; import org.apache.flink.runtime.io.disk.iomanager.IOManager; import org.apache.flink.runtime.io.network.NetworkEnvironment; +import org.apache.flink.runtime.io.network.PartitionState; import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter; +import org.apache.flink.runtime.io.network.netty.PartitionStateChecker; import org.apache.flink.runtime.io.network.partition.ResultPartition; +import org.apache.flink.runtime.io.network.partition.ResultPartitionConsumableNotifier; import org.apache.flink.runtime.io.network.partition.ResultPartitionID; import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate; import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable; +import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider; import org.apache.flink.runtime.jobgraph.tasks.StatefulTask; import org.apache.flink.runtime.jobgraph.tasks.StoppableTask; import org.apache.flink.runtime.memory.MemoryManager; +import org.apache.flink.runtime.metrics.groups.TaskMetricGroup; import org.apache.flink.runtime.query.TaskKvStateRegistry; -import org.apache.flink.runtime.state.ChainedStateHandle; -import org.apache.flink.runtime.state.KeyGroupsStateHandle; -import org.apache.flink.runtime.state.OperatorStateHandle; -import org.apache.flink.runtime.state.StreamStateHandle; - +import org.apache.flink.runtime.state.TaskStateHandles; import org.apache.flink.util.Preconditions; import org.apache.flink.util.SerializedValue; import org.slf4j.Logger; @@ -70,7 +66,6 @@ import java.io.IOException; import java.net.URL; -import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -233,18 +228,10 @@ public class Task implements Runnable, TaskActions { private volatile ExecutorService asyncCallDispatcher; /** - * The handle to the chained operator state that the task was initialized with. Will be set + * The handles to the states that the task was initialized with. Will be set * to null after the initialization, to be memory friendly. */ - private volatile ChainedStateHandle chainedOperatorState; - - /** - * The handle to the key group state that the task was initialized with. Will be set - * to null after the initialization, to be memory friendly. - */ - private volatile List keyGroupStates; - - private volatile List> partitionableOperatorState; + private volatile TaskStateHandles taskStateHandles; /** Initialized from the Flink configuration. May also be set at the ExecutionConfig */ private long taskCancellationInterval; @@ -280,10 +267,8 @@ public Task( this.requiredJarFiles = checkNotNull(tdd.getRequiredJarFiles()); this.requiredClasspaths = checkNotNull(tdd.getRequiredClasspaths()); this.nameOfInvokableClass = checkNotNull(tdd.getInvokableClassName()); - this.chainedOperatorState = tdd.getOperatorState(); this.serializedExecutionConfig = checkNotNull(tdd.getSerializedExecutionConfig()); - this.keyGroupStates = tdd.getKeyGroupState(); - this.partitionableOperatorState = tdd.getPartitionableOperatorState(); + this.taskStateHandles = tdd.getTaskStateHandles(); this.taskCancellationInterval = jobConfiguration.getLong( ConfigConstants.TASK_CANCELLATION_INTERVAL_MILLIS, @@ -570,20 +555,19 @@ else if (current == ExecutionState.CANCELING) { // the state into the task. the state is non-empty if this is an execution // of a task that failed but had backuped state from a checkpoint - if (chainedOperatorState != null || keyGroupStates != null || partitionableOperatorState != null) { + if (null != taskStateHandles) { if (invokable instanceof StatefulTask) { StatefulTask op = (StatefulTask) invokable; - op.setInitialState(chainedOperatorState, keyGroupStates, partitionableOperatorState); + op.setInitialState(taskStateHandles); } else { throw new IllegalStateException("Found operator state for a non-stateful task invokable"); } + // be memory and GC friendly - since the code stays in invoke() for a potentially long time, + // we clear the reference to the state handle + //noinspection UnusedAssignment + taskStateHandles = null; } - // be memory and GC friendly - since the code stays in invoke() for a potentially long time, - // we clear the reference to the state handle - //noinspection UnusedAssignment - this.chainedOperatorState = null; - this.keyGroupStates = null; // ---------------------------------------------------------------- // actual task core work diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/util/IntArrayList.java b/flink-runtime/src/main/java/org/apache/flink/runtime/util/IntArrayList.java index 27d958aa2602d..ce52eac6487a7 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/util/IntArrayList.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/util/IntArrayList.java @@ -18,6 +18,7 @@ package org.apache.flink.runtime.util; +import java.util.Arrays; import java.util.NoSuchElementException; /** @@ -69,6 +70,10 @@ private void grow(final int length) { } } + public int[] toArray() { + return Arrays.copyOf(array, size); + } + public static final IntArrayList EMPTY = new IntArrayList(0) { @Override diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/util/LongArrayList.java b/flink-runtime/src/main/java/org/apache/flink/runtime/util/LongArrayList.java index e6532096ef2d4..f2d9556622f0b 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/util/LongArrayList.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/util/LongArrayList.java @@ -18,6 +18,8 @@ package org.apache.flink.runtime.util; +import java.util.Arrays; + /** * Minimal implementation of an array-backed list of longs */ @@ -61,6 +63,10 @@ public void clear() { public boolean isEmpty() { return (size==0); } + + public long[] toArray() { + return Arrays.copyOf(array, size); + } private void grow(int length) { if(length > array.length) { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/util/NonClosingStreamDecorator.java b/flink-runtime/src/main/java/org/apache/flink/runtime/util/NonClosingStreamDecorator.java new file mode 100644 index 0000000000000..ba7bc79eee58a --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/util/NonClosingStreamDecorator.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.util; + +import java.io.IOException; +import java.io.InputStream; + +/** + * Decorator for input streams that ignores calls to {@link InputStream#close()}. + */ +public class NonClosingStreamDecorator extends InputStream { + + private final InputStream delegate; + + public NonClosingStreamDecorator(InputStream delegate) { + this.delegate = delegate; + } + + @Override + public int read() throws IOException { + return delegate.read(); + } + + @Override + public int read(byte[] b) throws IOException { + return delegate.read(b); + } + + @Override + public int read(byte[] b, int off, int len) throws IOException { + return delegate.read(b, off, len); + } + + @Override + public long skip(long n) throws IOException { + return delegate.skip(n); + } + + @Override + public int available() throws IOException { + return super.available(); + } + + @Override + public void close() throws IOException { + // ignore + } + + @Override + public void mark(int readlimit) { + super.mark(readlimit); + } + + @Override + public void reset() throws IOException { + super.reset(); + } + + @Override + public boolean markSupported() { + return super.markSupported(); + } +} \ No newline at end of file diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java index 2a20c6c7c09fe..d60f07dd2a4b5 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java @@ -38,13 +38,13 @@ import org.apache.flink.runtime.messages.checkpoint.NotifyCheckpointComplete; import org.apache.flink.runtime.messages.checkpoint.TriggerCheckpoint; import org.apache.flink.runtime.state.ChainedStateHandle; -import org.apache.flink.runtime.state.CheckpointStateHandles; import org.apache.flink.runtime.state.KeyGroupRange; import org.apache.flink.runtime.state.KeyGroupRangeAssignment; import org.apache.flink.runtime.state.KeyGroupRangeOffsets; import org.apache.flink.runtime.state.KeyGroupsStateHandle; import org.apache.flink.runtime.state.OperatorStateHandle; import org.apache.flink.runtime.state.StreamStateHandle; +import org.apache.flink.runtime.state.TaskStateHandles; import org.apache.flink.runtime.state.filesystem.FileStateHandle; import org.apache.flink.runtime.state.memory.ByteStreamStateHandle; import org.apache.flink.runtime.testutils.CommonTestUtils; @@ -1847,15 +1847,15 @@ public void testRestoreLatestCheckpointedState() throws Exception { long checkpointId = Iterables.getOnlyElement(coord.getPendingCheckpoints().keySet()); CheckpointMetaData checkpointMetaData = new CheckpointMetaData(checkpointId, 0L); - List keyGroupPartitions1 = CheckpointCoordinator.createKeyGroupPartitions(maxParallelism1, parallelism1); - List keyGroupPartitions2 = CheckpointCoordinator.createKeyGroupPartitions(maxParallelism2, parallelism2); + List keyGroupPartitions1 = StateAssignmentOperation.createKeyGroupPartitions(maxParallelism1, parallelism1); + List keyGroupPartitions2 = StateAssignmentOperation.createKeyGroupPartitions(maxParallelism2, parallelism2); for (int index = 0; index < jobVertex1.getParallelism(); index++) { ChainedStateHandle nonPartitionedState = generateStateForVertex(jobVertexID1, index); - ChainedStateHandle partitionableState = generateChainedPartitionableStateHandle(jobVertexID1, index, 2, 8); - List partitionedKeyGroupState = generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index)); + ChainedStateHandle partitionableState = generateChainedPartitionableStateHandle(jobVertexID1, index, 2, 8, false); + KeyGroupsStateHandle partitionedKeyGroupState = generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index), false); - CheckpointStateHandles checkpointStateHandles = new CheckpointStateHandles(nonPartitionedState, partitionableState, partitionedKeyGroupState); + SubtaskState checkpointStateHandles = new SubtaskState(nonPartitionedState, partitionableState, null, partitionedKeyGroupState, null, 0); AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint( jid, jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(), @@ -1867,9 +1867,9 @@ public void testRestoreLatestCheckpointedState() throws Exception { for (int index = 0; index < jobVertex2.getParallelism(); index++) { ChainedStateHandle nonPartitionedState = generateStateForVertex(jobVertexID2, index); - ChainedStateHandle partitionableState = generateChainedPartitionableStateHandle(jobVertexID2, index, 2, 8); - List partitionedKeyGroupState = generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index)); - CheckpointStateHandles checkpointStateHandles = new CheckpointStateHandles(nonPartitionedState, partitionableState, partitionedKeyGroupState); + ChainedStateHandle partitionableState = generateChainedPartitionableStateHandle(jobVertexID2, index, 2, 8, false); + KeyGroupsStateHandle partitionedKeyGroupState = generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index), false); + SubtaskState checkpointStateHandles = new SubtaskState(nonPartitionedState, partitionableState, null, partitionedKeyGroupState, null, 0); AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint( jid, jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(), @@ -1952,13 +1952,13 @@ public void testRestoreLatestCheckpointFailureWhenMaxParallelismChanges() throws long checkpointId = Iterables.getOnlyElement(coord.getPendingCheckpoints().keySet()); CheckpointMetaData checkpointMetaData = new CheckpointMetaData(checkpointId, 0L); - List keyGroupPartitions1 = CheckpointCoordinator.createKeyGroupPartitions(maxParallelism1, parallelism1); - List keyGroupPartitions2 = CheckpointCoordinator.createKeyGroupPartitions(maxParallelism2, parallelism2); + List keyGroupPartitions1 = StateAssignmentOperation.createKeyGroupPartitions(maxParallelism1, parallelism1); + List keyGroupPartitions2 = StateAssignmentOperation.createKeyGroupPartitions(maxParallelism2, parallelism2); for (int index = 0; index < jobVertex1.getParallelism(); index++) { ChainedStateHandle valueSizeTuple = generateStateForVertex(jobVertexID1, index); - List keyGroupState = generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index)); - CheckpointStateHandles checkpointStateHandles = new CheckpointStateHandles(valueSizeTuple, null, keyGroupState); + KeyGroupsStateHandle keyGroupState = generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index), false); + SubtaskState checkpointStateHandles = new SubtaskState(valueSizeTuple, null, null, keyGroupState, null, 0); AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint( jid, jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(), @@ -1971,8 +1971,8 @@ public void testRestoreLatestCheckpointFailureWhenMaxParallelismChanges() throws for (int index = 0; index < jobVertex2.getParallelism(); index++) { ChainedStateHandle valueSizeTuple = generateStateForVertex(jobVertexID2, index); - List keyGroupState = generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index)); - CheckpointStateHandles checkpointStateHandles = new CheckpointStateHandles(valueSizeTuple, null, keyGroupState); + KeyGroupsStateHandle keyGroupState = generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index), false); + SubtaskState checkpointStateHandles = new SubtaskState(valueSizeTuple, null, null, keyGroupState, null, 0); AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint( jid, jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(), @@ -2067,17 +2067,17 @@ public void testRestoreLatestCheckpointFailureWhenParallelismChanges() throws Ex long checkpointId = Iterables.getOnlyElement(coord.getPendingCheckpoints().keySet()); CheckpointMetaData checkpointMetaData = new CheckpointMetaData(checkpointId, 0L); - List keyGroupPartitions1 = - CheckpointCoordinator.createKeyGroupPartitions(maxParallelism1, parallelism1); - List keyGroupPartitions2 = - CheckpointCoordinator.createKeyGroupPartitions(maxParallelism2, parallelism2); + List keyGroupPartitions1 = + StateAssignmentOperation.createKeyGroupPartitions(maxParallelism1, parallelism1); + List keyGroupPartitions2 = + StateAssignmentOperation.createKeyGroupPartitions(maxParallelism2, parallelism2); for (int index = 0; index < jobVertex1.getParallelism(); index++) { ChainedStateHandle valueSizeTuple = generateStateForVertex(jobVertexID1, index); - List keyGroupState = generateKeyGroupState( - jobVertexID1, keyGroupPartitions1.get(index)); + KeyGroupsStateHandle keyGroupState = generateKeyGroupState( + jobVertexID1, keyGroupPartitions1.get(index), false); - CheckpointStateHandles checkpointStateHandles = new CheckpointStateHandles(valueSizeTuple, null, keyGroupState); + SubtaskState checkpointStateHandles = new SubtaskState(valueSizeTuple, null, null, keyGroupState, null, 0); AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint( jid, jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(), @@ -2091,10 +2091,10 @@ public void testRestoreLatestCheckpointFailureWhenParallelismChanges() throws Ex for (int index = 0; index < jobVertex2.getParallelism(); index++) { ChainedStateHandle state = generateStateForVertex(jobVertexID2, index); - List keyGroupState = generateKeyGroupState( - jobVertexID2, keyGroupPartitions2.get(index)); + KeyGroupsStateHandle keyGroupState = generateKeyGroupState( + jobVertexID2, keyGroupPartitions2.get(index), false); - CheckpointStateHandles checkpointStateHandles = new CheckpointStateHandles(state, null, keyGroupState); + SubtaskState checkpointStateHandles = new SubtaskState(state, null, null, keyGroupState, null, 0); AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint( jid, jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(), @@ -2132,24 +2132,36 @@ public void testRestoreLatestCheckpointFailureWhenParallelismChanges() throws Ex "non-partitioned state changed."); } + @Test + public void testRestoreLatestCheckpointedStateScaleIn() throws Exception { + testRestoreLatestCheckpointedStateWithChangingParallelism(false); + } + + @Test + public void testRestoreLatestCheckpointedStateScaleOut() throws Exception { + testRestoreLatestCheckpointedStateWithChangingParallelism(false); + } + /** * Tests the checkpoint restoration with changing parallelism of job vertex with partitioned * state. * * @throws Exception */ - @Test - public void testRestoreLatestCheckpointedStateWithChangingParallelism() throws Exception { + private void testRestoreLatestCheckpointedStateWithChangingParallelism(boolean scaleOut) throws Exception { final JobID jid = new JobID(); final long timestamp = System.currentTimeMillis(); final JobVertexID jobVertexID1 = new JobVertexID(); final JobVertexID jobVertexID2 = new JobVertexID(); int parallelism1 = 3; - int parallelism2 = 2; + int parallelism2 = scaleOut ? 2 : 13; + int maxParallelism1 = 42; int maxParallelism2 = 13; + int newParallelism2 = scaleOut ? 13 : 2; + final ExecutionJobVertex jobVertex1 = mockExecutionJobVertex( jobVertexID1, parallelism1, @@ -2190,18 +2202,20 @@ public void testRestoreLatestCheckpointedStateWithChangingParallelism() throws E long checkpointId = Iterables.getOnlyElement(coord.getPendingCheckpoints().keySet()); CheckpointMetaData checkpointMetaData = new CheckpointMetaData(checkpointId, 0L); - List keyGroupPartitions1 = - CheckpointCoordinator.createKeyGroupPartitions(maxParallelism1, parallelism1); - List keyGroupPartitions2 = - CheckpointCoordinator.createKeyGroupPartitions(maxParallelism2, parallelism2); + List keyGroupPartitions1 = + StateAssignmentOperation.createKeyGroupPartitions(maxParallelism1, parallelism1); + List keyGroupPartitions2 = + StateAssignmentOperation.createKeyGroupPartitions(maxParallelism2, parallelism2); + //vertex 1 for (int index = 0; index < jobVertex1.getParallelism(); index++) { ChainedStateHandle valueSizeTuple = generateStateForVertex(jobVertexID1, index); - ChainedStateHandle partitionableState = generateChainedPartitionableStateHandle(jobVertexID1, index, 2, 8); - List keyGroupState = generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index)); + ChainedStateHandle opStateBackend = generateChainedPartitionableStateHandle(jobVertexID1, index, 2, 8, false); + KeyGroupsStateHandle keyedStateBackend = generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index), false); + KeyGroupsStateHandle keyedStateRaw = generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index), true); - CheckpointStateHandles checkpointStateHandles = new CheckpointStateHandles(valueSizeTuple, partitionableState, keyGroupState); + SubtaskState checkpointStateHandles = new SubtaskState(valueSizeTuple, opStateBackend, null, keyedStateBackend, keyedStateRaw , 0); AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint( jid, jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(), @@ -2211,13 +2225,19 @@ public void testRestoreLatestCheckpointedStateWithChangingParallelism() throws E coord.receiveAcknowledgeMessage(acknowledgeCheckpoint); } - - final List> originalPartitionableStates = new ArrayList<>(jobVertex2.getParallelism()); + //vertex 2 + final List> expectedOpStatesBackend = new ArrayList<>(jobVertex2.getParallelism()); + final List> expectedOpStatesRaw = new ArrayList<>(jobVertex2.getParallelism()); for (int index = 0; index < jobVertex2.getParallelism(); index++) { - List keyGroupState = generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index)); - ChainedStateHandle partitionableState = generateChainedPartitionableStateHandle(jobVertexID2, index, 2, 8); - originalPartitionableStates.add(partitionableState); - CheckpointStateHandles checkpointStateHandles = new CheckpointStateHandles(null, partitionableState, keyGroupState); + KeyGroupsStateHandle keyedStateBackend = generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index), false); + KeyGroupsStateHandle keyedStateRaw = generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index), true); + ChainedStateHandle opStateBackend = generateChainedPartitionableStateHandle(jobVertexID2, index, 2, 8, false); + ChainedStateHandle opStateRaw = generateChainedPartitionableStateHandle(jobVertexID2, index, 2, 8, true); + expectedOpStatesBackend.add(opStateBackend); + expectedOpStatesRaw.add(opStateRaw); + SubtaskState checkpointStateHandles = + new SubtaskState(new ChainedStateHandle<>( + Collections.singletonList(null)), opStateBackend, opStateRaw, keyedStateBackend, keyedStateRaw, 0); AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint( jid, jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(), @@ -2233,16 +2253,15 @@ public void testRestoreLatestCheckpointedStateWithChangingParallelism() throws E Map tasks = new HashMap<>(); - int newParallelism2 = 13; - - List newKeyGroupPartitions2 = - CheckpointCoordinator.createKeyGroupPartitions(maxParallelism2, newParallelism2); + List newKeyGroupPartitions2 = + StateAssignmentOperation.createKeyGroupPartitions(maxParallelism2, newParallelism2); final ExecutionJobVertex newJobVertex1 = mockExecutionJobVertex( jobVertexID1, parallelism1, maxParallelism1); + // rescale vertex 2 final ExecutionJobVertex newJobVertex2 = mockExecutionJobVertex( jobVertexID2, newParallelism2, @@ -2254,19 +2273,28 @@ public void testRestoreLatestCheckpointedStateWithChangingParallelism() throws E // verify the restored state verifiyStateRestore(jobVertexID1, newJobVertex1, keyGroupPartitions1); - List>> actualPartitionableStates = new ArrayList<>(newJobVertex2.getParallelism()); + List>> actualOpStatesBackend = new ArrayList<>(newJobVertex2.getParallelism()); + List>> actualOpStatesRaw = new ArrayList<>(newJobVertex2.getParallelism()); for (int i = 0; i < newJobVertex2.getParallelism(); i++) { - List originalKeyGroupState = generateKeyGroupState(jobVertexID2, newKeyGroupPartitions2.get(i)); + KeyGroupsStateHandle originalKeyedStateBackend = generateKeyGroupState(jobVertexID2, newKeyGroupPartitions2.get(i), false); + KeyGroupsStateHandle originalKeyedStateRaw = generateKeyGroupState(jobVertexID2, newKeyGroupPartitions2.get(i), true); + + TaskStateHandles taskStateHandles = newJobVertex2.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskStateHandles(); - ChainedStateHandle operatorState = newJobVertex2.getTaskVertices()[i].getCurrentExecutionAttempt().getChainedStateHandle(); - List> partitionableState = newJobVertex2.getTaskVertices()[i].getCurrentExecutionAttempt().getChainedPartitionableStateHandle(); - List keyGroupState = newJobVertex2.getTaskVertices()[i].getCurrentExecutionAttempt().getKeyGroupsStateHandles(); + ChainedStateHandle operatorState = taskStateHandles.getLegacyOperatorState(); + List> opStateBackend = taskStateHandles.getManagedOperatorState(); + List> opStateRaw = taskStateHandles.getRawOperatorState(); + Collection keyGroupStateBackend = taskStateHandles.getManagedKeyedState(); + Collection keyGroupStateRaw = taskStateHandles.getRawKeyedState(); - actualPartitionableStates.add(partitionableState); + actualOpStatesBackend.add(opStateBackend); + actualOpStatesRaw.add(opStateRaw); assertNull(operatorState); - compareKeyPartitionedState(originalKeyGroupState, keyGroupState); + compareKeyedState(Collections.singletonList(originalKeyedStateBackend), keyGroupStateBackend); + compareKeyedState(Collections.singletonList(originalKeyedStateRaw), keyGroupStateRaw); } - comparePartitionableState(originalPartitionableStates, actualPartitionableStates); + comparePartitionableState(expectedOpStatesBackend, actualOpStatesBackend); + comparePartitionableState(expectedOpStatesRaw, actualOpStatesRaw); } /** @@ -2320,15 +2348,41 @@ public void testExternalizedCheckpoints() throws Exception { // Utilities // ------------------------------------------------------------------------ - public static List generateKeyGroupState( + static void sendAckMessageToCoordinator( + CheckpointCoordinator coord, + long checkpointId, JobID jid, + ExecutionJobVertex jobVertex, + JobVertexID jobVertexID, + List keyGroupPartitions) throws Exception { + + for (int index = 0; index < jobVertex.getParallelism(); index++) { + ChainedStateHandle state = generateStateForVertex(jobVertexID, index); + KeyGroupsStateHandle keyGroupState = generateKeyGroupState( + jobVertexID, + keyGroupPartitions.get(index), false); + + SubtaskState checkpointStateHandles = new SubtaskState(state, null, null, keyGroupState, null, 0); + AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint( + jid, + jobVertex.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(), + new CheckpointMetaData(checkpointId, 0L), + checkpointStateHandles); + + coord.receiveAcknowledgeMessage(acknowledgeCheckpoint); + } + } + + public static KeyGroupsStateHandle generateKeyGroupState( JobVertexID jobVertexID, - KeyGroupRange keyGroupPartition) throws IOException { + KeyGroupRange keyGroupPartition, boolean rawState) throws IOException { List testStatesLists = new ArrayList<>(keyGroupPartition.getNumberOfKeyGroups()); // generate state for one keygroup for (int keyGroupIndex : keyGroupPartition) { - Random random = new Random(jobVertexID.hashCode() + keyGroupIndex); + int vertexHash = jobVertexID.hashCode(); + int seed = rawState ? (vertexHash * (31 + keyGroupIndex)) : (vertexHash + keyGroupIndex); + Random random = new Random(seed); int simulatedStateValue = random.nextInt(); testStatesLists.add(simulatedStateValue); } @@ -2336,7 +2390,7 @@ public static List generateKeyGroupState( return generateKeyGroupState(keyGroupPartition, testStatesLists); } - public static List generateKeyGroupState( + public static KeyGroupsStateHandle generateKeyGroupState( KeyGroupRange keyGroupRange, List states) throws IOException { @@ -2353,9 +2407,7 @@ public static List generateKeyGroupState( KeyGroupsStateHandle keyGroupsStateHandle = new KeyGroupsStateHandle( keyGroupRangeOffsets, allSerializedStatesHandle); - List keyGroupsStateHandleList = new ArrayList<>(); - keyGroupsStateHandleList.add(keyGroupsStateHandle); - return keyGroupsStateHandleList; + return keyGroupsStateHandle; } public static Tuple2> serializeTogetherAndTrackOffsets( @@ -2412,14 +2464,19 @@ public static ChainedStateHandle generateChainedPartitionab JobVertexID jobVertexID, int index, int namedStates, - int partitionsPerState) throws IOException { + int partitionsPerState, + boolean rawState) throws IOException { Map> statesListsMap = new HashMap<>(namedStates); for (int i = 0; i < namedStates; ++i) { List testStatesLists = new ArrayList<>(partitionsPerState); // generate state - Random random = new Random(jobVertexID.hashCode() * index + i * namedStates); + int seed = jobVertexID.hashCode() * index + i * namedStates; + if (rawState) { + seed = (seed + 1) * 31; + } + Random random = new Random(seed); for (int j = 0; j < partitionsPerState; ++j) { int simulatedStateValue = random.nextInt(); testStatesLists.add(simulatedStateValue); @@ -2454,7 +2511,7 @@ public static ChainedStateHandle generateChainedPartitionab serializationWithOffsets.f0); OperatorStateHandle operatorStateHandle = - new OperatorStateHandle(streamStateHandle, offsetsMap); + new OperatorStateHandle(offsetsMap, streamStateHandle); return ChainedStateHandle.wrapSingleHandle(operatorStateHandle); } @@ -2528,37 +2585,35 @@ public static void verifiyStateRestore( for (int i = 0; i < executionJobVertex.getParallelism(); i++) { + TaskStateHandles taskStateHandles = executionJobVertex.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskStateHandles(); + ChainedStateHandle expectNonPartitionedState = generateStateForVertex(jobVertexID, i); - ChainedStateHandle actualNonPartitionedState = executionJobVertex. - getTaskVertices()[i].getCurrentExecutionAttempt().getChainedStateHandle(); + ChainedStateHandle actualNonPartitionedState = taskStateHandles.getLegacyOperatorState(); assertTrue(CommonTestUtils.isSteamContentEqual( expectNonPartitionedState.get(0).openInputStream(), actualNonPartitionedState.get(0).openInputStream())); - ChainedStateHandle expectedPartitionableState = - generateChainedPartitionableStateHandle(jobVertexID, i, 2, 8); + ChainedStateHandle expectedOpStateBackend = + generateChainedPartitionableStateHandle(jobVertexID, i, 2, 8, false); - List> actualPartitionableState = executionJobVertex. - getTaskVertices()[i].getCurrentExecutionAttempt().getChainedPartitionableStateHandle(); + List> actualPartitionableState = taskStateHandles.getManagedOperatorState(); assertTrue(CommonTestUtils.isSteamContentEqual( - expectedPartitionableState.get(0).openInputStream(), + expectedOpStateBackend.get(0).openInputStream(), actualPartitionableState.get(0).iterator().next().openInputStream())); - List expectPartitionedKeyGroupState = generateKeyGroupState( - jobVertexID, - keyGroupPartitions.get(i)); - List actualPartitionedKeyGroupState = executionJobVertex. - getTaskVertices()[i].getCurrentExecutionAttempt().getKeyGroupsStateHandles(); - compareKeyPartitionedState(expectPartitionedKeyGroupState, actualPartitionedKeyGroupState); + KeyGroupsStateHandle expectPartitionedKeyGroupState = generateKeyGroupState( + jobVertexID, keyGroupPartitions.get(i), false); + Collection actualPartitionedKeyGroupState = taskStateHandles.getManagedKeyedState(); + compareKeyedState(Collections.singletonList(expectPartitionedKeyGroupState), actualPartitionedKeyGroupState); } } - public static void compareKeyPartitionedState( - List expectPartitionedKeyGroupState, - List actualPartitionedKeyGroupState) throws Exception { + public static void compareKeyedState( + Collection expectPartitionedKeyGroupState, + Collection actualPartitionedKeyGroupState) throws Exception { - KeyGroupsStateHandle expectedHeadOpKeyGroupStateHandle = expectPartitionedKeyGroupState.get(0); + KeyGroupsStateHandle expectedHeadOpKeyGroupStateHandle = expectPartitionedKeyGroupState.iterator().next(); int expectedTotalKeyGroups = expectedHeadOpKeyGroupStateHandle.getNumberOfKeyGroups(); int actualTotalKeyGroups = 0; for(KeyGroupsStateHandle keyGroupsStateHandle: actualPartitionedKeyGroupState) { @@ -2576,13 +2631,10 @@ public static void compareKeyPartitionedState( for (KeyGroupsStateHandle oneActualKeyGroupStateHandle : actualPartitionedKeyGroupState) { if (oneActualKeyGroupStateHandle.containsKeyGroup(groupId)) { long actualOffset = oneActualKeyGroupStateHandle.getOffsetForKeyGroup(groupId); - try (FSDataInputStream actualInputStream = - oneActualKeyGroupStateHandle.openInputStream()) { + try (FSDataInputStream actualInputStream = oneActualKeyGroupStateHandle.openInputStream()) { actualInputStream.seek(actualOffset); - int actualGroupState = InstantiationUtil. deserializeObject(actualInputStream, Thread.currentThread().getContextClassLoader()); - assertEquals(expectedKeyGroupState, actualGroupState); } } @@ -2599,16 +2651,7 @@ public static void comparePartitionableState( for (ChainedStateHandle chainedStateHandle : expected) { for (int i = 0; i < chainedStateHandle.getLength(); ++i) { OperatorStateHandle operatorStateHandle = chainedStateHandle.get(i); - try (FSDataInputStream in = operatorStateHandle.openInputStream()) { - for (Map.Entry entry : operatorStateHandle.getStateNameToPartitionOffsets().entrySet()) { - for (long offset : entry.getValue()) { - in.seek(offset); - Integer state = InstantiationUtil. - deserializeObject(in, Thread.currentThread().getContextClassLoader()); - expectedResult.add(i + " : " + entry.getKey() + " : " + state); - } - } - } + collectResult(i, operatorStateHandle, expectedResult); } } Collections.sort(expectedResult); @@ -2618,25 +2661,32 @@ public static void comparePartitionableState( if (collectionList != null) { for (int i = 0; i < collectionList.size(); ++i) { Collection stateHandles = collectionList.get(i); + Assert.assertNotNull(stateHandles); for (OperatorStateHandle operatorStateHandle : stateHandles) { - try (FSDataInputStream in = operatorStateHandle.openInputStream()) { - for (Map.Entry entry : operatorStateHandle.getStateNameToPartitionOffsets().entrySet()) { - for (long offset : entry.getValue()) { - in.seek(offset); - Integer state = InstantiationUtil. - deserializeObject(in, Thread.currentThread().getContextClassLoader()); - actualResult.add(i + " : " + entry.getKey() + " : " + state); - } - } - } + collectResult(i, operatorStateHandle, actualResult); } } } } + Collections.sort(actualResult); Assert.assertEquals(expectedResult, actualResult); } + private static void collectResult(int opIdx, OperatorStateHandle operatorStateHandle, List resultCollector) throws Exception { + try (FSDataInputStream in = operatorStateHandle.openInputStream()) { + for (Map.Entry entry : operatorStateHandle.getStateNameToPartitionOffsets().entrySet()) { + for (long offset : entry.getValue()) { + in.seek(offset); + Integer state = InstantiationUtil. + deserializeObject(in, Thread.currentThread().getContextClassLoader()); + resultCollector.add(opIdx + " : " + entry.getKey() + " : " + state); + } + } + } + } + + @Test public void testCreateKeyGroupPartitions() { testCreateKeyGroupPartitions(1, 1); @@ -2697,7 +2747,7 @@ public void testStopPeriodicScheduler() throws Exception { } private void testCreateKeyGroupPartitions(int maxParallelism, int parallelism) { - List ranges = CheckpointCoordinator.createKeyGroupPartitions(maxParallelism, parallelism); + List ranges = StateAssignmentOperation.createKeyGroupPartitions(maxParallelism, parallelism); for (int i = 0; i < maxParallelism; ++i) { KeyGroupRange range = ranges.get(KeyGroupRangeAssignment.computeOperatorIndexForKeyGroup(maxParallelism, parallelism, i)); if (!range.contains(i)) { @@ -2743,7 +2793,7 @@ private void doTestPartitionableStateRepartitioning( } previousParallelOpInstanceStates.add( - new OperatorStateHandle(new FileStateHandle(fakePath, -1), namedStatesToOffsets)); + new OperatorStateHandle(namedStatesToOffsets, new FileStateHandle(fakePath, -1))); } Map>> expected = new HashMap<>(); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java index 950526c6652b9..359262fb425e5 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java @@ -29,17 +29,18 @@ import org.apache.flink.runtime.jobgraph.tasks.ExternalizedCheckpointSettings; import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint; import org.apache.flink.runtime.state.ChainedStateHandle; -import org.apache.flink.runtime.state.CheckpointStateHandles; import org.apache.flink.runtime.state.KeyGroupRange; import org.apache.flink.runtime.state.KeyGroupsStateHandle; import org.apache.flink.runtime.state.OperatorStateHandle; import org.apache.flink.runtime.state.StreamStateHandle; +import org.apache.flink.runtime.state.TaskStateHandles; import org.apache.flink.runtime.util.SerializableObject; import org.hamcrest.BaseMatcher; import org.hamcrest.Description; import org.junit.Test; import org.mockito.Mockito; +import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.HashMap; @@ -65,7 +66,7 @@ public void testSetState() { final ChainedStateHandle serializedState = CheckpointCoordinatorTest.generateChainedStateHandle(new SerializableObject()); KeyGroupRange keyGroupRange = KeyGroupRange.of(0,0); List testStates = Collections.singletonList(new SerializableObject()); - final List serializedKeyGroupStates = CheckpointCoordinatorTest.generateKeyGroupState(keyGroupRange, testStates); + final KeyGroupsStateHandle serializedKeyGroupStates = CheckpointCoordinatorTest.generateKeyGroupState(keyGroupRange, testStates); final JobID jid = new JobID(); final JobVertexID statefulId = new JobVertexID(); @@ -115,7 +116,7 @@ public void testSetState() { PendingCheckpoint pending = coord.getPendingCheckpoints().values().iterator().next(); final long checkpointId = pending.getCheckpointId(); - CheckpointStateHandles checkpointStateHandles = new CheckpointStateHandles(serializedState, null, serializedKeyGroupStates); + SubtaskState checkpointStateHandles = new SubtaskState(serializedState, null, null, serializedKeyGroupStates, null, 0L); CheckpointMetaData checkpointMetaData = new CheckpointMetaData(checkpointId, 0L); coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec1.getAttemptId(), checkpointMetaData, checkpointStateHandles)); coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec2.getAttemptId(), checkpointMetaData, checkpointStateHandles)); @@ -131,26 +132,33 @@ public void testSetState() { // verify that each stateful vertex got the state - BaseMatcher matcher = new BaseMatcher() { + final TaskStateHandles taskStateHandles = new TaskStateHandles( + serializedState, + Collections.>singletonList(null), + Collections.>singletonList(null), + Collections.singletonList(serializedKeyGroupStates), + null); + + BaseMatcher matcher = new BaseMatcher() { @Override public boolean matches(Object o) { - if (o instanceof CheckpointStateHandles) { - return ((CheckpointStateHandles) o).getNonPartitionedStateHandles().equals(serializedState); + if (o instanceof TaskStateHandles) { + return o.equals(taskStateHandles); } return false; } @Override public void describeTo(Description description) { - description.appendValue(serializedState); + description.appendValue(taskStateHandles); } }; - verify(statefulExec1, times(1)).setInitialState(Mockito.argThat(matcher), Mockito.>>any()); - verify(statefulExec2, times(1)).setInitialState(Mockito.argThat(matcher), Mockito.>>any()); - verify(statefulExec3, times(1)).setInitialState(Mockito.argThat(matcher), Mockito.>>any()); - verify(statelessExec1, times(0)).setInitialState(Mockito.any(), Mockito.>>any()); - verify(statelessExec2, times(0)).setInitialState(Mockito.any(), Mockito.>>any()); + verify(statefulExec1, times(1)).setInitialState(Mockito.argThat(matcher)); + verify(statefulExec2, times(1)).setInitialState(Mockito.argThat(matcher)); + verify(statefulExec3, times(1)).setInitialState(Mockito.argThat(matcher)); + verify(statelessExec1, times(0)).setInitialState(Mockito.any()); + verify(statelessExec2, times(0)).setInitialState(Mockito.any()); } catch (Exception e) { e.printStackTrace(); @@ -164,7 +172,7 @@ public void testStateOnlyPartiallyAvailable() { final ChainedStateHandle serializedState = CheckpointCoordinatorTest.generateChainedStateHandle(new SerializableObject()); KeyGroupRange keyGroupRange = KeyGroupRange.of(0,0); List testStates = Collections.singletonList(new SerializableObject()); - final List serializedKeyGroupStates = CheckpointCoordinatorTest.generateKeyGroupState(keyGroupRange, testStates); + final KeyGroupsStateHandle serializedKeyGroupStates = CheckpointCoordinatorTest.generateKeyGroupState(keyGroupRange, testStates); final JobID jid = new JobID(); final JobVertexID statefulId = new JobVertexID(); @@ -215,7 +223,8 @@ public void testStateOnlyPartiallyAvailable() { final long checkpointId = pending.getCheckpointId(); // the difference to the test "testSetState" is that one stateful subtask does not report state - CheckpointStateHandles checkpointStateHandles = new CheckpointStateHandles(serializedState, null, serializedKeyGroupStates); + SubtaskState checkpointStateHandles = + new SubtaskState(serializedState, null, null, serializedKeyGroupStates, null, 0L); CheckpointMetaData checkpointMetaData = new CheckpointMetaData(checkpointId, 0L); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java index baa0e08a2a54f..6b0d3f8852a1e 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java @@ -206,7 +206,7 @@ protected TestCompletedCheckpoint createCheckpoint(int id, int numberOfStates, C ChainedStateHandle stateHandle = CheckpointCoordinatorTest.generateChainedStateHandle( new CheckpointMessagesTest.MyHandle()); - taskState.putState(i, new SubtaskState(stateHandle, 0)); + taskState.putState(i, new SubtaskState(stateHandle, null, null, null, null, 0L)); } return new TestCompletedCheckpoint(new JobID(), id, 0, taskGroupStates, props); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1SerializerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1SerializerTest.java index bad836bf7499f..508a69dbd1392 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1SerializerTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1SerializerTest.java @@ -19,11 +19,13 @@ package org.apache.flink.runtime.checkpoint.savepoint; import org.apache.commons.io.output.ByteArrayOutputStream; +import org.apache.flink.core.memory.ByteArrayOutputStreamWithPos; import org.apache.flink.core.memory.DataInputViewStreamWrapper; import org.apache.flink.core.memory.DataOutputViewStreamWrapper; import org.junit.Test; import java.io.ByteArrayInputStream; +import java.util.Random; import static org.junit.Assert.assertEquals; @@ -34,19 +36,23 @@ public class SavepointV1SerializerTest { */ @Test public void testSerializeDeserializeV1() throws Exception { - SavepointV1 expected = new SavepointV1(123123, SavepointV1Test.createTaskStates(8, 32)); + Random r = new Random(42); + for (int i = 0; i < 100; ++i) { + SavepointV1 expected = + new SavepointV1(i+ 123123, SavepointV1Test.createTaskStates(1 + r.nextInt(64), 1 + r.nextInt(64))); - SavepointV1Serializer serializer = SavepointV1Serializer.INSTANCE; + SavepointV1Serializer serializer = SavepointV1Serializer.INSTANCE; - // Serialize - ByteArrayOutputStream baos = new ByteArrayOutputStream(); - serializer.serialize(expected, new DataOutputViewStreamWrapper(baos)); - byte[] bytes = baos.toByteArray(); + // Serialize + ByteArrayOutputStreamWithPos baos = new ByteArrayOutputStreamWithPos(); + serializer.serialize(expected, new DataOutputViewStreamWrapper(baos)); + byte[] bytes = baos.toByteArray(); - // Deserialize - ByteArrayInputStream bais = new ByteArrayInputStream(bytes); - Savepoint actual = serializer.deserialize(new DataInputViewStreamWrapper(bais)); + // Deserialize + ByteArrayInputStream bais = new ByteArrayInputStream(bytes); + Savepoint actual = serializer.deserialize(new DataInputViewStreamWrapper(bais)); - assertEquals(expected, actual); + assertEquals(expected, actual); + } } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Test.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Test.java index e38e5fba11ea1..1ae74ffdf4546 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Test.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Test.java @@ -32,10 +32,10 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Collection; -import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Random; import java.util.concurrent.ThreadLocalRandom; import static org.junit.Assert.assertEquals; @@ -66,35 +66,83 @@ public void testSavepointV1() throws Exception { assertTrue(savepoint.getTaskStates().isEmpty()); } - static Collection createTaskStates(int numTaskStates, int numSubtaskStates) throws IOException { + static Collection createTaskStates(int numTaskStates, int numSubtasksPerTask) throws IOException { + + Random random = new Random(numTaskStates * 31 + numSubtasksPerTask); + List taskStates = new ArrayList<>(numTaskStates); - for (int i = 0; i < numTaskStates; i++) { - TaskState taskState = new TaskState(new JobVertexID(), numSubtaskStates, numSubtaskStates, 1); - for (int j = 0; j < numSubtaskStates; j++) { - StreamStateHandle stateHandle = new TestByteStreamStateHandleDeepCompare("a", "Hello".getBytes()); - taskState.putState(i, new SubtaskState( - new ChainedStateHandle<>(Collections.singletonList(stateHandle)), 0)); - - stateHandle = new TestByteStreamStateHandleDeepCompare("b", "Beautiful".getBytes()); - Map offsetsMap = new HashMap<>(); - offsetsMap.put("A", new long[]{0, 10, 20}); - offsetsMap.put("B", new long[]{30, 40, 50}); - - OperatorStateHandle operatorStateHandle = - new OperatorStateHandle(stateHandle, offsetsMap); - - taskState.putPartitionableState( - i, - new ChainedStateHandle( - Collections.singletonList(operatorStateHandle))); - } + for (int stateIdx = 0; stateIdx < numTaskStates; ++stateIdx) { + + int chainLength = 1 + random.nextInt(8); + + TaskState taskState = new TaskState(new JobVertexID(), numSubtasksPerTask, 128, chainLength); + + int noNonPartitionableStateAtIndex = random.nextInt(chainLength); + int noOperatorStateBackendAtIndex = random.nextInt(chainLength); + int noOperatorStateStreamAtIndex = random.nextInt(chainLength); + + boolean hasKeyedBackend = random.nextInt(4) != 0; + boolean hasKeyedStream = random.nextInt(4) != 0; + + for (int subtaskIdx = 0; subtaskIdx < numSubtasksPerTask; subtaskIdx++) { - taskState.putKeyedState( - 0, - new KeyGroupsStateHandle( + List nonPartitionableStates = new ArrayList<>(chainLength); + List operatorStatesBackend = new ArrayList<>(chainLength); + List operatorStatesStream = new ArrayList<>(chainLength); + + for (int chainIdx = 0; chainIdx < chainLength; ++chainIdx) { + + StreamStateHandle nonPartitionableState = + new TestByteStreamStateHandleDeepCompare("a-" + chainIdx, ("Hi-" + chainIdx).getBytes()); + StreamStateHandle operatorStateBackend = + new TestByteStreamStateHandleDeepCompare("b-" + chainIdx, ("Beautiful-" + chainIdx).getBytes()); + StreamStateHandle operatorStateStream = + new TestByteStreamStateHandleDeepCompare("b-" + chainIdx, ("Beautiful-" + chainIdx).getBytes()); + Map offsetsMap = new HashMap<>(); + offsetsMap.put("A", new long[]{0, 10, 20}); + offsetsMap.put("B", new long[]{30, 40, 50}); + + if (chainIdx != noNonPartitionableStateAtIndex) { + nonPartitionableStates.add(nonPartitionableState); + } + + if (chainIdx != noOperatorStateBackendAtIndex) { + OperatorStateHandle operatorStateHandleBackend = + new OperatorStateHandle(offsetsMap, operatorStateBackend); + operatorStatesBackend.add(operatorStateHandleBackend); + } + + if (chainIdx != noOperatorStateStreamAtIndex) { + OperatorStateHandle operatorStateHandleStream = + new OperatorStateHandle(offsetsMap, operatorStateStream); + operatorStatesStream.add(operatorStateHandleStream); + } + } + + KeyGroupsStateHandle keyedStateBackend = null; + KeyGroupsStateHandle keyedStateStream = null; + + if (hasKeyedBackend) { + keyedStateBackend = new KeyGroupsStateHandle( new KeyGroupRangeOffsets(1, 1, new long[]{42}), - new TestByteStreamStateHandleDeepCompare("c", "World".getBytes()))); + new TestByteStreamStateHandleDeepCompare("c", "Hello".getBytes())); + } + + if (hasKeyedStream) { + keyedStateStream = new KeyGroupsStateHandle( + new KeyGroupRangeOffsets(1, 1, new long[]{23}), + new TestByteStreamStateHandleDeepCompare("d", "World".getBytes())); + } + + taskState.putState(subtaskIdx, new SubtaskState( + new ChainedStateHandle<>(nonPartitionableStates), + new ChainedStateHandle<>(operatorStatesBackend), + new ChainedStateHandle<>(operatorStatesStream), + keyedStateStream, + keyedStateBackend, + subtaskIdx * 10L)); + } taskStates.add(taskState); } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/stats/SimpleCheckpointStatsTrackerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/stats/SimpleCheckpointStatsTrackerTest.java index 2dac87f2b9451..50a59a5134a18 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/stats/SimpleCheckpointStatsTrackerTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/stats/SimpleCheckpointStatsTrackerTest.java @@ -335,8 +335,7 @@ private static CompletedCheckpoint[] generateRandomCheckpoints( StreamStateHandle proxy = new StateHandleProxy(new Path(), proxySize); SubtaskState subtaskState = new SubtaskState( - new ChainedStateHandle<>(Collections.singletonList(proxy)), - duration); + new ChainedStateHandle<>(Collections.singletonList(proxy)), null, null, null, null, duration); taskState.putState(subtaskIndex, subtaskState); } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java index 5ec699167798d..b195858edf1da 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java @@ -37,6 +37,7 @@ import org.apache.flink.runtime.checkpoint.CompletedCheckpoint; import org.apache.flink.runtime.checkpoint.CompletedCheckpointStore; import org.apache.flink.runtime.checkpoint.StandaloneCheckpointIDCounter; +import org.apache.flink.runtime.checkpoint.SubtaskState; import org.apache.flink.runtime.clusterframework.types.ResourceID; import org.apache.flink.runtime.execution.librarycache.BlobLibraryCacheManager; import org.apache.flink.runtime.executiongraph.restart.FixedDelayRestartStrategy; @@ -57,10 +58,8 @@ import org.apache.flink.runtime.leaderretrieval.LeaderRetrievalService; import org.apache.flink.runtime.messages.JobManagerMessages; import org.apache.flink.runtime.state.ChainedStateHandle; -import org.apache.flink.runtime.state.CheckpointStateHandles; -import org.apache.flink.runtime.state.KeyGroupsStateHandle; -import org.apache.flink.runtime.state.OperatorStateHandle; import org.apache.flink.runtime.state.StreamStateHandle; +import org.apache.flink.runtime.state.TaskStateHandles; import org.apache.flink.runtime.state.memory.ByteStreamStateHandle; import org.apache.flink.runtime.taskmanager.TaskManager; import org.apache.flink.runtime.testingUtils.TestingJobManager; @@ -85,7 +84,6 @@ import java.util.ArrayDeque; import java.util.ArrayList; -import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -446,12 +444,10 @@ public static class BlockingStatefulInvokable extends BlockingInvokable implemen @Override public void setInitialState( - ChainedStateHandle chainedState, - List keyGroupsState, - List> partitionableOperatorState) throws Exception { + TaskStateHandles taskStateHandles) throws Exception { int subtaskIndex = getIndexInSubtaskGroup(); if (subtaskIndex < recoveredStates.length) { - try (FSDataInputStream in = chainedState.get(0).openInputStream()) { + try (FSDataInputStream in = taskStateHandles.getLegacyOperatorState().get(0).openInputStream()) { recoveredStates[subtaskIndex] = InstantiationUtil.deserializeObject(in, getUserCodeClassLoader()); } } @@ -466,9 +462,8 @@ public boolean triggerCheckpoint(CheckpointMetaData checkpointMetaData) { ChainedStateHandle chainedStateHandle = new ChainedStateHandle(Collections.singletonList(byteStreamStateHandle)); - - CheckpointStateHandles checkpointStateHandles = - new CheckpointStateHandles(chainedStateHandle, null, Collections.emptyList()); + SubtaskState checkpointStateHandles = + new SubtaskState(chainedStateHandle, null, null, null, null, 0L); getEnvironment().acknowledgeCheckpoint( new CheckpointMetaData(checkpointMetaData.getCheckpointId(), -1, 0L, 0L, 0L, 0L), diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java index 305625e83504e..3521630ee46e3 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java @@ -23,12 +23,12 @@ import org.apache.flink.core.testutils.CommonTestUtils; import org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTest; import org.apache.flink.runtime.checkpoint.CheckpointMetaData; +import org.apache.flink.runtime.checkpoint.SubtaskState; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint; import org.apache.flink.runtime.messages.checkpoint.NotifyCheckpointComplete; import org.apache.flink.runtime.messages.checkpoint.TriggerCheckpoint; -import org.apache.flink.runtime.state.CheckpointStateHandles; import org.apache.flink.runtime.state.KeyGroupRange; import org.apache.flink.runtime.state.StreamStateHandle; import org.junit.Test; @@ -67,11 +67,14 @@ public void testConfirmTaskCheckpointed() { KeyGroupRange keyGroupRange = KeyGroupRange.of(42,42); - CheckpointStateHandles checkpointStateHandles = - new CheckpointStateHandles( + SubtaskState checkpointStateHandles = + new SubtaskState( CheckpointCoordinatorTest.generateChainedStateHandle(new MyHandle()), - CheckpointCoordinatorTest.generateChainedPartitionableStateHandle(new JobVertexID(), 0, 2, 8), - CheckpointCoordinatorTest.generateKeyGroupState(keyGroupRange, Collections.singletonList(new MyHandle()))); + CheckpointCoordinatorTest.generateChainedPartitionableStateHandle(new JobVertexID(), 0, 2, 8, false), + null, + CheckpointCoordinatorTest.generateKeyGroupState(keyGroupRange, Collections.singletonList(new MyHandle())), + null, + 0L); AcknowledgeCheckpoint withState = new AcknowledgeCheckpoint( new JobID(), diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java index 04ba4e520ba32..f2616b569143d 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java @@ -26,6 +26,7 @@ import org.apache.flink.runtime.accumulators.AccumulatorRegistry; import org.apache.flink.runtime.broadcast.BroadcastVariableManager; import org.apache.flink.runtime.checkpoint.CheckpointMetaData; +import org.apache.flink.runtime.checkpoint.SubtaskState; import org.apache.flink.runtime.execution.Environment; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.io.disk.iomanager.IOManager; @@ -37,7 +38,6 @@ import org.apache.flink.runtime.metrics.groups.TaskMetricGroup; import org.apache.flink.runtime.query.KvStateRegistry; import org.apache.flink.runtime.query.TaskKvStateRegistry; -import org.apache.flink.runtime.state.CheckpointStateHandles; import org.apache.flink.runtime.taskmanager.TaskManagerRuntimeInfo; import java.util.Collections; @@ -155,8 +155,7 @@ public void acknowledgeCheckpoint(CheckpointMetaData checkpointMetaData) { } @Override - public void acknowledgeCheckpoint( - CheckpointMetaData checkpointMetaData, CheckpointStateHandles checkpointStateHandles) { + public void acknowledgeCheckpoint(CheckpointMetaData checkpointMetaData, SubtaskState subtaskState) { } @Override diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java index eb55c4de48679..08b84cb42277f 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java @@ -28,6 +28,7 @@ import org.apache.flink.runtime.accumulators.AccumulatorRegistry; import org.apache.flink.runtime.broadcast.BroadcastVariableManager; import org.apache.flink.runtime.checkpoint.CheckpointMetaData; +import org.apache.flink.runtime.checkpoint.SubtaskState; import org.apache.flink.runtime.execution.Environment; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.io.disk.iomanager.IOManager; @@ -46,7 +47,6 @@ import org.apache.flink.runtime.metrics.groups.TaskMetricGroup; import org.apache.flink.runtime.query.KvStateRegistry; import org.apache.flink.runtime.query.TaskKvStateRegistry; -import org.apache.flink.runtime.state.CheckpointStateHandles; import org.apache.flink.runtime.taskmanager.TaskManagerRuntimeInfo; import org.apache.flink.types.Record; import org.apache.flink.util.MutableObjectIterator; @@ -316,8 +316,7 @@ public void acknowledgeCheckpoint(CheckpointMetaData checkpointMetaData) { } @Override - public void acknowledgeCheckpoint( - CheckpointMetaData checkpointMetaData, CheckpointStateHandles checkpointStateHandles) { + public void acknowledgeCheckpoint(CheckpointMetaData checkpointMetaData, SubtaskState subtaskState) { } @Override diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/KeyGroupRangeOffsetTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/KeyGroupRangeOffsetTest.java index 95564cc95da16..fb24712e0013f 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/KeyGroupRangeOffsetTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/KeyGroupRangeOffsetTest.java @@ -45,7 +45,7 @@ public void testKeyGroupIntersection() { keyGroupRangeOffsets.getKeyGroupRange())); intersection = keyGroupRangeOffsets.getIntersection(KeyGroupRange.of(11, 13)); - Assert.assertEquals(KeyGroupRange.EMPTY_KEY_GROUP, intersection.getKeyGroupRange()); + Assert.assertEquals(KeyGroupRange.EMPTY_KEY_GROUP_RANGE, intersection.getKeyGroupRange()); Assert.assertFalse(intersection.iterator().hasNext()); intersection = keyGroupRangeOffsets.getIntersection(KeyGroupRange.of(5, 13)); @@ -129,7 +129,7 @@ private void testKeyGroupRangeOffsetsBasicsInternal(int startKeyGroup, int endKe Assert.assertFalse(keyGroupRange.getKeyGroupRange().contains(startKeyGroup - 1)); Assert.assertFalse(keyGroupRange.getKeyGroupRange().contains(endKeyGroup + 1)); } else { - Assert.assertEquals(KeyGroupRange.EMPTY_KEY_GROUP, keyGroupRange); + Assert.assertEquals(KeyGroupRange.EMPTY_KEY_GROUP_RANGE, keyGroupRange); } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/KeyGroupRangeTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/KeyGroupRangeTest.java index ab0c32715ba8c..94350ad4f5133 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/KeyGroupRangeTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/KeyGroupRangeTest.java @@ -37,7 +37,7 @@ public void testKeyGroupIntersection() { keyGroupRange1 = KeyGroupRange.of(0,5); keyGroupRange2 = KeyGroupRange.of(6,10); intersection =keyGroupRange1.getIntersection(keyGroupRange2); - Assert.assertEquals(KeyGroupRange.EMPTY_KEY_GROUP, intersection); + Assert.assertEquals(KeyGroupRange.EMPTY_KEY_GROUP_RANGE, intersection); Assert.assertEquals(intersection, keyGroupRange2.getIntersection(keyGroupRange1)); keyGroupRange1 = KeyGroupRange.of(0, 10); @@ -93,7 +93,7 @@ private void testKeyGroupRangeBasicsInternal(int startKeyGroup, int endKeyGroup) Assert.assertFalse(keyGroupRange.contains(startKeyGroup - 1)); Assert.assertFalse(keyGroupRange.contains(endKeyGroup + 1)); } else { - Assert.assertEquals(KeyGroupRange.EMPTY_KEY_GROUP, keyGroupRange); + Assert.assertEquals(KeyGroupRange.EMPTY_KEY_GROUP_RANGE, keyGroupRange); } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/KeyedStateCheckpointOutputStreamTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/KeyedStateCheckpointOutputStreamTest.java new file mode 100644 index 0000000000000..0c4ed742406bf --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/KeyedStateCheckpointOutputStreamTest.java @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.runtime.state; + +import org.apache.flink.core.fs.FSDataInputStream; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataInputViewStreamWrapper; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.core.memory.DataOutputViewStreamWrapper; +import org.junit.Assert; +import org.junit.Test; + +import java.io.IOException; + +public class KeyedStateCheckpointOutputStreamTest { + + private static final int STREAM_CAPACITY = 128; + + private static KeyedStateCheckpointOutputStream createStream(KeyGroupRange keyGroupRange) { + CheckpointStreamFactory.CheckpointStateOutputStream checkStream = + new TestMemoryCheckpointOutputStream(STREAM_CAPACITY); + return new KeyedStateCheckpointOutputStream(checkStream, keyGroupRange); + } + + private KeyGroupsStateHandle writeAllTestKeyGroups( + KeyedStateCheckpointOutputStream stream, KeyGroupRange keyRange) throws Exception { + + DataOutputView dov = new DataOutputViewStreamWrapper(stream); + for (int kg : keyRange) { + stream.startNewKeyGroup(kg); + dov.writeInt(kg); + } + + return stream.closeAndGetHandle(); + } + + @Test + public void testCloseNotPropagated() throws Exception { + KeyedStateCheckpointOutputStream stream = createStream(new KeyGroupRange(0, 0)); + TestMemoryCheckpointOutputStream innerStream = (TestMemoryCheckpointOutputStream) stream.getDelegate(); + stream.close(); + Assert.assertFalse(innerStream.isClosed()); + } + + @Test + public void testEmptyKeyedStream() throws Exception { + final KeyGroupRange keyRange = new KeyGroupRange(0, 2); + KeyedStateCheckpointOutputStream stream = createStream(keyRange); + TestMemoryCheckpointOutputStream innerStream = (TestMemoryCheckpointOutputStream) stream.getDelegate(); + KeyGroupsStateHandle emptyHandle = stream.closeAndGetHandle(); + Assert.assertTrue(innerStream.isClosed()); + Assert.assertEquals(null, emptyHandle); + } + + @Test + public void testWriteReadRoundtrip() throws Exception { + final KeyGroupRange keyRange = new KeyGroupRange(0, 2); + KeyedStateCheckpointOutputStream stream = createStream(keyRange); + KeyGroupsStateHandle fullHandle = writeAllTestKeyGroups(stream, keyRange); + Assert.assertNotNull(fullHandle); + + verifyRead(fullHandle, keyRange); + } + + @Test + public void testWriteKeyGroupTracking() throws Exception { + final KeyGroupRange keyRange = new KeyGroupRange(0, 2); + KeyedStateCheckpointOutputStream stream = createStream(keyRange); + + try { + stream.startNewKeyGroup(4711); + Assert.fail(); + } catch (IllegalArgumentException expected) { + // good + } + + Assert.assertEquals(-1, stream.getCurrentKeyGroup()); + + DataOutputView dov = new DataOutputViewStreamWrapper(stream); + int previous = -1; + for (int kg : keyRange) { + Assert.assertFalse(stream.isKeyGroupAlreadyStarted(kg)); + Assert.assertFalse(stream.isKeyGroupAlreadyFinished(kg)); + stream.startNewKeyGroup(kg); + if(-1 != previous) { + Assert.assertTrue(stream.isKeyGroupAlreadyStarted(previous)); + Assert.assertTrue(stream.isKeyGroupAlreadyFinished(previous)); + } + Assert.assertTrue(stream.isKeyGroupAlreadyStarted(kg)); + Assert.assertFalse(stream.isKeyGroupAlreadyFinished(kg)); + dov.writeInt(kg); + previous = kg; + } + + KeyGroupsStateHandle fullHandle = stream.closeAndGetHandle(); + + verifyRead(fullHandle, keyRange); + + for (int kg : keyRange) { + try { + stream.startNewKeyGroup(kg); + Assert.fail(); + } catch (IOException ex) { + // required + } + } + } + + @Test + public void testReadWriteMissingKeyGroups() throws Exception { + final KeyGroupRange keyRange = new KeyGroupRange(0, 2); + KeyedStateCheckpointOutputStream stream = createStream(keyRange); + + DataOutputView dov = new DataOutputViewStreamWrapper(stream); + stream.startNewKeyGroup(1); + dov.writeInt(1); + + KeyGroupsStateHandle fullHandle = stream.closeAndGetHandle(); + + int count = 0; + try (FSDataInputStream in = fullHandle.openInputStream()) { + DataInputView div = new DataInputViewStreamWrapper(in); + for (int kg : fullHandle.keyGroups()) { + long off = fullHandle.getOffsetForKeyGroup(kg); + if (off >= 0) { + in.seek(off); + Assert.assertEquals(1, div.readInt()); + ++count; + } + } + } + + Assert.assertEquals(1, count); + } + + private static void verifyRead(KeyGroupsStateHandle fullHandle, KeyGroupRange keyRange) throws IOException { + int count = 0; + try (FSDataInputStream in = fullHandle.openInputStream()) { + DataInputView div = new DataInputViewStreamWrapper(in); + for (int kg : fullHandle.keyGroups()) { + long off = fullHandle.getOffsetForKeyGroup(kg); + in.seek(off); + Assert.assertEquals(kg, div.readInt()); + ++count; + } + } + + Assert.assertEquals(keyRange.getNumberOfKeyGroups(), count); + } +} \ No newline at end of file diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateOutputCheckpointStreamTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateOutputCheckpointStreamTest.java new file mode 100644 index 0000000000000..c6ef0f061564a --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateOutputCheckpointStreamTest.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state; + +import org.apache.flink.core.fs.FSDataInputStream; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataInputViewStreamWrapper; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.core.memory.DataOutputViewStreamWrapper; +import org.junit.Assert; +import org.junit.Test; + +import java.io.IOException; + +public class OperatorStateOutputCheckpointStreamTest { + + private static final int STREAM_CAPACITY = 128; + + private static OperatorStateCheckpointOutputStream createStream() throws IOException { + CheckpointStreamFactory.CheckpointStateOutputStream checkStream = + new TestMemoryCheckpointOutputStream(STREAM_CAPACITY); + return new OperatorStateCheckpointOutputStream(checkStream); + } + + private OperatorStateHandle writeAllTestKeyGroups( + OperatorStateCheckpointOutputStream stream, int numPartitions) throws Exception { + + DataOutputView dov = new DataOutputViewStreamWrapper(stream); + for (int i = 0; i < numPartitions; ++i) { + Assert.assertEquals(i, stream.getNumberOfPartitions()); + stream.startNewPartition(); + dov.writeInt(i); + } + + return stream.closeAndGetHandle(); + } + + @Test + public void testCloseNotPropagated() throws Exception { + OperatorStateCheckpointOutputStream stream = createStream(); + TestMemoryCheckpointOutputStream innerStream = (TestMemoryCheckpointOutputStream) stream.getDelegate(); + stream.close(); + Assert.assertFalse(innerStream.isClosed()); + innerStream.close(); + } + + @Test + public void testEmptyOperatorStream() throws Exception { + OperatorStateCheckpointOutputStream stream = createStream(); + TestMemoryCheckpointOutputStream innerStream = (TestMemoryCheckpointOutputStream) stream.getDelegate(); + OperatorStateHandle emptyHandle = stream.closeAndGetHandle(); + Assert.assertTrue(innerStream.isClosed()); + Assert.assertEquals(0, stream.getNumberOfPartitions()); + Assert.assertEquals(null, emptyHandle); + } + + @Test + public void testWriteReadRoundtrip() throws Exception { + int numPartitions = 3; + OperatorStateCheckpointOutputStream stream = createStream(); + OperatorStateHandle fullHandle = writeAllTestKeyGroups(stream, numPartitions); + Assert.assertNotNull(fullHandle); + + verifyRead(fullHandle, numPartitions); + } + + private static void verifyRead(OperatorStateHandle fullHandle, int numPartitions) throws IOException { + int count = 0; + try (FSDataInputStream in = fullHandle.openInputStream()) { + long[] offsets = fullHandle.getStateNameToPartitionOffsets(). + get(DefaultOperatorStateBackend.DEFAULT_OPERATOR_STATE_NAME); + + Assert.assertNotNull(offsets); + + DataInputView div = new DataInputViewStreamWrapper(in); + for (int i = 0; i < numPartitions; ++i) { + in.seek(offsets[i]); + Assert.assertEquals(i, div.readInt()); + ++count; + } + } + + Assert.assertEquals(numPartitions, count); + } + +} \ No newline at end of file diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java index 2f215741167d3..9e835ce96e9ff 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java @@ -38,7 +38,7 @@ import org.apache.flink.api.common.typeutils.base.LongSerializer; import org.apache.flink.api.common.typeutils.base.StringSerializer; import org.apache.flink.core.memory.DataOutputViewStreamWrapper; -import org.apache.flink.runtime.checkpoint.CheckpointCoordinator; +import org.apache.flink.runtime.checkpoint.StateAssignmentOperation; import org.apache.flink.runtime.execution.Environment; import org.apache.flink.runtime.operators.testutils.DummyEnvironment; import org.apache.flink.runtime.query.KvStateID; @@ -707,11 +707,11 @@ public void testKeyGroupSnapshotRestore() throws Exception { KeyGroupsStateHandle snapshot = runSnapshot(backend.snapshot(0, 0, streamFactory)); - List firstHalfKeyGroupStates = CheckpointCoordinator.getKeyGroupsStateHandles( + List firstHalfKeyGroupStates = StateAssignmentOperation.getKeyGroupsStateHandles( Collections.singletonList(snapshot), KeyGroupRangeAssignment.computeKeyGroupRangeForOperatorIndex(MAX_PARALLELISM, 2, 0)); - List secondHalfKeyGroupStates = CheckpointCoordinator.getKeyGroupsStateHandles( + List secondHalfKeyGroupStates = StateAssignmentOperation.getKeyGroupsStateHandles( Collections.singletonList(snapshot), KeyGroupRangeAssignment.computeKeyGroupRangeForOperatorIndex(MAX_PARALLELISM, 2, 1)); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/TestMemoryCheckpointOutputStream.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/TestMemoryCheckpointOutputStream.java new file mode 100644 index 0000000000000..5accc1944d948 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/TestMemoryCheckpointOutputStream.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state; + +import org.apache.flink.runtime.state.memory.MemCheckpointStreamFactory; + +import java.io.IOException; + +final class TestMemoryCheckpointOutputStream extends MemCheckpointStreamFactory.MemoryCheckpointOutputStream { + + private boolean closed; + + public TestMemoryCheckpointOutputStream(int maxSize) { + super(maxSize); + this.closed = false; + } + + @Override + public void close() { + this.closed = true; + super.close(); + } + + public boolean isClosed() { + return this.closed; + } + + @Override + public StreamStateHandle closeAndGetHandle() throws IOException { + this.closed = true; + return super.closeAndGetHandle(); + } +} \ No newline at end of file diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java index e2abe88c0ab86..7dd67ed46c00f 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java @@ -48,6 +48,7 @@ import org.apache.flink.runtime.state.KeyGroupsStateHandle; import org.apache.flink.runtime.state.OperatorStateHandle; import org.apache.flink.runtime.state.StreamStateHandle; +import org.apache.flink.runtime.state.TaskStateHandles; import org.apache.flink.util.SerializedValue; import org.junit.Before; import org.junit.Test; @@ -209,9 +210,7 @@ public void invoke() throws Exception { } @Override - public void setInitialState(ChainedStateHandle chainedState, - List keyGroupsState, - List> partitionableOperatorState) throws Exception { + public void setInitialState(TaskStateHandles taskStateHandles) throws Exception { } diff --git a/flink-streaming-connectors/flink-connector-filesystem/src/test/java/org/apache/flink/streaming/connectors/fs/bucketing/BucketingSinkTest.java b/flink-streaming-connectors/flink-connector-filesystem/src/test/java/org/apache/flink/streaming/connectors/fs/bucketing/BucketingSinkTest.java index ac1e3f05ad4db..0c0111cd35aaf 100644 --- a/flink-streaming-connectors/flink-connector-filesystem/src/test/java/org/apache/flink/streaming/connectors/fs/bucketing/BucketingSinkTest.java +++ b/flink-streaming-connectors/flink-connector-filesystem/src/test/java/org/apache/flink/streaming/connectors/fs/bucketing/BucketingSinkTest.java @@ -137,7 +137,7 @@ public void testCheckpointWithoutNotify() throws Exception { // snapshot but don't call notify to simulate a notify that never // arrives, the sink should move pending files in restore() in that case - StreamStateHandle snapshot1 = testHarness.snapshot(0, 0); + StreamStateHandle snapshot1 = testHarness.snapshotLegacy(0, 0); testHarness = createTestSink(dataDir, clock); testHarness.setup(); diff --git a/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBase.java b/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBase.java index 05028e6a870c1..76101dc8f468c 100644 --- a/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBase.java +++ b/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBase.java @@ -19,13 +19,16 @@ import org.apache.commons.collections.map.LinkedMap; import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.OperatorStateStore; import org.apache.flink.api.common.typeinfo.TypeInformation; -import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.ClosureCleaner; +import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.typeutils.ResultTypeQueryable; import org.apache.flink.configuration.Configuration; import org.apache.flink.runtime.state.CheckpointListener; -import org.apache.flink.api.common.state.OperatorStateStore; +import org.apache.flink.runtime.state.DefaultOperatorStateBackend; +import org.apache.flink.runtime.state.FunctionInitializationContext; +import org.apache.flink.runtime.state.FunctionSnapshotContext; import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction; import org.apache.flink.streaming.api.functions.AssignerWithPeriodicWatermarks; import org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks; @@ -37,11 +40,9 @@ import org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartitionState; import org.apache.flink.streaming.util.serialization.KeyedDeserializationSchema; import org.apache.flink.util.SerializedValue; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.io.Serializable; import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; @@ -97,7 +98,7 @@ public abstract class FlinkKafkaConsumerBase extends RichParallelSourceFuncti * The assigner is kept in serialized form, to deserialize it into multiple copies */ private SerializedValue> punctuatedWatermarkAssigner; - private transient OperatorStateStore stateStore; + private transient ListState> offsetsStateForCheckpoint; // ------------------------------------------------------------------------ // runtime state (used individually by each parallel subtask) @@ -311,33 +312,33 @@ public void close() throws Exception { // ------------------------------------------------------------------------ @Override - public void initializeState(OperatorStateStore stateStore) throws Exception { - - this.stateStore = stateStore; + public void initializeState(FunctionInitializationContext context) throws Exception { - ListState offsets = - stateStore.getSerializableListState(OperatorStateStore.DEFAULT_OPERATOR_STATE_NAME); + OperatorStateStore stateStore = context.getManagedOperatorStateStore(); + offsetsStateForCheckpoint = stateStore.getSerializableListState(DefaultOperatorStateBackend.DEFAULT_OPERATOR_STATE_NAME); - restoreToOffset = new HashMap<>(); + if (context.isRestored()) { + restoreToOffset = new HashMap<>(); + for (Tuple2 kafkaOffset : offsetsStateForCheckpoint.get()) { + restoreToOffset.put(kafkaOffset.f0, kafkaOffset.f1); + } - for (Serializable serializable : offsets.get()) { - @SuppressWarnings("unchecked") - Tuple2 kafkaOffset = (Tuple2) serializable; - restoreToOffset.put(kafkaOffset.f0, kafkaOffset.f1); + LOG.info("Setting restore state in the FlinkKafkaConsumer."); + if (LOG.isDebugEnabled()) { + LOG.debug("Using the following offsets: {}", restoreToOffset); + } + } else { + LOG.info("No restore state for FlinkKafkaConsumer."); } - - LOG.info("Setting restore state in the FlinkKafkaConsumer: {}", restoreToOffset); } @Override - public void prepareSnapshot(long checkpointId, long timestamp) throws Exception { + public void snapshotState(FunctionSnapshotContext context) throws Exception { if (!running) { - LOG.debug("storeOperatorState() called on closed source"); + LOG.debug("snapshotState() called on closed source"); } else { - ListState listState = - stateStore.getSerializableListState(OperatorStateStore.DEFAULT_OPERATOR_STATE_NAME); - listState.clear(); + offsetsStateForCheckpoint.clear(); final AbstractFetcher fetcher = this.kafkaFetcher; if (fetcher == null) { @@ -347,7 +348,7 @@ public void prepareSnapshot(long checkpointId, long timestamp) throws Exception if (restoreToOffset != null) { // the map cannot be asynchronously updated, because only one checkpoint call can happen // on this function at a time: either snapshotState() or notifyCheckpointComplete() - pendingCheckpoints.put(checkpointId, restoreToOffset); + pendingCheckpoints.put(context.getCheckpointId(), restoreToOffset); // truncate the map, to prevent infinite growth while (pendingCheckpoints.size() > MAX_NUM_PENDING_CHECKPOINTS) { @@ -355,11 +356,13 @@ public void prepareSnapshot(long checkpointId, long timestamp) throws Exception } for (Map.Entry kafkaTopicPartitionLongEntry : restoreToOffset.entrySet()) { - listState.add(Tuple2.of(kafkaTopicPartitionLongEntry.getKey(), kafkaTopicPartitionLongEntry.getValue())); + offsetsStateForCheckpoint.add( + Tuple2.of(kafkaTopicPartitionLongEntry.getKey(), kafkaTopicPartitionLongEntry.getValue())); } } else if (subscribedPartitions != null) { for (KafkaTopicPartition subscribedPartition : subscribedPartitions) { - listState.add(Tuple2.of(subscribedPartition, KafkaTopicPartitionState.OFFSET_NOT_SET)); + offsetsStateForCheckpoint.add( + Tuple2.of(subscribedPartition, KafkaTopicPartitionState.OFFSET_NOT_SET)); } } } else { @@ -367,7 +370,7 @@ public void prepareSnapshot(long checkpointId, long timestamp) throws Exception // the map cannot be asynchronously updated, because only one checkpoint call can happen // on this function at a time: either snapshotState() or notifyCheckpointComplete() - pendingCheckpoints.put(checkpointId, currentOffsets); + pendingCheckpoints.put(context.getCheckpointId(), currentOffsets); // truncate the map, to prevent infinite growth while (pendingCheckpoints.size() > MAX_NUM_PENDING_CHECKPOINTS) { @@ -375,7 +378,8 @@ public void prepareSnapshot(long checkpointId, long timestamp) throws Exception } for (Map.Entry kafkaTopicPartitionLongEntry : currentOffsets.entrySet()) { - listState.add(Tuple2.of(kafkaTopicPartitionLongEntry.getKey(), kafkaTopicPartitionLongEntry.getValue())); + offsetsStateForCheckpoint.add( + Tuple2.of(kafkaTopicPartitionLongEntry.getKey(), kafkaTopicPartitionLongEntry.getValue())); } } } diff --git a/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducerBase.java b/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducerBase.java index 26a695e86d50e..bede064f9d3d2 100644 --- a/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducerBase.java +++ b/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducerBase.java @@ -18,10 +18,12 @@ package org.apache.flink.streaming.connectors.kafka; import org.apache.flink.api.common.functions.RuntimeContext; +import org.apache.flink.api.common.state.OperatorStateStore; import org.apache.flink.api.java.ClosureCleaner; import org.apache.flink.configuration.Configuration; import org.apache.flink.metrics.MetricGroup; -import org.apache.flink.api.common.state.OperatorStateStore; +import org.apache.flink.runtime.state.FunctionInitializationContext; +import org.apache.flink.runtime.state.FunctionSnapshotContext; import org.apache.flink.runtime.util.SerializableObject; import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction; import org.apache.flink.streaming.api.functions.sink.RichSinkFunction; @@ -330,12 +332,12 @@ private void acknowledgeMessage() { protected abstract void flush(); @Override - public void initializeState(OperatorStateStore stateStore) throws Exception { - this.stateStore = stateStore; + public void initializeState(FunctionInitializationContext context) throws Exception { + this.stateStore = context.getManagedOperatorStateStore(); } @Override - public void prepareSnapshot(long checkpointId, long timestamp) throws Exception { + public void snapshotState(FunctionSnapshotContext ctx) throws Exception { if (flushOnCheckpoint) { // flushing is activated: We need to wait until pendingRecords is 0 flush(); diff --git a/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/AtLeastOnceProducerTest.java b/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/AtLeastOnceProducerTest.java index d2d7fca7df2e1..5e9bacc44c6f4 100644 --- a/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/AtLeastOnceProducerTest.java +++ b/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/AtLeastOnceProducerTest.java @@ -20,6 +20,8 @@ import org.apache.flink.api.java.tuple.Tuple1; import org.apache.flink.configuration.Configuration; +import org.apache.flink.runtime.state.FunctionSnapshotContext; +import org.apache.flink.runtime.state.StateSnapshotContextSynchronousImpl; import org.apache.flink.streaming.connectors.kafka.testutils.MockRuntimeContext; import org.apache.flink.streaming.util.serialization.KeyedSerializationSchema; import org.apache.flink.streaming.util.serialization.KeyedSerializationSchemaWrapper; @@ -112,7 +114,7 @@ public void run() { Thread threadB = new Thread(confirmer); threadB.start(); // this should block: - producer.prepareSnapshot(0, 0); + producer.snapshotState(new StateSnapshotContextSynchronousImpl(0, 0)); synchronized (threadA) { threadA.notifyAll(); // just in case, to let the test fail faster } @@ -148,9 +150,9 @@ protected KafkaProducer getKafkaProducer(Properties props) { } @Override - public void prepareSnapshot(long checkpointId, long timestamp) throws Exception { + public void snapshotState(FunctionSnapshotContext ctx) throws Exception { // call the actual snapshot state - super.prepareSnapshot(checkpointId, timestamp); + super.snapshotState(ctx); // notify test that snapshotting has been done snapshottingFinished.set(true); } diff --git a/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBaseTest.java b/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBaseTest.java index 6d2dc7038c176..1c7626c3fbc01 100644 --- a/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBaseTest.java +++ b/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBaseTest.java @@ -19,19 +19,26 @@ package org.apache.flink.streaming.connectors.kafka; import org.apache.commons.collections.map.LinkedMap; +import org.apache.flink.api.common.functions.RuntimeContext; import org.apache.flink.api.common.state.ListState; import org.apache.flink.api.common.state.ListStateDescriptor; -import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.common.state.OperatorStateStore; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContextSynchronousImpl; import org.apache.flink.streaming.api.functions.AssignerWithPeriodicWatermarks; import org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks; +import org.apache.flink.streaming.api.functions.source.SourceFunction; import org.apache.flink.streaming.api.operators.StreamingRuntimeContext; import org.apache.flink.streaming.connectors.kafka.internals.AbstractFetcher; import org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartition; import org.apache.flink.streaming.util.serialization.KeyedDeserializationSchema; import org.apache.flink.util.SerializedValue; +import org.junit.Assert; import org.junit.Test; import org.mockito.Matchers; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; import java.io.Serializable; import java.lang.reflect.Field; @@ -47,6 +54,8 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -100,7 +109,7 @@ public void ignoreCheckpointWhenNotRunning() throws Exception { TestingListState> listState = new TestingListState<>(); when(operatorStateStore.getOperatorState(Matchers.any(ListStateDescriptor.class))).thenReturn(listState); - consumer.prepareSnapshot(17L, 17L); + consumer.snapshotState(new StateSnapshotContextSynchronousImpl(1, 1)); assertFalse(listState.get().iterator().hasNext()); consumer.notifyCheckpointComplete(66L); @@ -113,24 +122,30 @@ public void ignoreCheckpointWhenNotRunning() throws Exception { public void checkRestoredCheckpointWhenFetcherNotReady() throws Exception { OperatorStateStore operatorStateStore = mock(OperatorStateStore.class); - TestingListState expectedState = new TestingListState<>(); - expectedState.add(Tuple2.of(new KafkaTopicPartition("abc", 13), 16768L)); - expectedState.add(Tuple2.of(new KafkaTopicPartition("def", 7), 987654321L)); - TestingListState listState = new TestingListState<>(); + listState.add(Tuple2.of(new KafkaTopicPartition("abc", 13), 16768L)); + listState.add(Tuple2.of(new KafkaTopicPartition("def", 7), 987654321L)); FlinkKafkaConsumerBase consumer = getConsumer(null, new LinkedMap(), true); - when(operatorStateStore.getSerializableListState(Matchers.any(String.class))).thenReturn(expectedState); - consumer.initializeState(operatorStateStore); - when(operatorStateStore.getSerializableListState(Matchers.any(String.class))).thenReturn(listState); - consumer.prepareSnapshot(17L, 17L); + StateInitializationContext initializationContext = mock(StateInitializationContext.class); + + when(initializationContext.getManagedOperatorStateStore()).thenReturn(operatorStateStore); + when(initializationContext.isRestored()).thenReturn(true); + + consumer.initializeState(initializationContext); + + consumer.snapshotState(new StateSnapshotContextSynchronousImpl(17, 17)); + + // ensure that the list was cleared and refilled. while this is an implementation detail, we use it here + // to figure out that snapshotState() actually did something. + Assert.assertTrue(listState.isClearCalled()); Set expected = new HashSet<>(); - for (Serializable serializable : expectedState.get()) { + for (Serializable serializable : listState.get()) { expected.add(serializable); } @@ -155,12 +170,40 @@ public void checkRestoredNullCheckpointWhenFetcherNotReady() throws Exception { TestingListState listState = new TestingListState<>(); when(operatorStateStore.getSerializableListState(Matchers.any(String.class))).thenReturn(listState); - consumer.initializeState(operatorStateStore); - consumer.prepareSnapshot(17L, 17L); + StateInitializationContext initializationContext = mock(StateInitializationContext.class); + + when(initializationContext.getManagedOperatorStateStore()).thenReturn(operatorStateStore); + when(initializationContext.isRestored()).thenReturn(false); + + consumer.initializeState(initializationContext); + + consumer.snapshotState(new StateSnapshotContextSynchronousImpl(17, 17)); assertFalse(listState.get().iterator().hasNext()); } + @Test + public void checkUseFetcherWhenNoCheckpoint() throws Exception { + + FlinkKafkaConsumerBase consumer = getConsumer(null, new LinkedMap(), true); + List partitionList = new ArrayList<>(1); + partitionList.add(new KafkaTopicPartition("test", 0)); + consumer.setSubscribedPartitions(partitionList); + + OperatorStateStore operatorStateStore = mock(OperatorStateStore.class); + TestingListState listState = new TestingListState<>(); + when(operatorStateStore.getSerializableListState(Matchers.any(String.class))).thenReturn(listState); + + StateInitializationContext initializationContext = mock(StateInitializationContext.class); + + when(initializationContext.getManagedOperatorStateStore()).thenReturn(operatorStateStore); + + // make the context signal that there is no restored state, then validate that + when(initializationContext.isRestored()).thenReturn(false); + consumer.initializeState(initializationContext); + consumer.run(mock(SourceFunction.SourceContext.class)); + } + @Test @SuppressWarnings("unchecked") public void testSnapshotState() throws Exception { @@ -186,22 +229,23 @@ public void testSnapshotState() throws Exception { OperatorStateStore backend = mock(OperatorStateStore.class); - TestingListState init = new TestingListState<>(); - TestingListState listState1 = new TestingListState<>(); - TestingListState listState2 = new TestingListState<>(); - TestingListState listState3 = new TestingListState<>(); + TestingListState listState = new TestingListState<>(); + + when(backend.getSerializableListState(Matchers.any(String.class))).thenReturn(listState); + + StateInitializationContext initializationContext = mock(StateInitializationContext.class); - when(backend.getSerializableListState(Matchers.any(String.class))). - thenReturn(init, listState1, listState2, listState3); + when(initializationContext.getManagedOperatorStateStore()).thenReturn(backend); + when(initializationContext.isRestored()).thenReturn(false, true, true, true); - consumer.initializeState(backend); + consumer.initializeState(initializationContext); // checkpoint 1 - consumer.prepareSnapshot(138L, 138L); + consumer.snapshotState(new StateSnapshotContextSynchronousImpl(138, 138)); HashMap snapshot1 = new HashMap<>(); - for (Serializable serializable : listState1.get()) { + for (Serializable serializable : listState.get()) { Tuple2 kafkaTopicPartitionLongTuple2 = (Tuple2) serializable; snapshot1.put(kafkaTopicPartitionLongTuple2.f0, kafkaTopicPartitionLongTuple2.f1); } @@ -211,11 +255,11 @@ public void testSnapshotState() throws Exception { assertEquals(state1, pendingCheckpoints.get(138L)); // checkpoint 2 - consumer.prepareSnapshot(140L, 140L); + consumer.snapshotState(new StateSnapshotContextSynchronousImpl(140, 140)); HashMap snapshot2 = new HashMap<>(); - for (Serializable serializable : listState2.get()) { + for (Serializable serializable : listState.get()) { Tuple2 kafkaTopicPartitionLongTuple2 = (Tuple2) serializable; snapshot2.put(kafkaTopicPartitionLongTuple2.f0, kafkaTopicPartitionLongTuple2.f1); } @@ -230,11 +274,11 @@ public void testSnapshotState() throws Exception { assertTrue(pendingCheckpoints.containsKey(140L)); // checkpoint 3 - consumer.prepareSnapshot(141L, 141L); + consumer.snapshotState(new StateSnapshotContextSynchronousImpl(141, 141)); HashMap snapshot3 = new HashMap<>(); - for (Serializable serializable : listState3.get()) { + for (Serializable serializable : listState.get()) { Tuple2 kafkaTopicPartitionLongTuple2 = (Tuple2) serializable; snapshot3.put(kafkaTopicPartitionLongTuple2.f0, kafkaTopicPartitionLongTuple2.f1); } @@ -252,12 +296,12 @@ public void testSnapshotState() throws Exception { assertEquals(0, pendingCheckpoints.size()); OperatorStateStore operatorStateStore = mock(OperatorStateStore.class); - TestingListState> listState = new TestingListState<>(); + listState = new TestingListState<>(); when(operatorStateStore.getOperatorState(Matchers.any(ListStateDescriptor.class))).thenReturn(listState); // create 500 snapshots for (int i = 100; i < 600; i++) { - consumer.prepareSnapshot(i, i); + consumer.snapshotState(new StateSnapshotContextSynchronousImpl(i, i)); listState.clear(); } assertEquals(FlinkKafkaConsumerBase.MAX_NUM_PENDING_CHECKPOINTS, pendingCheckpoints.size()); @@ -298,7 +342,7 @@ private static FlinkKafkaConsumerBase getConsumer( // ------------------------------------------------------------------------ - private static final class DummyFlinkKafkaConsumer extends FlinkKafkaConsumerBase { + private static class DummyFlinkKafkaConsumer extends FlinkKafkaConsumerBase { private static final long serialVersionUID = 1L; @SuppressWarnings("unchecked") @@ -308,22 +352,37 @@ public DummyFlinkKafkaConsumer() { @Override protected AbstractFetcher createFetcher(SourceContext sourceContext, List thisSubtaskPartitions, SerializedValue> watermarksPeriodic, SerializedValue> watermarksPunctuated, StreamingRuntimeContext runtimeContext) throws Exception { - return null; + AbstractFetcher fetcher = mock(AbstractFetcher.class); + doAnswer(new Answer() { + @Override + public Object answer(InvocationOnMock invocationOnMock) throws Throwable { + Assert.fail("Trying to restore offsets even though there was no restore state."); + return null; + } + }).when(fetcher).restoreOffsets(any(HashMap.class)); + return fetcher; } @Override protected List getKafkaPartitions(List topics) { return Collections.emptyList(); } + + @Override + public RuntimeContext getRuntimeContext() { + return mock(StreamingRuntimeContext.class); + } } private static final class TestingListState implements ListState { private final List list = new ArrayList<>(); + private boolean clearCalled = false; @Override public void clear() { list.clear(); + clearCalled = true; } @Override @@ -335,5 +394,13 @@ public Iterable get() throws Exception { public void add(T value) throws Exception { list.add(value); } + + public List getList() { + return list; + } + + public boolean isClearCalled() { + return clearCalled; + } } } diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/Checkpointed.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/Checkpointed.java index 4a0fd6066fc44..7af5ceab9213d 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/Checkpointed.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/Checkpointed.java @@ -38,7 +38,7 @@ */ @Deprecated @PublicEvolving -public interface Checkpointed { +public interface Checkpointed extends CheckpointedRestoring { /** * Gets the current state of the function of operator. The state must reflect the result of all @@ -56,14 +56,4 @@ public interface Checkpointed { * and to try again with the next checkpoint attempt. */ T snapshotState(long checkpointId, long checkpointTimestamp) throws Exception; - - /** - * Restores the state of the function or operator to that of a previous checkpoint. - * This method is invoked when a function is executed as part of a recovery run. - * - * Note that restoreState() is called before open(). - * - * @param state The state to be restored. - */ - void restoreState(T state) throws Exception; } diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/CheckpointedFunction.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/CheckpointedFunction.java index 777cb91025a89..37d8244d2f187 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/CheckpointedFunction.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/CheckpointedFunction.java @@ -20,46 +20,48 @@ import org.apache.flink.annotation.PublicEvolving; import org.apache.flink.api.common.state.OperatorStateStore; +import org.apache.flink.runtime.state.FunctionInitializationContext; +import org.apache.flink.runtime.state.FunctionSnapshotContext; /** * * Similar to @{@link Checkpointed}, this interface must be implemented by functions that have potentially * repartitionable state that needs to be checkpointed. Methods from this interface are called upon checkpointing and - * restoring of state. + * initialization of state. * - * On #initializeState the implementing class receives the {@link OperatorStateStore} - * to store it's state. At least before each snapshot, all state persistent state must be stored in the state store. + * On {@link #initializeState(FunctionInitializationContext)} the implementing class receives a + * {@link FunctionInitializationContext} which provides access to the {@link OperatorStateStore} (all) and + * {@link org.apache.flink.api.common.state.KeyedStateStore} (only for keyed operators). Those allow to register + * managed operator / keyed user states. Furthermore, the context provides information whether or the operator was + * restored. * - * When the backend is received for initialization, the user registers states with the backend via - * {@link org.apache.flink.api.common.state.StateDescriptor}. Then, all previously stored state is found in the - * received {@link org.apache.flink.api.common.state.State} (currently only - * {@link org.apache.flink.api.common.state.ListState} is supported. * - * In #prepareSnapshot, the implementing class must ensure that all operator state is passed to the operator backend, - * i.e. that the state was stored in the relevant {@link org.apache.flink.api.common.state.State} instances that - * are requested on restore. Notice that users might want to clear and reinsert the complete state first if incremental - * updates of the states are not possible. + * In {@link #snapshotState(FunctionSnapshotContext)} the implementing class must ensure that all operator / keyed state + * is passed to user states that have been registered during initialization, so that it is visible to the system + * backends for checkpointing. + * */ @PublicEvolving public interface CheckpointedFunction { /** + * This method is called when a snapshot for a checkpoint is requested. This acts as a hook to the function to + * ensure that all state is exposed by means previously offered through {@link FunctionInitializationContext} when + * the Function was initialized, or offered now by {@link FunctionSnapshotContext} itself. * - * This method is called when state should be stored for a checkpoint. The state can be registered and written to - * the provided backend. - * - * @param checkpointId Id of the checkpoint to perform - * @param timestamp Timestamp of the checkpoint + * @param context the context for drawing a snapshot of the operator * @throws Exception */ - void prepareSnapshot(long checkpointId, long timestamp) throws Exception; + void snapshotState(FunctionSnapshotContext context) throws Exception; /** - * This method is called when an operator is opened, so that the function can set the state backend to which it - * hands it's state on snapshot. + * This method is called when an operator is initialized, so that the function can set up it's state through + * the provided context. Initialization typically includes registering user states through the state stores + * that the context offers. * - * @param stateStore the state store to which this function stores it's state + * @param context the context for initializing the operator * @throws Exception */ - void initializeState(OperatorStateStore stateStore) throws Exception; + void initializeState(FunctionInitializationContext context) throws Exception; + } diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/CheckpointedRestoring.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/CheckpointedRestoring.java new file mode 100644 index 0000000000000..c0dd361c70864 --- /dev/null +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/CheckpointedRestoring.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.api.checkpoint; + +import org.apache.flink.annotation.PublicEvolving; + +import java.io.Serializable; + +/** + * This deprecated interface contains the methods for restoring from the legacy checkpointing mechanism of state. + * @param type of the restored state. + */ +@Deprecated +@PublicEvolving +public interface CheckpointedRestoring { + /** + * Restores the state of the function or operator to that of a previous checkpoint. + * This method is invoked when a function is executed as part of a recovery run. + * + * Note that restoreState() is called before open(). + * + * @param state The state to be restored. + */ + void restoreState(T state) throws Exception; +} diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java index 167dfb04ad2d6..279e828faea36 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java @@ -22,38 +22,47 @@ import org.apache.commons.math3.stat.descriptive.DescriptiveStatistics; import org.apache.flink.annotation.PublicEvolving; import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.state.KeyedStateStore; import org.apache.flink.api.common.state.State; import org.apache.flink.api.common.state.StateDescriptor; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.configuration.ConfigConstants; import org.apache.flink.configuration.Configuration; +import org.apache.flink.core.fs.FSDataInputStream; import org.apache.flink.metrics.Counter; import org.apache.flink.metrics.Gauge; import org.apache.flink.metrics.MetricGroup; import org.apache.flink.runtime.state.AbstractKeyedStateBackend; import org.apache.flink.runtime.state.CheckpointStreamFactory; +import org.apache.flink.runtime.state.DefaultKeyedStateStore; import org.apache.flink.runtime.state.KeyGroupRange; import org.apache.flink.runtime.state.KeyGroupRangeAssignment; +import org.apache.flink.runtime.state.KeyGroupsStateHandle; import org.apache.flink.runtime.state.KeyedStateBackend; import org.apache.flink.runtime.state.OperatorStateBackend; import org.apache.flink.runtime.state.OperatorStateHandle; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateInitializationContextImpl; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.runtime.state.StateSnapshotContextSynchronousImpl; +import org.apache.flink.runtime.state.StreamStateHandle; import org.apache.flink.runtime.state.VoidNamespace; import org.apache.flink.runtime.state.VoidNamespaceSerializer; import org.apache.flink.streaming.api.graph.StreamConfig; import org.apache.flink.streaming.api.watermark.Watermark; import org.apache.flink.streaming.runtime.streamrecord.LatencyMarker; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.streaming.runtime.tasks.OperatorStateHandles; import org.apache.flink.streaming.runtime.tasks.StreamTask; import org.apache.flink.streaming.runtime.tasks.TimeServiceProvider; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.util.Collection; +import java.util.ConcurrentModificationException; import java.util.HashMap; import java.util.Map; -import java.util.ConcurrentModificationException; -import java.util.Collection; -import java.util.concurrent.RunnableFuture; /** * Base class for all stream operators. Operators that contain a user function should extend the class @@ -97,6 +106,7 @@ public abstract class AbstractStreamOperator private transient StreamingRuntimeContext runtimeContext; + // ---------------- key/value state ------------------ /** key selector used to get the key for the state. Non-null only is the operator uses key/value state */ @@ -106,11 +116,12 @@ public abstract class AbstractStreamOperator /** Backend for keyed state. This might be empty if we're not on a keyed stream. */ private transient AbstractKeyedStateBackend keyedStateBackend; - /** Operator state backend */ + /** Keyed state store view on the keyed backend */ + private transient DefaultKeyedStateStore keyedStateStore; + + /** Operator state backend / store */ private transient OperatorStateBackend operatorStateBackend; - private transient Collection lazyRestoreStateHandles; - // --------------- Metrics --------------------------- @@ -151,8 +162,61 @@ public MetricGroup getMetricGroup() { } @Override - public void restoreState(Collection stateHandles) { - this.lazyRestoreStateHandles = stateHandles; + public final void initializeState(OperatorStateHandles stateHandles) throws Exception { + + Collection keyedStateHandlesRaw = null; + Collection operatorStateHandlesRaw = null; + Collection operatorStateHandlesBackend = null; + + boolean restoring = null != stateHandles; + + initKeyedState(); //TODO we should move the actual initialization of this from StreamTask to this class + + if (restoring) { + + // TODO check that there is EITHER old OR new state in handles! + restoreStreamCheckpointed(stateHandles); + + //pass directly + operatorStateHandlesBackend = stateHandles.getManagedOperatorState(); + operatorStateHandlesRaw = stateHandles.getRawOperatorState(); + + if (null != getKeyedStateBackend()) { + //only use the keyed state if it is meant for us (aka head operator) + keyedStateHandlesRaw = stateHandles.getRawKeyedState(); + } + } + + initOperatorState(operatorStateHandlesBackend); + + StateInitializationContext initializationContext = new StateInitializationContextImpl( + restoring, // information whether we restore or start for the first time + operatorStateBackend, // access to operator state backend + keyedStateStore, // access to keyed state backend + keyedStateHandlesRaw, // access to keyed state stream + operatorStateHandlesRaw, // access to operator state stream + getContainingTask().getCancelables()); // access to register streams for canceling + + initializeState(initializationContext); + } + + @Deprecated + private void restoreStreamCheckpointed(OperatorStateHandles stateHandles) throws Exception { + StreamStateHandle state = stateHandles.getLegacyOperatorState(); + if (this instanceof StreamCheckpointedOperator && null != state) { + + LOG.debug("Restore state of task {} in chain ({}).", + stateHandles.getOperatorChainIndex(), getContainingTask().getName()); + + FSDataInputStream is = state.openInputStream(); + try { + getContainingTask().getCancelables().registerClosable(is); + ((StreamCheckpointedOperator) this).restoreState(is); + } finally { + getContainingTask().getCancelables().unregisterClosable(is); + is.close(); + } + } } /** @@ -165,8 +229,7 @@ public void restoreState(Collection stateHandles) { */ @Override public void open() throws Exception { - initOperatorState(); - initKeyedState(); + } private void initKeyedState() { @@ -174,7 +237,6 @@ private void initKeyedState() { TypeSerializer keySerializer = config.getStateKeySerializer(getUserCodeClassloader()); // create a keyed state backend if there is keyed state, as indicated by the presence of a key serializer if (null != keySerializer) { - KeyGroupRange subTaskKeyGroupRange = KeyGroupRangeAssignment.computeKeyGroupRangeForOperatorIndex( container.getEnvironment().getTaskInfo().getNumberOfKeyGroups(), container.getEnvironment().getTaskInfo().getNumberOfParallelSubtasks(), @@ -184,7 +246,8 @@ private void initKeyedState() { keySerializer, container.getConfiguration().getNumberOfKeyGroups(getUserCodeClassloader()), subTaskKeyGroupRange); - + + this.keyedStateStore = new DefaultKeyedStateStore(keyedStateBackend, getExecutionConfig()); } } catch (Exception e) { @@ -192,10 +255,10 @@ private void initKeyedState() { } } - private void initOperatorState() { + private void initOperatorState(Collection operatorStateHandles) { try { // create an operator state backend - this.operatorStateBackend = container.createOperatorStateBackend(this, lazyRestoreStateHandles); + this.operatorStateBackend = container.createOperatorStateBackend(this, operatorStateHandles); } catch (Exception e) { throw new IllegalStateException("Could not initialize operator state backend.", e); } @@ -238,11 +301,51 @@ public void dispose() throws Exception { } @Override - public RunnableFuture snapshotState( + public final OperatorSnapshotResult snapshotState( long checkpointId, long timestamp, CheckpointStreamFactory streamFactory) throws Exception { - return operatorStateBackend != null ? - operatorStateBackend.snapshot(checkpointId, timestamp, streamFactory) : null; + KeyGroupRange keyGroupRange = null != keyedStateBackend ? + keyedStateBackend.getKeyGroupRange() : KeyGroupRange.EMPTY_KEY_GROUP_RANGE; + + StateSnapshotContextSynchronousImpl snapshotContext = new StateSnapshotContextSynchronousImpl( + checkpointId, timestamp, streamFactory, keyGroupRange, getContainingTask().getCancelables()); + + snapshotState(snapshotContext); + + OperatorSnapshotResult snapshotInProgress = new OperatorSnapshotResult(); + + snapshotInProgress.setKeyedStateRawFuture(snapshotContext.getKeyedStateStreamFuture()); + snapshotInProgress.setOperatorStateRawFuture(snapshotContext.getOperatorStateStreamFuture()); + + if (null != operatorStateBackend) { + snapshotInProgress.setOperatorStateManagedFuture( + operatorStateBackend.snapshot(checkpointId, timestamp, streamFactory)); + } + + if (null != keyedStateBackend) { + snapshotInProgress.setKeyedStateManagedFuture( + keyedStateBackend.snapshot(checkpointId, timestamp, streamFactory)); + } + + return snapshotInProgress; + } + + /** + * Stream operators with state, which want to participate in a snapshot need to override this hook method. + * + * @param context context that provides information and means required for taking a snapshot + */ + public void snapshotState(StateSnapshotContext context) throws Exception { + + } + + /** + * Stream operators with state which can be restored need to override this hook method. + * + * @param context context that allows to register different states. + */ + public void initializeState(StateInitializationContext context) throws Exception { + } @Override @@ -283,22 +386,12 @@ public StreamingRuntimeContext getRuntimeContext() { return runtimeContext; } - @SuppressWarnings("rawtypes, unchecked") + @SuppressWarnings("unchecked") public KeyedStateBackend getKeyedStateBackend() { - - if (null == keyedStateBackend) { - initKeyedState(); - } - return (KeyedStateBackend) keyedStateBackend; } public OperatorStateBackend getOperatorStateBackend() { - - if (null == operatorStateBackend) { - initOperatorState(); - } - return operatorStateBackend; } @@ -327,12 +420,12 @@ protected S getPartitionedState(StateDescriptor stateDes * @throws Exception Thrown, if the state backend cannot create the key/value state. */ @SuppressWarnings("unchecked") - protected S getPartitionedState(N namespace, TypeSerializer namespaceSerializer, StateDescriptor stateDescriptor) throws Exception { - if (keyedStateBackend != null) { - return keyedStateBackend.getPartitionedState( - namespace, - namespaceSerializer, - stateDescriptor); + protected S getPartitionedState( + N namespace, TypeSerializer namespaceSerializer, + StateDescriptor stateDescriptor) throws Exception { + + if (keyedStateStore != null) { + return keyedStateStore.getPartitionedState(namespace, namespaceSerializer, stateDescriptor); } else { throw new RuntimeException("Cannot create partitioned state. The keyed state " + "backend has not been set. This indicates that the operator is not " + @@ -343,18 +436,18 @@ protected S getPartitionedState(N namespace, TypeSerializer @Override @SuppressWarnings({"unchecked", "rawtypes"}) public void setKeyContextElement1(StreamRecord record) throws Exception { - setRawKeyContextElement(record, stateKeySelector1); + setKeyContextElement(record, stateKeySelector1); } @Override @SuppressWarnings({"unchecked", "rawtypes"}) public void setKeyContextElement2(StreamRecord record) throws Exception { - setRawKeyContextElement(record, stateKeySelector2); + setKeyContextElement(record, stateKeySelector2); } - private void setRawKeyContextElement(StreamRecord record, KeySelector selector) throws Exception { + private void setKeyContextElement(StreamRecord record, KeySelector selector) throws Exception { if (selector != null) { - Object key = ((KeySelector) selector).getKey(record.getValue()); + Object key = selector.getKey(record.getValue()); setKeyContext(key); } } @@ -374,6 +467,10 @@ public void setKeyContext(Object key) { } } + public KeyedStateStore getKeyedStateStore() { + return keyedStateStore; + } + // ------------------------------------------------------------------------ // Context and chaining properties // ------------------------------------------------------------------------ @@ -567,4 +664,5 @@ public void close() { output.close(); } } + } diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java index 72f30b87f0db8..5e1a252544ee3 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java @@ -28,11 +28,12 @@ import org.apache.flink.core.fs.FSDataInputStream; import org.apache.flink.core.fs.FSDataOutputStream; import org.apache.flink.runtime.state.CheckpointListener; -import org.apache.flink.runtime.state.CheckpointStreamFactory; -import org.apache.flink.runtime.state.OperatorStateHandle; -import org.apache.flink.api.common.state.OperatorStateStore; +import org.apache.flink.runtime.state.DefaultOperatorStateBackend; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; import org.apache.flink.streaming.api.checkpoint.Checkpointed; import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction; +import org.apache.flink.streaming.api.checkpoint.CheckpointedRestoring; import org.apache.flink.streaming.api.checkpoint.ListCheckpointed; import org.apache.flink.streaming.api.graph.StreamConfig; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; @@ -42,7 +43,6 @@ import java.io.Serializable; import java.util.ArrayList; import java.util.List; -import java.util.concurrent.RunnableFuture; import static java.util.Objects.requireNonNull; @@ -73,6 +73,7 @@ public abstract class AbstractUdfStreamOperator public AbstractUdfStreamOperator(F userFunction) { this.userFunction = requireNonNull(userFunction); + checkUdfCheckpointingPreconditions(); } /** @@ -93,22 +94,44 @@ public void setup(StreamTask containingTask, StreamConfig config, Output listCheckpointedFun = (ListCheckpointed) userFunction; + List partitionableState = ((ListCheckpointed) userFunction). + snapshotState(context.getCheckpointId(), context.getCheckpointTimestamp()); ListState listState = getOperatorStateBackend(). - getSerializableListState(OperatorStateStore.DEFAULT_OPERATOR_STATE_NAME); + getSerializableListState(DefaultOperatorStateBackend.DEFAULT_OPERATOR_STATE_NAME); + + listState.clear(); + + for (Serializable statePartition : partitionableState) { + listState.add(statePartition); + } + } + + } + + @Override + public void initializeState(StateInitializationContext context) throws Exception { + super.initializeState(context); + + if (userFunction instanceof CheckpointedFunction) { + ((CheckpointedFunction) userFunction).initializeState(context); + } else if (context.isRestored() && userFunction instanceof ListCheckpointed) { + @SuppressWarnings("unchecked") + ListCheckpointed listCheckpointedFun = (ListCheckpointed) userFunction; + + ListState listState = context.getManagedOperatorStateStore(). + getSerializableListState(DefaultOperatorStateBackend.DEFAULT_OPERATOR_STATE_NAME); List list = new ArrayList<>(); @@ -122,6 +145,13 @@ public void open() throws Exception { throw new Exception("Failed to restore state to function: " + e.getMessage(), e); } } + + } + + @Override + public void open() throws Exception { + super.open(); + FunctionUtils.openFunction(userFunction, new Configuration()); } @Override @@ -147,6 +177,7 @@ public void dispose() throws Exception { @Override public void snapshotState(FSDataOutputStream out, long checkpointId, long timestamp) throws Exception { + if (userFunction instanceof Checkpointed) { @SuppressWarnings("unchecked") Checkpointed chkFunction = (Checkpointed) userFunction; @@ -169,9 +200,9 @@ public void snapshotState(FSDataOutputStream out, long checkpointId, long timest @Override public void restoreState(FSDataInputStream in) throws Exception { - if (userFunction instanceof Checkpointed) { + if (userFunction instanceof CheckpointedRestoring) { @SuppressWarnings("unchecked") - Checkpointed chkFunction = (Checkpointed) userFunction; + CheckpointedRestoring chkFunction = (CheckpointedRestoring) userFunction; int hasUdfState = in.read(); @@ -188,32 +219,6 @@ public void restoreState(FSDataInputStream in) throws Exception { } } - @Override - public RunnableFuture snapshotState( - long checkpointId, long timestamp, CheckpointStreamFactory streamFactory) throws Exception { - - if (userFunction instanceof CheckpointedFunction) { - ((CheckpointedFunction) userFunction).prepareSnapshot(checkpointId, timestamp); - } - - if (userFunction instanceof ListCheckpointed) { - @SuppressWarnings("unchecked") - List partitionableState = - ((ListCheckpointed) userFunction).snapshotState(checkpointId, timestamp); - - ListState listState = getOperatorStateBackend(). - getSerializableListState(OperatorStateStore.DEFAULT_OPERATOR_STATE_NAME); - - listState.clear(); - - for (Serializable statePartition : partitionableState) { - listState.add(statePartition); - } - } - - return super.snapshotState(checkpointId, timestamp, streamFactory); - } - @Override public void notifyOfCompletedCheckpoint(long checkpointId) throws Exception { super.notifyOfCompletedCheckpoint(checkpointId); @@ -251,4 +256,26 @@ public void setOutputType(TypeInformation outTypeInfo, ExecutionConfig exec public Configuration getUserFunctionParameters() { return new Configuration(); } + + private void checkUdfCheckpointingPreconditions() { + + boolean newCheckpointInferface = false; + + if (userFunction instanceof CheckpointedFunction) { + newCheckpointInferface = true; + } + + if (userFunction instanceof ListCheckpointed) { + if (newCheckpointInferface) { + throw new IllegalStateException("User functions are not allowed to implement " + + "CheckpointedFunction AND ListCheckpointed."); + } + newCheckpointInferface = true; + } + + if (newCheckpointInferface && userFunction instanceof Checkpointed) { + throw new IllegalStateException("User functions are not allowed to implement Checkpointed AND " + + "CheckpointedFunction/ListCheckpointed."); + } + } } diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/OperatorSnapshotResult.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/OperatorSnapshotResult.java new file mode 100644 index 0000000000000..52c89f83db00f --- /dev/null +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/OperatorSnapshotResult.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.api.operators; + +import org.apache.flink.runtime.state.KeyGroupsStateHandle; +import org.apache.flink.runtime.state.OperatorStateHandle; + +import java.util.concurrent.RunnableFuture; + +/** + * Result of {@link AbstractStreamOperator#snapshotState}. + */ +public class OperatorSnapshotResult { + + private RunnableFuture keyedStateManagedFuture; + private RunnableFuture keyedStateRawFuture; + private RunnableFuture operatorStateManagedFuture; + private RunnableFuture operatorStateRawFuture; + + public OperatorSnapshotResult() { + } + + public OperatorSnapshotResult( + RunnableFuture keyedStateManagedFuture, + RunnableFuture keyedStateRawFuture, + RunnableFuture operatorStateManagedFuture, + RunnableFuture operatorStateRawFuture) { + this.keyedStateManagedFuture = keyedStateManagedFuture; + this.keyedStateRawFuture = keyedStateRawFuture; + this.operatorStateManagedFuture = operatorStateManagedFuture; + this.operatorStateRawFuture = operatorStateRawFuture; + } + + public RunnableFuture getKeyedStateManagedFuture() { + return keyedStateManagedFuture; + } + + public void setKeyedStateManagedFuture(RunnableFuture keyedStateManagedFuture) { + this.keyedStateManagedFuture = keyedStateManagedFuture; + } + + public RunnableFuture getKeyedStateRawFuture() { + return keyedStateRawFuture; + } + + public void setKeyedStateRawFuture(RunnableFuture keyedStateRawFuture) { + this.keyedStateRawFuture = keyedStateRawFuture; + } + + public RunnableFuture getOperatorStateManagedFuture() { + return operatorStateManagedFuture; + } + + public void setOperatorStateManagedFuture(RunnableFuture operatorStateManagedFuture) { + this.operatorStateManagedFuture = operatorStateManagedFuture; + } + + public RunnableFuture getOperatorStateRawFuture() { + return operatorStateRawFuture; + } + + public void setOperatorStateRawFuture(RunnableFuture operatorStateRawFuture) { + this.operatorStateRawFuture = operatorStateRawFuture; + } +} \ No newline at end of file diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java index fae5fd063b599..f6e547290e04d 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java @@ -20,14 +20,12 @@ import org.apache.flink.annotation.PublicEvolving; import org.apache.flink.metrics.MetricGroup; import org.apache.flink.runtime.state.CheckpointStreamFactory; -import org.apache.flink.runtime.state.OperatorStateHandle; import org.apache.flink.streaming.api.graph.StreamConfig; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.streaming.runtime.tasks.OperatorStateHandles; import org.apache.flink.streaming.runtime.tasks.StreamTask; import java.io.Serializable; -import java.util.Collection; -import java.util.concurrent.RunnableFuture; /** * Basic interface for stream operators. Implementers would implement one of @@ -105,7 +103,7 @@ public interface StreamOperator extends Serializable { * the runnable might already be finished. * @throws Exception exception that happened during snapshotting. */ - RunnableFuture snapshotState( + OperatorSnapshotResult snapshotState( long checkpointId, long timestamp, CheckpointStreamFactory streamFactory) throws Exception; /** @@ -113,7 +111,7 @@ RunnableFuture snapshotState( * * @param stateHandles state handles to the operator state. */ - void restoreState(Collection stateHandles); + void initializeState(OperatorStateHandles stateHandles) throws Exception; /** * Called when the checkpoint with the given ID is completed and acknowledged on the JobManager. diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContext.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContext.java index cc2e54b57ac5f..cd0489f5fb3e8 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContext.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContext.java @@ -37,8 +37,6 @@ import java.util.List; import java.util.Map; -import static java.util.Objects.requireNonNull; - /** * Implementation of the {@link org.apache.flink.api.common.functions.RuntimeContext}, * for streaming operators. @@ -108,36 +106,17 @@ public C getBroadcastVariableWithInitializer(String name, BroadcastVariab @Override public ValueState getState(ValueStateDescriptor stateProperties) { - requireNonNull(stateProperties, "The state properties must not be null"); - try { - stateProperties.initializeSerializerUnlessSet(getExecutionConfig()); - return operator.getPartitionedState(stateProperties); - } catch (Exception e) { - throw new RuntimeException("Error while getting state", e); - } + return operator.getKeyedStateStore().getState(stateProperties); } @Override public ListState getListState(ListStateDescriptor stateProperties) { - requireNonNull(stateProperties, "The state properties must not be null"); - try { - stateProperties.initializeSerializerUnlessSet(getExecutionConfig()); - ListState originalState = operator.getPartitionedState(stateProperties); - return new UserFacingListState(originalState); - } catch (Exception e) { - throw new RuntimeException("Error while getting state", e); - } + return operator.getKeyedStateStore().getListState(stateProperties); } @Override public ReducingState getReducingState(ReducingStateDescriptor stateProperties) { - requireNonNull(stateProperties, "The state properties must not be null"); - try { - stateProperties.initializeSerializerUnlessSet(getExecutionConfig()); - return operator.getPartitionedState(stateProperties); - } catch (Exception e) { - throw new RuntimeException("Error while getting state", e); - } + return operator.getKeyedStateStore().getReducingState(stateProperties); } // ------------------ expose (read only) relevant information from the stream config -------- // diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/WindowOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/WindowOperator.java index 4d8f6550d8000..ea5f6170c7f9a 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/WindowOperator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/WindowOperator.java @@ -40,8 +40,8 @@ import org.apache.flink.core.memory.DataInputView; import org.apache.flink.core.memory.DataInputViewStreamWrapper; import org.apache.flink.core.memory.DataOutputView; -import org.apache.flink.metrics.MetricGroup; import org.apache.flink.core.memory.DataOutputViewStreamWrapper; +import org.apache.flink.metrics.MetricGroup; import org.apache.flink.runtime.state.VoidNamespace; import org.apache.flink.runtime.state.VoidNamespaceSerializer; import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator; @@ -708,7 +708,7 @@ public void registerProcessingTimeTimer(long time) { if (processingTimeTimers.add(timer)) { Timer oldHead = processingTimeTimersQueue.peek(); - long nextTriggerTime = oldHead != null ? oldHead.timestamp : Long.MAX_VALUE; + long nextTriggerTime = oldHead != null ? oldHead.timestamp : Long.MAX_VALUE; processingTimeTimersQueue.add(timer); diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/OperatorStateHandles.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/OperatorStateHandles.java new file mode 100644 index 0000000000000..7abf8d99187d1 --- /dev/null +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/OperatorStateHandles.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.runtime.tasks; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.annotation.VisibleForTesting; +import org.apache.flink.runtime.state.ChainedStateHandle; +import org.apache.flink.runtime.state.KeyGroupsStateHandle; +import org.apache.flink.runtime.state.OperatorStateHandle; +import org.apache.flink.runtime.state.StreamStateHandle; +import org.apache.flink.runtime.state.TaskStateHandles; +import org.apache.flink.util.CollectionUtil; +import org.apache.flink.util.Preconditions; + +import java.util.Collection; +import java.util.List; + +/** + * This class holds all state handles for one operator. + */ +@Internal +@VisibleForTesting +public class OperatorStateHandles { + + private final int operatorChainIndex; + + private final StreamStateHandle legacyOperatorState; + + private final Collection managedKeyedState; + private final Collection rawKeyedState; + private final Collection managedOperatorState; + private final Collection rawOperatorState; + + public OperatorStateHandles( + int operatorChainIndex, + StreamStateHandle legacyOperatorState, + Collection managedKeyedState, + Collection rawKeyedState, + Collection managedOperatorState, + Collection rawOperatorState) { + + this.operatorChainIndex = operatorChainIndex; + this.legacyOperatorState = legacyOperatorState; + this.managedKeyedState = managedKeyedState; + this.rawKeyedState = rawKeyedState; + this.managedOperatorState = managedOperatorState; + this.rawOperatorState = rawOperatorState; + } + + public OperatorStateHandles(TaskStateHandles taskStateHandles, int operatorChainIndex) { + Preconditions.checkNotNull(taskStateHandles); + + this.operatorChainIndex = operatorChainIndex; + + ChainedStateHandle legacyState = taskStateHandles.getLegacyOperatorState(); + this.legacyOperatorState = ChainedStateHandle.isNullOrEmpty(legacyState) ? + null : legacyState.get(operatorChainIndex); + + this.rawKeyedState = taskStateHandles.getRawKeyedState(); + this.managedKeyedState = taskStateHandles.getManagedKeyedState(); + + this.managedOperatorState = getSafeItemAtIndexOrNull(taskStateHandles.getManagedOperatorState(), operatorChainIndex); + this.rawOperatorState = getSafeItemAtIndexOrNull(taskStateHandles.getRawOperatorState(), operatorChainIndex); + } + + public StreamStateHandle getLegacyOperatorState() { + return legacyOperatorState; + } + + public Collection getManagedKeyedState() { + return managedKeyedState; + } + + public Collection getRawKeyedState() { + return rawKeyedState; + } + + public Collection getManagedOperatorState() { + return managedOperatorState; + } + + public Collection getRawOperatorState() { + return rawOperatorState; + } + + public int getOperatorChainIndex() { + return operatorChainIndex; + } + + private static T getSafeItemAtIndexOrNull(List list, int idx) { + return CollectionUtil.isNullOrEmpty(list) ? null : list.get(idx); + } +} \ No newline at end of file diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java index 2e6ebf3c077c5..eb5fde71e3ff7 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java @@ -23,9 +23,9 @@ import org.apache.flink.configuration.ConfigConstants; import org.apache.flink.configuration.Configuration; import org.apache.flink.configuration.IllegalConfigurationException; -import org.apache.flink.core.fs.FSDataInputStream; import org.apache.flink.metrics.Gauge; import org.apache.flink.runtime.checkpoint.CheckpointMetaData; +import org.apache.flink.runtime.checkpoint.SubtaskState; import org.apache.flink.runtime.execution.CancelTaskException; import org.apache.flink.runtime.execution.Environment; import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable; @@ -33,7 +33,6 @@ import org.apache.flink.runtime.state.AbstractKeyedStateBackend; import org.apache.flink.runtime.state.AbstractStateBackend; import org.apache.flink.runtime.state.ChainedStateHandle; -import org.apache.flink.runtime.state.CheckpointStateHandles; import org.apache.flink.runtime.state.CheckpointStreamFactory; import org.apache.flink.runtime.state.ClosableRegistry; import org.apache.flink.runtime.state.KeyGroupRange; @@ -42,27 +41,29 @@ import org.apache.flink.runtime.state.OperatorStateHandle; import org.apache.flink.runtime.state.StateBackendFactory; import org.apache.flink.runtime.state.StreamStateHandle; +import org.apache.flink.runtime.state.TaskStateHandles; import org.apache.flink.runtime.state.filesystem.FsStateBackend; import org.apache.flink.runtime.state.filesystem.FsStateBackendFactory; import org.apache.flink.runtime.state.memory.MemoryStateBackend; import org.apache.flink.runtime.taskmanager.DispatcherThreadFactory; import org.apache.flink.streaming.api.TimeCharacteristic; import org.apache.flink.streaming.api.graph.StreamConfig; +import org.apache.flink.streaming.api.operators.OperatorSnapshotResult; import org.apache.flink.streaming.api.operators.Output; import org.apache.flink.streaming.api.operators.StreamCheckpointedOperator; import org.apache.flink.streaming.api.operators.StreamOperator; import org.apache.flink.streaming.runtime.io.RecordWriterOutput; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.util.CollectionUtil; +import org.apache.flink.util.FutureUtil; import org.apache.flink.util.Preconditions; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.Closeable; import java.io.IOException; -import java.util.Arrays; +import java.util.ArrayList; import java.util.Collection; -import java.util.Collections; import java.util.List; import java.util.Map; import java.util.concurrent.ExecutorService; @@ -87,13 +88,14 @@ * * The life cycle of the task is set up as follows: *
{@code
- *  -- getOperatorState() -> restores state of all operators in the chain
+ *  -- setInitialState -> provides state of all operators in the chain
  *
  *  -- invoke()
  *        |
  *        +----> Create basic utils (config, etc) and load the chain of operators
  *        +----> operators.setup()
  *        +----> task specific init()
+ *        +----> initialize-operator-states()
  *        +----> open-operators()
  *        +----> run()
  *        +----> close-operators()
@@ -153,12 +155,7 @@ public abstract class StreamTask>
 	/** The map of user-defined accumulators of this task */
 	private Map> accumulatorMap;
 
-	/** The chained operator state to be restored once the initialization is done */
-	private ChainedStateHandle lazyRestoreChainedOperatorState;
-
-	private List lazyRestoreKeyGroupStates;
-
-	private List> lazyRestoreOperatorState;
+	private TaskStateHandles restoreStateHandles;
 
 
 	/** The currently active background materialization threads */
@@ -251,9 +248,8 @@ public Long getValue() {
 			// -------- Invoke --------
 			LOG.debug("Invoking {}", getName());
 
-			// first order of business is to give operators back their state
-			restoreState();
-			lazyRestoreChainedOperatorState = null; // GC friendliness
+			// first order of business is to give operators their state
+			initializeState();
 
 			// we need to make sure that any triggers scheduled in open() cannot be
 			// executed before all operators are opened
@@ -510,60 +506,8 @@ RecordWriterOutput[] getStreamOutputs() {
 	// ------------------------------------------------------------------------
 
 	@Override
-	public void setInitialState(
-		ChainedStateHandle chainedState,
-		List keyGroupsState,
-		List> partitionableOperatorState) {
-
-		lazyRestoreChainedOperatorState = chainedState;
-		lazyRestoreKeyGroupStates = keyGroupsState;
-		lazyRestoreOperatorState = partitionableOperatorState;
-	}
-
-	private void restoreState() throws Exception {
-		final StreamOperator[] allOperators = operatorChain.getAllOperators();
-
-		if (lazyRestoreChainedOperatorState != null) {
-			Preconditions.checkState(lazyRestoreChainedOperatorState.getLength() == allOperators.length,
-					"Invalid Invalid number of operator states. Found :" + lazyRestoreChainedOperatorState.getLength() +
-							". Expected: " + allOperators.length);
-		}
-
-		if (lazyRestoreOperatorState != null) {
-			Preconditions.checkArgument(lazyRestoreOperatorState.isEmpty()
-							|| lazyRestoreOperatorState.size() == allOperators.length,
-					"Invalid number of operator states. Found :" + lazyRestoreOperatorState.size() +
-							". Expected: " + allOperators.length);
-		}
-
-		for (int i = 0; i < allOperators.length; i++) {
-			StreamOperator operator = allOperators[i];
-
-			if (null != lazyRestoreOperatorState && !lazyRestoreOperatorState.isEmpty()) {
-				operator.restoreState(lazyRestoreOperatorState.get(i));
-			}
-
-			// TODO deprecated code path
-			if (operator instanceof StreamCheckpointedOperator) {
-
-				if (lazyRestoreChainedOperatorState != null) {
-					StreamStateHandle state = lazyRestoreChainedOperatorState.get(i);
-
-					if (state != null) {
-						LOG.debug("Restore state of task {} in chain ({}).", i, getName());
-
-						FSDataInputStream is = state.openInputStream();
-						try {
-							cancelables.registerClosable(is);
-							((StreamCheckpointedOperator) operator).restoreState(is);
-						} finally {
-							cancelables.unregisterClosable(is);
-							is.close();
-						}
-					}
-				}
-			}
-		}
+	public void setInitialState(TaskStateHandles taskStateHandles) {
+		this.restoreStateHandles = taskStateHandles;
 	}
 
 	@Override
@@ -600,117 +544,19 @@ public void triggerCheckpointOnBarrier(CheckpointMetaData checkpointMetaData) th
 
 	private boolean performCheckpoint(CheckpointMetaData checkpointMetaData) throws Exception {
 
-		long checkpointId = checkpointMetaData.getCheckpointId();
-		long timestamp = checkpointMetaData.getTimestamp();
-
-		LOG.debug("Starting checkpoint {} on task {}", checkpointId, getName());
+		LOG.debug("Starting checkpoint {} on task {}", checkpointMetaData.getCheckpointId(), getName());
 
 		synchronized (lock) {
 			if (isRunning) {
 
-				final long startOfSyncPart = System.nanoTime();
-
 				// Since both state checkpointing and downstream barrier emission occurs in this
 				// lock scope, they are an atomic operation regardless of the order in which they occur.
 				// Given this, we immediately emit the checkpoint barriers, so the downstream operators
 				// can start their checkpoint work as soon as possible
-				operatorChain.broadcastCheckpointBarrier(checkpointId, timestamp);
-
-				// now draw the state snapshot
-				final StreamOperator[] allOperators = operatorChain.getAllOperators();
-
-				final List nonPartitionedStates =
-						Arrays.asList(new StreamStateHandle[allOperators.length]);
-
-				final List operatorStates =
-						Arrays.asList(new OperatorStateHandle[allOperators.length]);
-
-				for (int i = 0; i < allOperators.length; i++) {
-					StreamOperator operator = allOperators[i];
-
-					if (operator != null) {
-
-						final String operatorId = createOperatorIdentifier(operator, configuration.getVertexID());
-
-						CheckpointStreamFactory streamFactory =
-								stateBackend.createStreamFactory(getEnvironment().getJobID(), operatorId);
-
-						//TODO deprecated code path
-						if (operator instanceof StreamCheckpointedOperator) {
-
-							CheckpointStreamFactory.CheckpointStateOutputStream outStream =
-									streamFactory.createCheckpointStateOutputStream(checkpointId, timestamp);
-
-
-							cancelables.registerClosable(outStream);
-
-							try {
-								((StreamCheckpointedOperator) operator).
-										snapshotState(outStream, checkpointId, timestamp);
-
-								nonPartitionedStates.set(i, outStream.closeAndGetHandle());
-							} finally {
-								cancelables.unregisterClosable(outStream);
-							}
-						}
-
-						RunnableFuture handleFuture =
-								operator.snapshotState(checkpointId, timestamp, streamFactory);
-
-						if (null != handleFuture) {
-							//TODO for now we assume there are only synchrous snapshots, no need to start the runnable.
-							if (!handleFuture.isDone()) {
-								throw new IllegalStateException("Currently only supports synchronous snapshots!");
-							}
-
-							operatorStates.set(i, handleFuture.get());
-						}
-					}
-
-				}
-
-				RunnableFuture keyGroupsStateHandleFuture = null;
-
-				if (keyedStateBackend != null) {
-					CheckpointStreamFactory streamFactory = stateBackend.createStreamFactory(
-							getEnvironment().getJobID(),
-							createOperatorIdentifier(headOperator, configuration.getVertexID()));
-
-					keyGroupsStateHandleFuture = keyedStateBackend.snapshot(checkpointId, timestamp, streamFactory);
-				}
-
-				ChainedStateHandle chainedNonPartitionedStateHandles =
-						new ChainedStateHandle<>(nonPartitionedStates);
-
-				ChainedStateHandle chainedPartitionedStateHandles =
-						new ChainedStateHandle<>(operatorStates);
-
-				LOG.debug("Finished synchronous checkpoints for checkpoint {} on task {}", checkpointId, getName());
-
-				final long syncEndNanos = System.nanoTime();
-				final long syncDurationMillis = (syncEndNanos - startOfSyncPart) / 1_000_000;
-
-				checkpointMetaData.setSyncDurationMillis(syncDurationMillis);
-
-				AsyncCheckpointRunnable asyncCheckpointRunnable = new AsyncCheckpointRunnable(
-						"checkpoint-" + checkpointId + "-" + timestamp,
-						this,
-						cancelables,
-						chainedNonPartitionedStateHandles,
-						chainedPartitionedStateHandles,
-						keyGroupsStateHandleFuture,
-						checkpointMetaData,
-						syncEndNanos);
-
-				cancelables.registerClosable(asyncCheckpointRunnable);
-				asyncOperationsThreadPool.submit(asyncCheckpointRunnable);
-
-				if (LOG.isDebugEnabled()) {
-					LOG.debug("{} - finished synchronous part of checkpoint {}." +
-							"Alignment duration: {} ms, snapshot duration {} ms",
-							getName(), checkpointId, checkpointMetaData.getAlignmentDurationNanos() / 1_000_000, syncDurationMillis);
-				}
+				operatorChain.broadcastCheckpointBarrier(
+						checkpointMetaData.getCheckpointId(), checkpointMetaData.getTimestamp());
 
+				checkpointState(checkpointMetaData);
 				return true;
 			} else {
 				return false;
@@ -740,6 +586,59 @@ public void notifyCheckpointComplete(long checkpointId) throws Exception {
 		}
 	}
 
+	private void checkpointState(CheckpointMetaData checkpointMetaData) throws Exception {
+		CheckpointingOperation checkpointingOperation = new CheckpointingOperation(this, checkpointMetaData);
+		checkpointingOperation.executeCheckpointing();
+	}
+
+	private void initializeState() throws Exception {
+
+		boolean restored = null != restoreStateHandles;
+
+		if (restored) {
+
+			checkRestorePreconditions(operatorChain.getChainLength());
+			initializeOperators(true);
+			restoreStateHandles = null; // free for GC
+		} else {
+			initializeOperators(false);
+		}
+	}
+
+	private void initializeOperators(boolean restored) throws Exception {
+		StreamOperator[] allOperators = operatorChain.getAllOperators();
+		for (int chainIdx = 0; chainIdx < allOperators.length; ++chainIdx) {
+			StreamOperator operator = allOperators[chainIdx];
+			if (null != operator) {
+				if (restored) {
+					operator.initializeState(new OperatorStateHandles(restoreStateHandles, chainIdx));
+				} else {
+					operator.initializeState(null);
+				}
+			}
+		}
+	}
+
+	private void checkRestorePreconditions(int operatorChainLength) {
+
+		ChainedStateHandle nonPartitionableOperatorStates =
+				restoreStateHandles.getLegacyOperatorState();
+		List> operatorStates =
+				restoreStateHandles.getManagedOperatorState();
+
+		if (nonPartitionableOperatorStates != null) {
+			Preconditions.checkState(nonPartitionableOperatorStates.getLength() == operatorChainLength,
+					"Invalid Invalid number of operator states. Found :" + nonPartitionableOperatorStates.getLength()
+							+ ". Expected: " + operatorChainLength);
+		}
+
+		if (!CollectionUtil.isNullOrEmpty(operatorStates)) {
+			Preconditions.checkArgument(operatorStates.size() == operatorChainLength,
+					"Invalid number of operator states. Found :" + operatorStates.size() +
+							". Expected: " + operatorChainLength);
+		}
+	}
+
 	// ------------------------------------------------------------------------
 	//  State backend
 	// ------------------------------------------------------------------------
@@ -777,7 +676,8 @@ private AbstractStateBackend createStateBackend() throws Exception {
 					try {
 						@SuppressWarnings("rawtypes")
 						Class clazz =
-								Class.forName(backendName, false, getUserCodeClassLoader()).asSubclass(StateBackendFactory.class);
+								Class.forName(backendName, false, getUserCodeClassLoader()).
+										asSubclass(StateBackendFactory.class);
 
 						stateBackend = clazz.newInstance().createFromConfig(flinkConfig);
 					} catch (ClassNotFoundException e) {
@@ -799,7 +699,7 @@ public OperatorStateBackend createOperatorStateBackend(
 			StreamOperator op, Collection restoreStateHandles) throws Exception {
 
 		Environment env = getEnvironment();
-		String opId = createOperatorIdentifier(op, configuration.getVertexID());
+		String opId = createOperatorIdentifier(op, getConfiguration().getVertexID());
 
 		OperatorStateBackend newBackend = restoreStateHandles == null ?
 				stateBackend.createOperatorStateBackend(env, opId)
@@ -823,7 +723,7 @@ public  AbstractKeyedStateBackend createKeyedStateBackend(
 				headOperator,
 				configuration.getVertexID());
 
-		if (lazyRestoreKeyGroupStates != null) {
+		if (null != restoreStateHandles && null != restoreStateHandles.getManagedKeyedState()) {
 			keyedStateBackend = stateBackend.restoreKeyedStateBackend(
 					getEnvironment(),
 					getEnvironment().getJobID(),
@@ -831,10 +731,10 @@ public  AbstractKeyedStateBackend createKeyedStateBackend(
 					keySerializer,
 					numberOfKeyGroups,
 					keyGroupRange,
-					lazyRestoreKeyGroupStates,
+					restoreStateHandles.getManagedKeyedState(),
 					getEnvironment().getTaskKvStateRegistry());
 
-			lazyRestoreKeyGroupStates = null; // GC friendliness
+			restoreStateHandles = null; // GC friendliness
 		} else {
 			keyedStateBackend = stateBackend.createKeyedStateBackend(
 					getEnvironment(),
@@ -913,62 +813,60 @@ public String toString() {
 
 	// ------------------------------------------------------------------------
 
-	private static class AsyncCheckpointRunnable implements Runnable, Closeable {
+	private static final class AsyncCheckpointRunnable implements Runnable, Closeable {
 
 		private final StreamTask owner;
 
-		private final ClosableRegistry cancelables;
-
-		private final ChainedStateHandle nonPartitionedStateHandles;
-
-		private final ChainedStateHandle partitioneableStateHandles;
+		private final List snapshotInProgressList;
 
-		private final RunnableFuture keyGroupsStateHandleFuture;
+		RunnableFuture futureKeyedBackendStateHandles;
+		RunnableFuture futureKeyedStreamStateHandles;
 
-		private final String name;
+		List nonPartitionedStateHandles;
 
 		private final CheckpointMetaData checkpointMetaData;
 
 		private final long asyncStartNanos;
 
 		AsyncCheckpointRunnable(
-				String name,
 				StreamTask owner,
-				ClosableRegistry cancelables,
-				ChainedStateHandle nonPartitionedStateHandles,
-				ChainedStateHandle partitioneableStateHandles,
-				RunnableFuture keyGroupsStateHandleFuture,
+				List nonPartitionedStateHandles,
+				List snapshotInProgressList,
 				CheckpointMetaData checkpointMetaData,
-				long asyncStartNanos
-		) {
+				long asyncStartNanos) {
 
-			this.name = name;
-			this.owner = owner;
-			this.cancelables = cancelables;
+			this.owner = Preconditions.checkNotNull(owner);
+			this.snapshotInProgressList = Preconditions.checkNotNull(snapshotInProgressList);
+			this.checkpointMetaData = Preconditions.checkNotNull(checkpointMetaData);
 			this.nonPartitionedStateHandles = nonPartitionedStateHandles;
-			this.partitioneableStateHandles = partitioneableStateHandles;
-			this.keyGroupsStateHandleFuture = keyGroupsStateHandleFuture;
-			this.checkpointMetaData = checkpointMetaData;
 			this.asyncStartNanos = asyncStartNanos;
+
+			if (!snapshotInProgressList.isEmpty()) {
+				// TODO Currently only the head operator of a chain can have keyed state, so simply access it directly.
+				int headIndex = snapshotInProgressList.size() - 1;
+				OperatorSnapshotResult snapshotInProgress = snapshotInProgressList.get(headIndex);
+				if (null != snapshotInProgress) {
+					this.futureKeyedBackendStateHandles = snapshotInProgress.getKeyedStateManagedFuture();
+					this.futureKeyedStreamStateHandles = snapshotInProgress.getKeyedStateRawFuture();
+				}
+			}
 		}
 
 		@Override
 		public void run() {
+
 			try {
 
-				List keyedStates = Collections.emptyList();
+				// Keyed state handle future, currently only one (the head) operator can have this
+				KeyGroupsStateHandle keyedStateHandleBackend = FutureUtil.runIfNotDoneAndGet(futureKeyedBackendStateHandles);
+				KeyGroupsStateHandle keyedStateHandleStream = FutureUtil.runIfNotDoneAndGet(futureKeyedStreamStateHandles);
 
-				if (keyGroupsStateHandleFuture != null) {
+				List operatorStatesBackend = new ArrayList<>(snapshotInProgressList.size());
+				List operatorStatesStream = new ArrayList<>(snapshotInProgressList.size());
 
-					if (!keyGroupsStateHandleFuture.isDone()) {
-						//TODO this currently works because we only have one RunnableFuture
-						keyGroupsStateHandleFuture.run();
-					}
-
-					KeyGroupsStateHandle keyGroupsStateHandle = this.keyGroupsStateHandleFuture.get();
-					if (keyGroupsStateHandle != null) {
-						keyedStates = Collections.singletonList(keyGroupsStateHandle);
-					}
+				for (OperatorSnapshotResult snapshotInProgress : snapshotInProgressList) {
+					operatorStatesBackend.add(FutureUtil.runIfNotDoneAndGet(snapshotInProgress.getOperatorStateManagedFuture()));
+					operatorStatesStream.add(FutureUtil.runIfNotDoneAndGet(snapshotInProgress.getOperatorStateRawFuture()));
 				}
 
 				final long asyncEndNanos = System.nanoTime();
@@ -976,37 +874,161 @@ public void run() {
 
 				checkpointMetaData.setAsyncDurationMillis(asyncDurationMillis);
 
-				if (nonPartitionedStateHandles.isEmpty() && partitioneableStateHandles.isEmpty() && keyedStates.isEmpty()) {
-					owner.getEnvironment().acknowledgeCheckpoint(checkpointMetaData);
-				} else {
-					CheckpointStateHandles allStateHandles = new CheckpointStateHandles(
-							nonPartitionedStateHandles,
-							partitioneableStateHandles,
-							keyedStates);
+				ChainedStateHandle chainedNonPartitionedOperatorsState =
+						new ChainedStateHandle<>(nonPartitionedStateHandles);
 
-					owner.getEnvironment().acknowledgeCheckpoint(checkpointMetaData, allStateHandles);
+				ChainedStateHandle chainedOperatorStateBackend =
+						new ChainedStateHandle<>(operatorStatesBackend);
+
+				ChainedStateHandle chainedOperatorStateStream =
+						new ChainedStateHandle<>(operatorStatesStream);
+
+				SubtaskState subtaskState = new SubtaskState(
+						chainedNonPartitionedOperatorsState,
+						chainedOperatorStateBackend,
+						chainedOperatorStateStream,
+						keyedStateHandleBackend,
+						keyedStateHandleStream);
+
+				if (subtaskState.hasState()) {
+					owner.getEnvironment().acknowledgeCheckpoint(checkpointMetaData, subtaskState);
+				} else {
+					owner.getEnvironment().acknowledgeCheckpoint(checkpointMetaData);
 				}
 
 				if (LOG.isDebugEnabled()) {
-					LOG.debug("{} - finished asynchronous part of checkpoint {}. Asynchronous duration: {} ms", 
+					LOG.debug("{} - finished asynchronous part of checkpoint {}. Asynchronous duration: {} ms",
 							owner.getName(), checkpointMetaData.getCheckpointId(), asyncDurationMillis);
 				}
-			}
-			catch (Exception e) {
+			} catch (Exception e) {
 				// registers the exception and tries to fail the whole task
 				AsynchronousException asyncException = new AsynchronousException(e);
 				owner.handleAsyncException("Failure in asynchronous checkpoint materialization", asyncException);
-			}
-			finally {
-				cancelables.unregisterClosable(this);
+			} finally {
+				owner.cancelables.unregisterClosable(this);
 			}
 		}
 
 		@Override
 		public void close() {
-			if (keyGroupsStateHandleFuture != null) {
-				keyGroupsStateHandleFuture.cancel(true);
+			//TODO Handle other state futures in case we actually run them. Currently they are just DoneFutures.
+			if (futureKeyedBackendStateHandles != null) {
+				futureKeyedBackendStateHandles.cancel(true);
+			}
+		}
+	}
+
+	public ClosableRegistry getCancelables() {
+		return cancelables;
+	}
+
+	// ------------------------------------------------------------------------
+
+	private static final class CheckpointingOperation {
+
+		private final StreamTask owner;
+
+		private final CheckpointMetaData checkpointMetaData;
+
+		private final StreamOperator[] allOperators;
+
+		private long startSyncPartNano;
+		private long startAsyncPartNano;
+
+		// ------------------------
+
+		private CheckpointStreamFactory streamFactory;
+
+		private final List nonPartitionedStates;
+		private final List snapshotInProgressList;
+
+		public CheckpointingOperation(StreamTask owner, CheckpointMetaData checkpointMetaData) {
+			this.owner = Preconditions.checkNotNull(owner);
+			this.checkpointMetaData = Preconditions.checkNotNull(checkpointMetaData);
+			this.allOperators = owner.operatorChain.getAllOperators();
+			this.nonPartitionedStates = new ArrayList<>(allOperators.length);
+			this.snapshotInProgressList = new ArrayList<>(allOperators.length);
+		}
+
+		public void executeCheckpointing() throws Exception {
+
+			startSyncPartNano = System.nanoTime();
+
+			for (StreamOperator op : allOperators) {
+
+				createStreamFactory(op);
+				snapshotNonPartitionableState(op);
+
+				OperatorSnapshotResult snapshotInProgress =
+						op.snapshotState(checkpointMetaData.getCheckpointId(), checkpointMetaData.getTimestamp(), streamFactory);
+
+				snapshotInProgressList.add(snapshotInProgress);
 			}
+
+			if (LOG.isDebugEnabled()) {
+				LOG.debug("Finished synchronous checkpoints for checkpoint {} on task {}",
+						checkpointMetaData.getCheckpointId(), owner.getName());
+			}
+
+			startAsyncPartNano= System.nanoTime();
+
+			checkpointMetaData.setSyncDurationMillis((startAsyncPartNano - startSyncPartNano) / 1_000_000);
+
+			runAsyncCheckpointingAndAcknowledge();
+
+			if (LOG.isDebugEnabled()) {
+				LOG.debug("{} - finished synchronous part of checkpoint {}." +
+								"Alignment duration: {} ms, snapshot duration {} ms",
+						owner.getName(), checkpointMetaData.getCheckpointId(),
+						checkpointMetaData.getAlignmentDurationNanos() / 1_000_000,
+						checkpointMetaData.getSyncDurationMillis());
+			}
+		}
+
+		private void createStreamFactory(StreamOperator operator) throws IOException {
+			String operatorId = owner.createOperatorIdentifier(operator, owner.configuration.getVertexID());
+			this.streamFactory = owner.stateBackend.createStreamFactory(owner.getEnvironment().getJobID(), operatorId);
+		}
+
+		//TODO deprecated code path
+		private void snapshotNonPartitionableState(StreamOperator operator) throws Exception {
+
+			StreamStateHandle stateHandle = null;
+
+			if (operator instanceof StreamCheckpointedOperator) {
+
+				CheckpointStreamFactory.CheckpointStateOutputStream outStream =
+						streamFactory.createCheckpointStateOutputStream(
+								checkpointMetaData.getCheckpointId(), checkpointMetaData.getTimestamp());
+
+				owner.cancelables.registerClosable(outStream);
+
+				try {
+					((StreamCheckpointedOperator) operator).
+							snapshotState(
+									outStream,
+									checkpointMetaData.getCheckpointId(),
+									checkpointMetaData.getTimestamp());
+
+					stateHandle = outStream.closeAndGetHandle();
+				} finally {
+					owner.cancelables.unregisterClosable(outStream);
+					outStream.close();
+				}
+			}
+			nonPartitionedStates.add(stateHandle);
+		}
+
+		public void runAsyncCheckpointingAndAcknowledge() throws IOException {
+			AsyncCheckpointRunnable asyncCheckpointRunnable = new AsyncCheckpointRunnable(
+					owner,
+					nonPartitionedStates,
+					snapshotInProgressList,
+					checkpointMetaData,
+					startAsyncPartNano);
+
+			owner.cancelables.registerClosable(asyncCheckpointRunnable);
+			owner.asyncOperationsThreadPool.submit(asyncCheckpointRunnable);
 		}
 	}
 }
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperatorLifecycleTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperatorLifecycleTest.java
new file mode 100644
index 0000000000000..cbb833bc5c29e
--- /dev/null
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperatorLifecycleTest.java
@@ -0,0 +1,293 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.operators;
+
+import org.apache.flink.api.common.functions.RichFunction;
+import org.apache.flink.api.common.functions.RuntimeContext;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.core.testutils.OneShotLatch;
+import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
+import org.apache.flink.runtime.execution.ExecutionState;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.runtime.taskmanager.Task;
+import org.apache.flink.streaming.api.TimeCharacteristic;
+import org.apache.flink.streaming.api.functions.source.RichSourceFunction;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.streaming.api.graph.StreamConfig;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.runtime.tasks.SourceStreamTask;
+import org.apache.flink.streaming.runtime.tasks.StreamTask;
+import org.apache.flink.streaming.runtime.tasks.StreamTaskTest;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.io.Serializable;
+import java.lang.reflect.Method;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+
+import static org.junit.Assert.assertEquals;
+
+/**
+ * This test secures the lifecycle of AbstractUdfStreamOperator, including it's UDF handling.
+ */
+public class AbstractUdfStreamOperatorLifecycleTest {
+
+	private static final List EXPECTED_CALL_ORDER_FULL = Arrays.asList(
+			"OPERATOR::setup",
+			"UDF::setRuntimeContext",
+			"OPERATOR::initializeState",
+			"OPERATOR::open",
+			"UDF::open",
+			"OPERATOR::run",
+			"UDF::run",
+			"OPERATOR::snapshotState",
+			"OPERATOR::close",
+			"UDF::close",
+			"OPERATOR::dispose");
+
+	private static final List EXPECTED_CALL_ORDER_CANCEL_RUNNING = Arrays.asList(
+			"OPERATOR::setup",
+			"UDF::setRuntimeContext",
+			"OPERATOR::initializeState",
+			"OPERATOR::open",
+			"UDF::open",
+			"OPERATOR::run",
+			"UDF::run",
+			"OPERATOR::cancel",
+			"UDF::cancel",
+			"OPERATOR::dispose",
+			"UDF::close");
+
+	private static final String ALL_METHODS_STREAM_OPERATOR = "[close[], dispose[], getChainingStrategy[], " +
+			"getMetricGroup[], initializeState[class org.apache.flink.streaming.runtime.tasks.OperatorStateHandles], " +
+			"notifyOfCompletedCheckpoint[long], open[], setChainingStrategy[class " +
+			"org.apache.flink.streaming.api.operators.ChainingStrategy], setKeyContextElement1[class " +
+			"org.apache.flink.streaming.runtime.streamrecord.StreamRecord], " +
+			"setKeyContextElement2[class org.apache.flink.streaming.runtime.streamrecord.StreamRecord], " +
+			"setup[class org.apache.flink.streaming.runtime.tasks.StreamTask, class " +
+			"org.apache.flink.streaming.api.graph.StreamConfig, interface " +
+			"org.apache.flink.streaming.api.operators.Output], snapshotState[long, long, " +
+			"interface org.apache.flink.runtime.state.CheckpointStreamFactory]]";
+
+	private static final String ALL_METHODS_RICH_FUNCTION = "[close[], getIterationRuntimeContext[], getRuntimeContext[]" +
+			", open[class org.apache.flink.configuration.Configuration], setRuntimeContext[interface " +
+			"org.apache.flink.api.common.functions.RuntimeContext]]";
+
+	private static final List ACTUAL_ORDER_TRACKING =
+			Collections.synchronizedList(new ArrayList(EXPECTED_CALL_ORDER_FULL.size()));
+
+	@Test
+	public void testAllMethodsRegisteredInTest() {
+		List methodsWithSignatureString = new ArrayList<>();
+		for (Method method : StreamOperator.class.getMethods()) {
+			methodsWithSignatureString.add(method.getName() + Arrays.toString(method.getParameterTypes()));
+		}
+		Collections.sort(methodsWithSignatureString);
+		Assert.assertEquals("It seems like new methods have been introduced to " + StreamOperator.class +
+				". Please register them with this test and ensure to document their position in the lifecycle " +
+				"(if applicable).", ALL_METHODS_STREAM_OPERATOR, methodsWithSignatureString.toString());
+
+		methodsWithSignatureString = new ArrayList<>();
+		for (Method method : RichFunction.class.getMethods()) {
+			methodsWithSignatureString.add(method.getName() + Arrays.toString(method.getParameterTypes()));
+		}
+		Collections.sort(methodsWithSignatureString);
+		Assert.assertEquals("It seems like new methods have been introduced to " + RichFunction.class +
+				". Please register them with this test and ensure to document their position in the lifecycle " +
+				"(if applicable).", ALL_METHODS_RICH_FUNCTION, methodsWithSignatureString.toString());
+	}
+
+	@Test
+	public void testLifeCycleFull() throws Exception {
+		ACTUAL_ORDER_TRACKING.clear();
+
+		Configuration taskManagerConfig = new Configuration();
+		StreamConfig cfg = new StreamConfig(new Configuration());
+		MockSourceFunction srcFun = new MockSourceFunction();
+
+		cfg.setStreamOperator(new LifecycleTrackingStreamSource(srcFun, true));
+		cfg.setTimeCharacteristic(TimeCharacteristic.ProcessingTime);
+
+		Task task = StreamTaskTest.createTask(SourceStreamTask.class, cfg, taskManagerConfig);
+
+		task.startTaskThread();
+
+		LifecycleTrackingStreamSource.runStarted.await();
+
+		// wait for clean termination
+		task.getExecutingThread().join();
+		assertEquals(ExecutionState.FINISHED, task.getExecutionState());
+		assertEquals(EXPECTED_CALL_ORDER_FULL, ACTUAL_ORDER_TRACKING);
+	}
+
+	@Test
+	public void testLifeCycleCancel() throws Exception {
+		ACTUAL_ORDER_TRACKING.clear();
+
+		Configuration taskManagerConfig = new Configuration();
+		StreamConfig cfg = new StreamConfig(new Configuration());
+		MockSourceFunction srcFun = new MockSourceFunction();
+		cfg.setStreamOperator(new LifecycleTrackingStreamSource(srcFun, false));
+		cfg.setTimeCharacteristic(TimeCharacteristic.ProcessingTime);
+
+		Task task = StreamTaskTest.createTask(SourceStreamTask.class, cfg, taskManagerConfig);
+
+		task.startTaskThread();
+		LifecycleTrackingStreamSource.runStarted.await();
+
+		// this should cancel the task even though it is blocked on runFinished
+		task.cancelExecution();
+
+		// wait for clean termination
+		task.getExecutingThread().join();
+		assertEquals(ExecutionState.CANCELED, task.getExecutionState());
+		assertEquals(EXPECTED_CALL_ORDER_CANCEL_RUNNING, ACTUAL_ORDER_TRACKING);
+	}
+
+	private static class MockSourceFunction extends RichSourceFunction {
+
+		private static final long serialVersionUID = 1L;
+
+		@Override
+		public void run(SourceContext ctx) {
+			ACTUAL_ORDER_TRACKING.add("UDF::run");
+		}
+
+		@Override
+		public void cancel() {
+			ACTUAL_ORDER_TRACKING.add("UDF::cancel");
+		}
+
+		@Override
+		public void setRuntimeContext(RuntimeContext t) {
+			ACTUAL_ORDER_TRACKING.add("UDF::setRuntimeContext");
+			super.setRuntimeContext(t);
+		}
+
+		@Override
+		public void open(Configuration parameters) throws Exception {
+			ACTUAL_ORDER_TRACKING.add("UDF::open");
+			super.open(parameters);
+		}
+
+		@Override
+		public void close() throws Exception {
+			ACTUAL_ORDER_TRACKING.add("UDF::close");
+			super.close();
+		}
+	}
+
+	private static class LifecycleTrackingStreamSource>
+			extends StreamSource implements Serializable {
+
+		private static final long serialVersionUID = 2431488948886850562L;
+		private transient Thread testCheckpointer;
+
+		private final boolean simulateCheckpointing;
+
+		static OneShotLatch runStarted;
+		static OneShotLatch runFinish;
+
+		public LifecycleTrackingStreamSource(SRC sourceFunction, boolean simulateCheckpointing) {
+			super(sourceFunction);
+			this.simulateCheckpointing = simulateCheckpointing;
+			runStarted = new OneShotLatch();
+			runFinish = new OneShotLatch();
+		}
+
+		@Override
+		public void run(Object lockingObject, Output> collector) throws Exception {
+			ACTUAL_ORDER_TRACKING.add("OPERATOR::run");
+			super.run(lockingObject, collector);
+			runStarted.trigger();
+			runFinish.await();
+		}
+
+		@Override
+		public void setup(StreamTask containingTask, StreamConfig config, Output> output) {
+			ACTUAL_ORDER_TRACKING.add("OPERATOR::setup");
+			super.setup(containingTask, config, output);
+			if (simulateCheckpointing) {
+				testCheckpointer = new Thread() {
+					@Override
+					public void run() {
+						long id = 0;
+						while (true) {
+							try {
+								Thread.sleep(50);
+								if (getContainingTask().isCanceled() || getContainingTask().triggerCheckpoint(
+										new CheckpointMetaData(id++, System.currentTimeMillis()))) {
+									LifecycleTrackingStreamSource.runFinish.trigger();
+									break;
+								}
+							} catch (Exception e) {
+								e.printStackTrace();
+								Assert.fail();
+							}
+						}
+					}
+				};
+				testCheckpointer.start();
+			}
+		}
+
+		@Override
+		public void snapshotState(StateSnapshotContext context) throws Exception {
+			ACTUAL_ORDER_TRACKING.add("OPERATOR::snapshotState");
+			super.snapshotState(context);
+		}
+
+		@Override
+		public void initializeState(StateInitializationContext context) throws Exception {
+			ACTUAL_ORDER_TRACKING.add("OPERATOR::initializeState");
+			super.initializeState(context);
+		}
+
+		@Override
+		public void open() throws Exception {
+			ACTUAL_ORDER_TRACKING.add("OPERATOR::open");
+			super.open();
+		}
+
+		@Override
+		public void close() throws Exception {
+			ACTUAL_ORDER_TRACKING.add("OPERATOR::close");
+			super.close();
+		}
+
+		@Override
+		public void cancel() {
+			ACTUAL_ORDER_TRACKING.add("OPERATOR::cancel");
+			super.cancel();
+		}
+
+		@Override
+		public void dispose() throws Exception {
+			ACTUAL_ORDER_TRACKING.add("OPERATOR::dispose");
+			super.dispose();
+			if (simulateCheckpointing) {
+				testCheckpointer.join();
+			}
+		}
+	}
+}
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StateInitializationContextImplTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StateInitializationContextImplTest.java
new file mode 100644
index 0000000000000..75c2261da59b4
--- /dev/null
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StateInitializationContextImplTest.java
@@ -0,0 +1,260 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.operators;
+
+import org.apache.flink.api.common.state.KeyedStateStore;
+import org.apache.flink.api.common.state.OperatorStateStore;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.core.memory.ByteArrayOutputStreamWithPos;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.runtime.state.ClosableRegistry;
+import org.apache.flink.runtime.state.DefaultOperatorStateBackend;
+import org.apache.flink.runtime.state.KeyGroupRange;
+import org.apache.flink.runtime.state.KeyGroupRangeOffsets;
+import org.apache.flink.runtime.state.KeyGroupStatePartitionStreamProvider;
+import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.OperatorStateHandle;
+import org.apache.flink.runtime.state.StateInitializationContextImpl;
+import org.apache.flink.runtime.state.StatePartitionStreamProvider;
+import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
+import org.apache.flink.runtime.util.LongArrayList;
+import org.apache.flink.util.Preconditions;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+import static org.mockito.Mockito.mock;
+
+public class StateInitializationContextImplTest {
+
+	static final int NUM_HANDLES = 10;
+
+	private StateInitializationContextImpl initializationContext;
+	private ClosableRegistry closableRegistry;
+
+	private int writtenKeyGroups;
+	private Set writtenOperatorStates;
+
+	@Before
+	public void setUp() throws Exception {
+
+
+		this.writtenKeyGroups = 0;
+		this.writtenOperatorStates = new HashSet<>();
+
+		this.closableRegistry = new ClosableRegistry();
+		OperatorStateStore stateStore = mock(OperatorStateStore.class);
+
+		ByteArrayOutputStreamWithPos out = new ByteArrayOutputStreamWithPos(64);
+
+		List keyGroupsStateHandles = new ArrayList<>(NUM_HANDLES);
+		int prev = 0;
+		for (int i = 0; i < NUM_HANDLES; ++i) {
+			out.reset();
+			int size = i % 4;
+			int end = prev + size;
+			DataOutputView dov = new DataOutputViewStreamWrapper(out);
+			KeyGroupRangeOffsets offsets =
+					new KeyGroupRangeOffsets(i == 9 ? KeyGroupRange.EMPTY_KEY_GROUP_RANGE : new KeyGroupRange(prev, end));
+			prev = end + 1;
+			for (int kg : offsets.getKeyGroupRange()) {
+				offsets.setKeyGroupOffset(kg, out.getPosition());
+				dov.writeInt(kg);
+				++writtenKeyGroups;
+			}
+
+			KeyGroupsStateHandle handle =
+					new KeyGroupsStateHandle(offsets, new ByteStateHandleCloseChecking("kg-" + i, out.toByteArray()));
+
+			keyGroupsStateHandles.add(handle);
+		}
+
+		List operatorStateHandles = new ArrayList<>(NUM_HANDLES);
+
+		for (int i = 0; i < NUM_HANDLES; ++i) {
+			int size = i % 4;
+			out.reset();
+			DataOutputView dov = new DataOutputViewStreamWrapper(out);
+			LongArrayList offsets = new LongArrayList(size);
+			for (int s = 0; s < size; ++s) {
+				offsets.add(out.getPosition());
+				int val = i * NUM_HANDLES + s;
+				dov.writeInt(val);
+				writtenOperatorStates.add(val);
+			}
+
+			Map offsetsMap = new HashMap<>();
+			offsetsMap.put(DefaultOperatorStateBackend.DEFAULT_OPERATOR_STATE_NAME, offsets.toArray());
+			OperatorStateHandle operatorStateHandle =
+					new OperatorStateHandle(offsetsMap, new ByteStateHandleCloseChecking("os-" + i, out.toByteArray()));
+			operatorStateHandles.add(operatorStateHandle);
+		}
+
+		this.initializationContext =
+				new StateInitializationContextImpl(
+						true,
+						stateStore,
+						mock(KeyedStateStore.class),
+						keyGroupsStateHandles,
+						operatorStateHandles,
+						closableRegistry);
+	}
+
+	@Test
+	public void getOperatorStateStreams() throws Exception {
+
+	}
+
+	@Test
+	public void getKeyedStateStreams() throws Exception {
+
+		int readKeyGroupCount = 0;
+
+		for (KeyGroupStatePartitionStreamProvider stateStreamProvider
+				: initializationContext.getRawKeyedStateInputs()) {
+
+			Assert.assertNotNull(stateStreamProvider);
+
+			try (InputStream is = stateStreamProvider.getStream()) {
+				DataInputView div = new DataInputViewStreamWrapper(is);
+				int val = div.readInt();
+				++readKeyGroupCount;
+				Assert.assertEquals(stateStreamProvider.getKeyGroupId(), val);
+			}
+		}
+
+		Assert.assertEquals(writtenKeyGroups, readKeyGroupCount);
+	}
+
+	@Test
+	public void getOperatorStateStore() throws Exception {
+
+		Set readStatesCount = new HashSet<>();
+
+		for (StatePartitionStreamProvider statePartitionStreamProvider
+				: initializationContext.getRawOperatorStateInputs()) {
+
+			Assert.assertNotNull(statePartitionStreamProvider);
+
+			try (InputStream is = statePartitionStreamProvider.getStream()) {
+				DataInputView div = new DataInputViewStreamWrapper(is);
+				Assert.assertTrue(readStatesCount.add(div.readInt()));
+			}
+		}
+
+		Assert.assertEquals(writtenOperatorStates, readStatesCount);
+	}
+
+	@Test
+	public void close() throws Exception {
+
+		int count = 0;
+		int stopCount = NUM_HANDLES / 2;
+		boolean isClosed = false;
+
+
+		try {
+			for (KeyGroupStatePartitionStreamProvider stateStreamProvider
+					: initializationContext.getRawKeyedStateInputs()) {
+				Assert.assertNotNull(stateStreamProvider);
+
+				if (count == stopCount) {
+					initializationContext.close();
+					isClosed = true;
+				}
+
+				try (InputStream is = stateStreamProvider.getStream()) {
+					DataInputView div = new DataInputViewStreamWrapper(is);
+					try {
+						int val = div.readInt();
+						Assert.assertEquals(stateStreamProvider.getKeyGroupId(), val);
+						if (isClosed) {
+							Assert.fail("Close was ignored: stream");
+						}
+						++count;
+					} catch (IOException ioex) {
+						if (!isClosed) {
+							throw ioex;
+						}
+					}
+				}
+			}
+			Assert.fail("Close was ignored: registry");
+		} catch (IOException iex) {
+			Assert.assertTrue(isClosed);
+			Assert.assertEquals(stopCount, count);
+		}
+
+	}
+
+	static final class ByteStateHandleCloseChecking extends ByteStreamStateHandle {
+
+		private static final long serialVersionUID = -6201941296931334140L;
+
+		public ByteStateHandleCloseChecking(String handleName, byte[] data) {
+			super(handleName, data);
+		}
+
+		@Override
+		public FSDataInputStream openInputStream() throws IOException {
+			return new FSDataInputStream() {
+				private int index = 0;
+				private boolean closed = false;
+
+				@Override
+				public void seek(long desired) throws IOException {
+					Preconditions.checkArgument(desired >= 0 && desired < Integer.MAX_VALUE);
+					index = (int) desired;
+				}
+
+				@Override
+				public long getPos() throws IOException {
+					return index;
+				}
+
+				@Override
+				public int read() throws IOException {
+					if (closed) {
+						throw new IOException("Stream closed");
+					}
+					return index < data.length ? data[index++] & 0xFF : -1;
+				}
+
+				@Override
+				public void close() throws IOException {
+					super.close();
+					this.closed = true;
+				}
+			};
+		}
+	}
+
+}
\ No newline at end of file
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StateSnapshotContextSynchronousImplTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StateSnapshotContextSynchronousImplTest.java
new file mode 100644
index 0000000000000..0ee839e8a7cd5
--- /dev/null
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StateSnapshotContextSynchronousImplTest.java
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.operators;
+
+import org.apache.flink.runtime.state.CheckpointStreamFactory;
+import org.apache.flink.runtime.state.ClosableRegistry;
+import org.apache.flink.runtime.state.KeyGroupRange;
+import org.apache.flink.runtime.state.KeyedStateCheckpointOutputStream;
+import org.apache.flink.runtime.state.OperatorStateCheckpointOutputStream;
+import org.apache.flink.runtime.state.StateSnapshotContextSynchronousImpl;
+import org.apache.flink.runtime.state.memory.MemCheckpointStreamFactory;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+public class StateSnapshotContextSynchronousImplTest {
+
+	private StateSnapshotContextSynchronousImpl snapshotContext;
+
+	@Before
+	public void setUp() throws Exception {
+		ClosableRegistry closableRegistry = new ClosableRegistry();
+		CheckpointStreamFactory streamFactory = new MemCheckpointStreamFactory(1024);
+		KeyGroupRange keyGroupRange = new KeyGroupRange(0, 2);
+		this.snapshotContext = new StateSnapshotContextSynchronousImpl(42, 4711, streamFactory, keyGroupRange, closableRegistry);
+	}
+
+	@Test
+	public void testMetaData() {
+		Assert.assertEquals(42, snapshotContext.getCheckpointId());
+		Assert.assertEquals(4711, snapshotContext.getCheckpointTimestamp());
+	}
+
+	@Test
+	public void testCreateRawKeyedStateOutput() throws Exception {
+		KeyedStateCheckpointOutputStream stream = snapshotContext.getRawKeyedOperatorStateOutput();
+		Assert.assertNotNull(stream);
+	}
+
+	@Test
+	public void testCreateRawOperatorStateOutput() throws Exception {
+		OperatorStateCheckpointOutputStream stream = snapshotContext.getRawOperatorStateOutput();
+		Assert.assertNotNull(stream);
+	}
+}
\ No newline at end of file
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamOperatorSnapshotRestoreTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamOperatorSnapshotRestoreTest.java
new file mode 100644
index 0000000000000..ada0b86b461bd
--- /dev/null
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamOperatorSnapshotRestoreTest.java
@@ -0,0 +1,214 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.operators;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ValueState;
+import org.apache.flink.api.common.state.ValueStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.runtime.state.KeyGroupStatePartitionStreamProvider;
+import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.KeyedStateCheckpointOutputStream;
+import org.apache.flink.runtime.state.OperatorStateCheckpointOutputStream;
+import org.apache.flink.runtime.state.OperatorStateHandle;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StatePartitionStreamProvider;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.runtime.tasks.OperatorStateHandles;
+import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness;
+import org.apache.flink.util.FutureUtil;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.io.InputStream;
+import java.util.BitSet;
+import java.util.Collections;
+
+public class StreamOperatorSnapshotRestoreTest {
+
+	@Test
+	public void testOperatorStatesSnapshotRestore() throws Exception {
+
+		//-------------------------------------------------------------------------- snapshot
+
+		TestOneInputStreamOperator op = new TestOneInputStreamOperator(false);
+
+		KeyedOneInputStreamOperatorTestHarness testHarness =
+				new KeyedOneInputStreamOperatorTestHarness<>(op, new KeySelector() {
+					@Override
+					public Integer getKey(Integer value) throws Exception {
+						return value;
+					}
+				}, TypeInformation.of(Integer.class));
+
+		testHarness.open();
+
+		for (int i = 0; i < 10; ++i) {
+			testHarness.processElement(new StreamRecord<>(i));
+		}
+
+		OperatorSnapshotResult snapshotInProgress = testHarness.snapshot(1L, 1L);
+
+		KeyGroupsStateHandle keyedManaged =
+				FutureUtil.runIfNotDoneAndGet(snapshotInProgress.getKeyedStateManagedFuture());
+		KeyGroupsStateHandle keyedRaw =
+				FutureUtil.runIfNotDoneAndGet(snapshotInProgress.getKeyedStateRawFuture());
+
+		OperatorStateHandle opManaged =
+				FutureUtil.runIfNotDoneAndGet(snapshotInProgress.getOperatorStateManagedFuture());
+		OperatorStateHandle opRaw =
+				FutureUtil.runIfNotDoneAndGet(snapshotInProgress.getOperatorStateRawFuture());
+
+		testHarness.close();
+
+		//-------------------------------------------------------------------------- restore
+
+		op = new TestOneInputStreamOperator(true);
+		testHarness = new KeyedOneInputStreamOperatorTestHarness<>(op, new KeySelector() {
+			@Override
+			public Integer getKey(Integer value) throws Exception {
+				return value;
+			}
+		}, TypeInformation.of(Integer.class));
+
+		testHarness.initializeState(new OperatorStateHandles(
+				0,
+				null,
+				Collections.singletonList(keyedManaged),
+				Collections.singletonList(keyedRaw),
+				Collections.singletonList(opManaged),
+				Collections.singletonList(opRaw)));
+
+		testHarness.open();
+
+		for (int i = 0; i < 10; ++i) {
+			testHarness.processElement(new StreamRecord<>(i));
+		}
+
+		testHarness.close();
+	}
+
+	static class TestOneInputStreamOperator
+			extends AbstractStreamOperator
+			implements OneInputStreamOperator {
+
+		private static final long serialVersionUID = -8942866418598856475L;
+
+		public TestOneInputStreamOperator(boolean verifyRestore) {
+			this.verifyRestore = verifyRestore;
+		}
+
+		private boolean verifyRestore;
+		private ValueState keyedState;
+		private ListState opState;
+
+		@Override
+		public void processElement(StreamRecord element) throws Exception {
+			if (verifyRestore) {
+				// check restored managed keyed state
+				long exp = element.getValue() + 1;
+				long act = keyedState.value();
+				Assert.assertEquals(exp, act);
+			} else {
+				// write managed keyed state that goes into snapshot
+				keyedState.update(element.getValue() + 1);
+				// write managed operator state that goes into snapshot
+				opState.add(element.getValue());
+			}
+		}
+
+		@Override
+		public void processWatermark(Watermark mark) throws Exception {
+
+		}
+
+		@Override
+		public void snapshotState(StateSnapshotContext context) throws Exception {
+
+			KeyedStateCheckpointOutputStream out = context.getRawKeyedOperatorStateOutput();
+			DataOutputView dov = new DataOutputViewStreamWrapper(out);
+
+			// write raw keyed state that goes into snapshot
+			int count = 0;
+			for (int kg : out.getKeyGroupList()) {
+				out.startNewKeyGroup(kg);
+				dov.writeInt(kg + 2);
+				++count;
+			}
+
+			Assert.assertEquals(KeyedOneInputStreamOperatorTestHarness.MAX_PARALLELISM, count);
+
+			// write raw operator state that goes into snapshot
+			OperatorStateCheckpointOutputStream outOp = context.getRawOperatorStateOutput();
+			dov = new DataOutputViewStreamWrapper(outOp);
+			for (int i = 0; i < 13; ++i) {
+				outOp.startNewPartition();
+				dov.writeInt(42 + i);
+			}
+		}
+
+		@Override
+		public void initializeState(StateInitializationContext context) throws Exception {
+
+			Assert.assertEquals(verifyRestore, context.isRestored());
+
+			keyedState = context.getManagedKeyedStateStore().getState(new ValueStateDescriptor<>("managed-keyed", Integer.class, 0));
+			opState = context.getManagedOperatorStateStore().getSerializableListState("managed-op-state");
+
+			if (context.isRestored()) {
+				// check restored raw keyed state
+				int count = 0;
+				for (KeyGroupStatePartitionStreamProvider streamProvider : context.getRawKeyedStateInputs()) {
+					try (InputStream in = streamProvider.getStream()) {
+						DataInputView div = new DataInputViewStreamWrapper(in);
+						Assert.assertEquals(streamProvider.getKeyGroupId() + 2, div.readInt());
+						++count;
+					}
+				}
+				Assert.assertEquals(KeyedOneInputStreamOperatorTestHarness.MAX_PARALLELISM, count);
+
+				// check restored managed operator state
+				BitSet check = new BitSet(10);
+				for (int v : opState.get()) {
+					check.set(v);
+				}
+
+				Assert.assertEquals(10, check.cardinality());
+
+				// check restored raw operator state
+				check = new BitSet(13);
+				for (StatePartitionStreamProvider streamProvider : context.getRawOperatorStateInputs()) {
+					try (InputStream in = streamProvider.getStream()) {
+						DataInputView div = new DataInputViewStreamWrapper(in);
+						check.set(div.readInt() - 42);
+					}
+				}
+				Assert.assertEquals(13, check.cardinality());
+			}
+		}
+	}
+
+}
\ No newline at end of file
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContextTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContextTest.java
index 02409a39a2735..ff7acd76026ed 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContextTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContextTest.java
@@ -37,6 +37,7 @@
 import org.apache.flink.runtime.operators.testutils.DummyEnvironment;
 import org.apache.flink.runtime.query.KvStateRegistry;
 import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
+import org.apache.flink.runtime.state.DefaultKeyedStateStore;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.VoidNamespace;
 import org.apache.flink.runtime.state.VoidNamespaceSerializer;
@@ -53,7 +54,9 @@
 import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertTrue;
 import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.spy;
 import static org.mockito.Mockito.when;
 
 public class StreamingRuntimeContextTest {
@@ -160,18 +163,22 @@ private static AbstractStreamOperator createDescriptorCapturingMockOp(
 			final AtomicReference ref, final ExecutionConfig config) throws Exception {
 		
 		AbstractStreamOperator operatorMock = mock(AbstractStreamOperator.class);
+
+		DefaultKeyedStateStore keyedStateStore = spy(new DefaultKeyedStateStore(mock(AbstractKeyedStateBackend.class), config));
+
 		when(operatorMock.getExecutionConfig()).thenReturn(config);
-		
-		when(operatorMock.getPartitionedState(any(StateDescriptor.class))).thenAnswer(
-				new Answer() {
-					
-					@Override
-					public Object answer(InvocationOnMock invocationOnMock) throws Throwable {
-						ref.set(invocationOnMock.getArguments()[0]);
-						return null;
-					}
-				});
-		
+
+		doAnswer(new Answer() {
+
+			@Override
+			public Object answer(InvocationOnMock invocationOnMock) throws Throwable {
+				ref.set(invocationOnMock.getArguments()[0]);
+				return null;
+			}
+		}).when(keyedStateStore).getPartitionedState(any(StateDescriptor.class));
+
+		when(operatorMock.getKeyedStateStore()).thenReturn(keyedStateStore);
+
 		return operatorMock;
 	}
 
@@ -179,29 +186,32 @@ public Object answer(InvocationOnMock invocationOnMock) throws Throwable {
 	private static AbstractStreamOperator createPlainMockOp() throws Exception {
 
 		AbstractStreamOperator operatorMock = mock(AbstractStreamOperator.class);
-		when(operatorMock.getExecutionConfig()).thenReturn(new ExecutionConfig());
-
-		when(operatorMock.getPartitionedState(any(ListStateDescriptor.class))).thenAnswer(
-				new Answer>() {
-
-					@Override
-					public ListState answer(InvocationOnMock invocationOnMock) throws Throwable {
-						ListStateDescriptor descr =
-								(ListStateDescriptor) invocationOnMock.getArguments()[0];
-
-						AbstractKeyedStateBackend backend = new MemoryStateBackend().createKeyedStateBackend(
-								new DummyEnvironment("test_task", 1, 0),
-								new JobID(),
-								"test_op",
-								IntSerializer.INSTANCE,
-								1,
-								new KeyGroupRange(0, 0),
-								new KvStateRegistry().createTaskRegistry(new JobID(), new JobVertexID()));
-						backend.setCurrentKey(0);
-						return backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, descr);
-					}
-				});
+		ExecutionConfig config = new ExecutionConfig();
+		DefaultKeyedStateStore keyedStateStore = spy(new DefaultKeyedStateStore(mock(AbstractKeyedStateBackend.class), config));
+
+		when(operatorMock.getExecutionConfig()).thenReturn(config);
 
+		doAnswer(new Answer>() {
+
+			@Override
+			public ListState answer(InvocationOnMock invocationOnMock) throws Throwable {
+				ListStateDescriptor descr =
+						(ListStateDescriptor) invocationOnMock.getArguments()[0];
+
+				AbstractKeyedStateBackend backend = new MemoryStateBackend().createKeyedStateBackend(
+						new DummyEnvironment("test_task", 1, 0),
+						new JobID(),
+						"test_op",
+						IntSerializer.INSTANCE,
+						1,
+						new KeyGroupRange(0, 0),
+						new KvStateRegistry().createTaskRegistry(new JobID(), new JobVertexID()));
+				backend.setCurrentKey(0);
+				return backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, descr);
+			}
+		}).when(keyedStateStore).getPartitionedState(any(ListStateDescriptor.class));
+
+		when(operatorMock.getKeyedStateStore()).thenReturn(keyedStateStore);
 		return operatorMock;
 	}
 	
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierBufferTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierBufferTest.java
index 59242e8362f7d..f2fc876ea74b9 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierBufferTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierBufferTest.java
@@ -32,6 +32,7 @@
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
+import org.apache.flink.runtime.state.TaskStateHandles;
 import org.junit.AfterClass;
 import org.junit.BeforeClass;
 import org.junit.Test;
@@ -973,10 +974,7 @@ public long getNextExpectedCheckpointId() {
 		}
 
 		@Override
-		public void setInitialState(
-				ChainedStateHandle chainedState,
-				List keyGroupsState,
-				List> partitionableOperatorState) throws Exception {
+		public void setInitialState(TaskStateHandles taskStateHandles) throws Exception {
 			throw new UnsupportedOperationException("should never be called");
 		}
 
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierTrackerTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierTrackerTest.java
index b6d0450e458d8..7cfbb66234879 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierTrackerTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierTrackerTest.java
@@ -29,6 +29,7 @@
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
+import org.apache.flink.runtime.state.TaskStateHandles;
 import org.junit.Test;
 
 import java.util.Arrays;
@@ -366,10 +367,7 @@ private CheckpointSequenceValidator(long... checkpointIDs) {
 		}
 
 		@Override
-		public void setInitialState(
-				ChainedStateHandle chainedState,
-				List keyGroupsState,
-				List> partitionableOperatorState) throws Exception {
+		public void setInitialState(TaskStateHandles taskStateHandles) throws Exception {
 
 			throw new UnsupportedOperationException("should never be called");
 		}
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/GenericWriteAheadSinkTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/GenericWriteAheadSinkTest.java
index d1ba4898bcbd8..e5e26e9e6a597 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/GenericWriteAheadSinkTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/GenericWriteAheadSinkTest.java
@@ -129,7 +129,7 @@ public void testCommitterException() throws Exception {
 			elementCounter++;
 		}
 
-		testHarness.snapshot(0, 0);
+		testHarness.snapshotLegacy(0, 0);
 		testHarness.notifyOfCompletedCheckpoint(0);
 
 		//isCommitted should have failed, thus sendValues() should never have been called
@@ -140,7 +140,7 @@ public void testCommitterException() throws Exception {
 			elementCounter++;
 		}
 
-		testHarness.snapshot(1, 0);
+		testHarness.snapshotLegacy(1, 0);
 		testHarness.notifyOfCompletedCheckpoint(1);
 
 		//previous CP should be retried, but will fail the CP commit. Second CP should be skipped.
@@ -151,7 +151,7 @@ public void testCommitterException() throws Exception {
 			elementCounter++;
 		}
 
-		testHarness.snapshot(2, 0);
+		testHarness.snapshotLegacy(2, 0);
 		testHarness.notifyOfCompletedCheckpoint(2);
 
 		//all CP's should be retried and succeed; since one CP was written twice we have 2 * 10 + 10 + 10 = 40 values
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/WriteAheadSinkTestBase.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/WriteAheadSinkTestBase.java
index b7203b5ac0b33..3d1e6e8926dcc 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/WriteAheadSinkTestBase.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/WriteAheadSinkTestBase.java
@@ -66,7 +66,7 @@ public void testIdealCircumstances() throws Exception {
 			elementCounter++;
 		}
 
-		testHarness.snapshot(snapshotCount++, 0);
+		testHarness.snapshotLegacy(snapshotCount++, 0);
 		testHarness.notifyOfCompletedCheckpoint(snapshotCount - 1);
 
 		for (int x = 0; x < 20; x++) {
@@ -74,7 +74,7 @@ public void testIdealCircumstances() throws Exception {
 			elementCounter++;
 		}
 
-		testHarness.snapshot(snapshotCount++, 0);
+		testHarness.snapshotLegacy(snapshotCount++, 0);
 		testHarness.notifyOfCompletedCheckpoint(snapshotCount - 1);
 
 		for (int x = 0; x < 20; x++) {
@@ -82,7 +82,7 @@ public void testIdealCircumstances() throws Exception {
 			elementCounter++;
 		}
 
-		testHarness.snapshot(snapshotCount++, 0);
+		testHarness.snapshotLegacy(snapshotCount++, 0);
 		testHarness.notifyOfCompletedCheckpoint(snapshotCount - 1);
 
 		verifyResultsIdealCircumstances(testHarness, sink);
@@ -105,7 +105,7 @@ public void testDataPersistenceUponMissedNotify() throws Exception {
 			elementCounter++;
 		}
 
-		testHarness.snapshot(snapshotCount++, 0);
+		testHarness.snapshotLegacy(snapshotCount++, 0);
 		testHarness.notifyOfCompletedCheckpoint(snapshotCount - 1);
 
 		for (int x = 0; x < 20; x++) {
@@ -113,14 +113,14 @@ public void testDataPersistenceUponMissedNotify() throws Exception {
 			elementCounter++;
 		}
 
-		testHarness.snapshot(snapshotCount++, 0);
+		testHarness.snapshotLegacy(snapshotCount++, 0);
 
 		for (int x = 0; x < 20; x++) {
 			testHarness.processElement(new StreamRecord<>(generateValue(elementCounter, 2)));
 			elementCounter++;
 		}
 
-		testHarness.snapshot(snapshotCount++, 0);
+		testHarness.snapshotLegacy(snapshotCount++, 0);
 		testHarness.notifyOfCompletedCheckpoint(snapshotCount - 1);
 
 		verifyResultsDataPersistenceUponMissedNotify(testHarness, sink);
@@ -143,7 +143,7 @@ public void testDataDiscardingUponRestore() throws Exception {
 			elementCounter++;
 		}
 
-		StreamStateHandle latestSnapshot = testHarness.snapshot(snapshotCount++, 0);
+		StreamStateHandle latestSnapshot = testHarness.snapshotLegacy(snapshotCount++, 0);
 		testHarness.notifyOfCompletedCheckpoint(snapshotCount - 1);
 
 		for (int x = 0; x < 20; x++) {
@@ -166,7 +166,7 @@ public void testDataDiscardingUponRestore() throws Exception {
 			elementCounter++;
 		}
 
-		testHarness.snapshot(snapshotCount++, 0);
+		testHarness.snapshotLegacy(snapshotCount++, 0);
 		testHarness.notifyOfCompletedCheckpoint(snapshotCount - 1);
 
 		verifyResultsDataDiscardingUponRestore(testHarness, sink);
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AccumulatingAlignedProcessingTimeWindowOperatorTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AccumulatingAlignedProcessingTimeWindowOperatorTest.java
index 51e61a1ed9f9a..e96109e11fd2f 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AccumulatingAlignedProcessingTimeWindowOperatorTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AccumulatingAlignedProcessingTimeWindowOperatorTest.java
@@ -522,7 +522,7 @@ public void checkpointRestoreWithPendingWindowTumbling() {
 
 			// draw a snapshot and dispose the window
 			int beforeSnapShot = testHarness.getOutput().size();
-			StreamStateHandle state = testHarness.snapshot(1L, System.currentTimeMillis());
+			StreamStateHandle state = testHarness.snapshotLegacy(1L, System.currentTimeMillis());
 			List resultAtSnapshot = extractFromStreamRecords(testHarness.getOutput());
 			int afterSnapShot = testHarness.getOutput().size();
 			assertEquals("operator performed computation during snapshot", beforeSnapShot, afterSnapShot);
@@ -611,7 +611,7 @@ public void checkpointRestoreWithPendingWindowSliding() {
 			// draw a snapshot
 			List resultAtSnapshot = extractFromStreamRecords(testHarness.getOutput());
 			int beforeSnapShot = testHarness.getOutput().size();
-			StreamStateHandle state = testHarness.snapshot(1L, System.currentTimeMillis());
+			StreamStateHandle state = testHarness.snapshotLegacy(1L, System.currentTimeMillis());
 			int afterSnapShot = testHarness.getOutput().size();
 			assertEquals("operator performed computation during snapshot", beforeSnapShot, afterSnapShot);
 
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AggregatingAlignedProcessingTimeWindowOperatorTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AggregatingAlignedProcessingTimeWindowOperatorTest.java
index 12a842f415c52..802329b50fd3a 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AggregatingAlignedProcessingTimeWindowOperatorTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AggregatingAlignedProcessingTimeWindowOperatorTest.java
@@ -634,7 +634,7 @@ public void checkpointRestoreWithPendingWindowTumbling() {
 			// draw a snapshot
 			List> resultAtSnapshot = extractFromStreamRecords(testHarness.getOutput());
 			int beforeSnapShot = resultAtSnapshot.size();
-			StreamStateHandle state = testHarness.snapshot(1L, System.currentTimeMillis());
+			StreamStateHandle state = testHarness.snapshotLegacy(1L, System.currentTimeMillis());
 			int afterSnapShot = testHarness.getOutput().size();
 			assertEquals("operator performed computation during snapshot", beforeSnapShot, afterSnapShot);
 
@@ -727,7 +727,7 @@ public void checkpointRestoreWithPendingWindowSliding() {
 			// draw a snapshot
 			List> resultAtSnapshot = extractFromStreamRecords(testHarness.getOutput());
 			int beforeSnapShot = resultAtSnapshot.size();
-			StreamStateHandle state = testHarness.snapshot(1L, System.currentTimeMillis());
+			StreamStateHandle state = testHarness.snapshotLegacy(1L, System.currentTimeMillis());
 			int afterSnapShot = testHarness.getOutput().size();
 			assertEquals("operator performed computation during snapshot", beforeSnapShot, afterSnapShot);
 
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/WindowOperatorTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/WindowOperatorTest.java
index ba803e3f793c0..2b0b915da9b3c 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/WindowOperatorTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/WindowOperatorTest.java
@@ -125,7 +125,7 @@ private void testSlidingEventTimeWindows(OneInputStreamOperatorTestHarness(new Tuple2<>("key1", 2), 1000));
 
 		// do a snapshot, close and restore again
-		StreamStateHandle snapshot = testHarness.snapshot(0L, 0L);
+		StreamStateHandle snapshot = testHarness.snapshotLegacy(0L, 0L);
 		testHarness.close();
 		testHarness.setup();
 		testHarness.restore(snapshot);
@@ -465,7 +465,7 @@ public void testReduceSessionWindows() throws Exception {
 		testHarness.processElement(new StreamRecord<>(new Tuple2<>("key2", 3), 2500));
 
 		// do a snapshot, close and restore again
-		StreamStateHandle snapshot = testHarness.snapshot(0L, 0L);
+		StreamStateHandle snapshot = testHarness.snapshotLegacy(0L, 0L);
 		testHarness.close();
 		testHarness.setup();
 		testHarness.restore(snapshot);
@@ -543,7 +543,7 @@ public void testSessionWindowsWithCountTrigger() throws Exception {
 		testHarness.processElement(new StreamRecord<>(new Tuple2<>("key1", 2), 1000));
 
 		// do a snapshot, close and restore again
-		StreamStateHandle snapshot = testHarness.snapshot(0L, 0L);
+		StreamStateHandle snapshot = testHarness.snapshotLegacy(0L, 0L);
 		testHarness.close();
 		testHarness.setup();
 		testHarness.restore(snapshot);
@@ -641,7 +641,7 @@ public void testPointSessions() throws Exception {
 		testHarness.processElement(new StreamRecord<>(new Tuple2<>("key2", 33), 1000));
 
 		// do a snapshot, close and restore again
-		StreamStateHandle snapshot = testHarness.snapshot(0L, 0L);
+		StreamStateHandle snapshot = testHarness.snapshotLegacy(0L, 0L);
 		testHarness.close();
 		testHarness.setup();
 		testHarness.restore(snapshot);
@@ -796,7 +796,7 @@ public void testCountTrigger() throws Exception {
 		testHarness.processElement(new StreamRecord<>(new Tuple2<>("key2", 1), 1999));
 
 		// do a snapshot, close and restore again
-		StreamStateHandle snapshot = testHarness.snapshot(0L, 0L);
+		StreamStateHandle snapshot = testHarness.snapshotLegacy(0L, 0L);
 
 		testHarness.close();
 
@@ -884,7 +884,7 @@ operator, new ExecutionConfig(), timer,
 		operator.processingTimeTimersQueue.add(timer2);
 		operator.processingTimeTimersQueue.add(timer3);
 		
-		StreamStateHandle snapshot = testHarness.snapshot(0, 0);
+		StreamStateHandle snapshot = testHarness.snapshotLegacy(0, 0);
 
 		WindowOperator, Tuple2, Tuple2, TimeWindow> otherOperator = new WindowOperator<>(
 				SlidingEventTimeWindows.of(Time.of(WINDOW_SIZE, TimeUnit.SECONDS), Time.of(WINDOW_SLIDE, TimeUnit.SECONDS)),
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java
index b5b6582939a1e..ee5a203551995 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java
@@ -45,6 +45,7 @@
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
+import org.apache.flink.runtime.state.TaskStateHandles;
 import org.apache.flink.runtime.taskmanager.CheckpointResponder;
 import org.apache.flink.runtime.taskmanager.Task;
 import org.apache.flink.runtime.taskmanager.TaskManagerConnection;
@@ -124,8 +125,17 @@ private static TaskDeploymentDescriptor createTaskDeploymentDescriptor(
 			StreamStateHandle state) throws IOException {
 
 		ChainedStateHandle operatorState = new ChainedStateHandle<>(Collections.singletonList(state));
-		List keyGroupState = Collections.emptyList();
-		List> partitionableOperatorState = Collections.emptyList();
+		List keyGroupStateFromBackend = Collections.emptyList();
+		List keyGroupStateFromStream = Collections.emptyList();
+		List> operatorStateBackend = Collections.emptyList();
+		List> operatorStateStream = Collections.emptyList();
+
+		TaskStateHandles taskStateHandles = new TaskStateHandles(
+				operatorState,
+				operatorStateBackend,
+				operatorStateStream,
+				keyGroupStateFromBackend,
+				keyGroupStateFromStream);
 
 		return new TaskDeploymentDescriptor(
 				new JobID(),
@@ -143,9 +153,7 @@ private static TaskDeploymentDescriptor createTaskDeploymentDescriptor(
 				Collections.emptyList(),
 				Collections.emptyList(),
 				0,
-				operatorState,
-				keyGroupState,
-				partitionableOperatorState);
+				taskStateHandles);
 	}
 
 	private static Task createTask(TaskDeploymentDescriptor tdd) throws IOException {
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java
index 1b2b723505e96..3dd2ed7ba993b 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java
@@ -31,15 +31,13 @@
 import org.apache.flink.core.fs.FSDataOutputStream;
 import org.apache.flink.core.testutils.OneShotLatch;
 import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
+import org.apache.flink.runtime.checkpoint.SubtaskState;
 import org.apache.flink.runtime.io.network.api.CheckpointBarrier;
 import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
 import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider;
-import org.apache.flink.runtime.state.ChainedStateHandle;
-import org.apache.flink.runtime.state.CheckpointStateHandles;
-import org.apache.flink.runtime.state.CheckpointStreamFactory;
-import org.apache.flink.runtime.state.KeyGroupsStateHandle;
-import org.apache.flink.runtime.state.OperatorStateHandle;
-import org.apache.flink.runtime.state.StreamStateHandle;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.runtime.state.TaskStateHandles;
 import org.apache.flink.streaming.api.graph.StreamConfig;
 import org.apache.flink.streaming.api.graph.StreamEdge;
 import org.apache.flink.streaming.api.graph.StreamNode;
@@ -64,8 +62,6 @@
 
 import java.io.Serializable;
 import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashSet;
@@ -74,7 +70,6 @@
 import java.util.Random;
 import java.util.Set;
 import java.util.concurrent.ConcurrentLinkedQueue;
-import java.util.concurrent.RunnableFuture;
 import java.util.concurrent.TimeUnit;
 
 import static org.junit.Assert.assertEquals;
@@ -390,7 +385,7 @@ public void testSnapshottingAndRestoring() throws Exception {
 		testHarness.waitForTaskCompletion(deadline.timeLeft().toMillis());
 
 		final OneInputStreamTask restoredTask = new OneInputStreamTask();
-		restoredTask.setInitialState(env.getState(), env.getKeyGroupStates(), env.getPartitionableOperatorState());
+		restoredTask.setInitialState(new TaskStateHandles(env.getCheckpointStateHandles()));
 
 		final OneInputStreamTaskTestHarness restoredTaskHarness = new OneInputStreamTaskTestHarness(restoredTask, BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO);
 		restoredTaskHarness.configureForKeyedStream(keySelector, BasicTypeInfo.STRING_TYPE_INFO);
@@ -482,9 +477,7 @@ public IN getKey(IN value) throws Exception {
 
 	private static class AcknowledgeStreamMockEnvironment extends StreamMockEnvironment {
 		private volatile long checkpointId;
-		private volatile ChainedStateHandle state;
-		private volatile List keyGroupStates;
-		private volatile List> partitionableOperatorState;
+		private volatile SubtaskState checkpointStateHandles;
 
 		private final OneShotLatch checkpointLatch = new OneShotLatch();
 
@@ -492,24 +485,6 @@ public long getCheckpointId() {
 			return checkpointId;
 		}
 
-		public ChainedStateHandle getState() {
-			return state;
-		}
-
-		List getKeyGroupStates() {
-			List result = new ArrayList<>();
-			for (KeyGroupsStateHandle keyGroupState : keyGroupStates) {
-				if (keyGroupState != null) {
-					result.add(keyGroupState);
-				}
-			}
-			return result;
-		}
-
-		List> getPartitionableOperatorState() {
-			return partitionableOperatorState;
-		}
-
 		AcknowledgeStreamMockEnvironment(
 				Configuration jobConfig, Configuration taskConfig,
 				ExecutionConfig executionConfig, long memorySize,
@@ -521,26 +496,20 @@ List> getPartitionableOperatorState() {
 		@Override
 		public void acknowledgeCheckpoint(
 				CheckpointMetaData checkpointMetaData,
-				CheckpointStateHandles checkpointStateHandles) {
+				SubtaskState checkpointStateHandles) {
 
 			this.checkpointId = checkpointMetaData.getCheckpointId();
-			if(checkpointStateHandles != null) {
-				this.state = checkpointStateHandles.getNonPartitionedStateHandles();
-				this.keyGroupStates = checkpointStateHandles.getKeyGroupsStateHandle();
-				ChainedStateHandle chainedStateHandle = checkpointStateHandles.getPartitioneableStateHandles();
-				Collection[] ia = new Collection[chainedStateHandle.getLength()];
-				this.partitionableOperatorState = Arrays.asList(ia);
-
-				for (int i = 0; i < chainedStateHandle.getLength(); ++i) {
-					partitionableOperatorState.set(i, Collections.singletonList(chainedStateHandle.get(i)));
-				}
-			}
+			this.checkpointStateHandles = checkpointStateHandles;
 			checkpointLatch.trigger();
 		}
 
 		public OneShotLatch getCheckpointLatch() {
 			return checkpointLatch;
 		}
+
+		public SubtaskState getCheckpointStateHandles() {
+			return checkpointStateHandles;
+		}
 	}
 
 	private static class TestingStreamOperator
@@ -580,9 +549,7 @@ public void open() throws Exception {
 		}
 
 		@Override
-		public RunnableFuture snapshotState(
-				long checkpointId, long timestamp, CheckpointStreamFactory streamFactory) throws Exception {
-
+		public void snapshotState(StateSnapshotContext context) throws Exception {
 			ListState partitionableState =
 					getOperatorStateBackend().getOperatorState(TEST_DESCRIPTOR);
 			partitionableState.clear();
@@ -591,7 +558,11 @@ public RunnableFuture snapshotState(
 			partitionableState.add(4711);
 
 			++numberSnapshotCalls;
-			return super.snapshotState(checkpointId, timestamp, streamFactory);
+		}
+
+		@Override
+		public void initializeState(StateInitializationContext context) throws Exception {
+
 		}
 
 		TestingStreamOperator(long seed, long recoveryTimestamp) {
@@ -643,7 +614,6 @@ public void restoreState(FSDataInputStream in) throws Exception {
 			assertEquals(random.nextInt(), (int) operatorState);
 		}
 
-
 		private Serializable generateFunctionState() {
 			return random.nextInt();
 		}
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java
index f852682288a30..36ecf597f3b55 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java
@@ -28,6 +28,7 @@
 import org.apache.flink.runtime.accumulators.AccumulatorRegistry;
 import org.apache.flink.runtime.broadcast.BroadcastVariableManager;
 import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
+import org.apache.flink.runtime.checkpoint.SubtaskState;
 import org.apache.flink.runtime.event.AbstractEvent;
 import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
@@ -50,7 +51,6 @@
 import org.apache.flink.runtime.plugable.NonReusingDeserializationDelegate;
 import org.apache.flink.runtime.query.KvStateRegistry;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
-import org.apache.flink.runtime.state.CheckpointStateHandles;
 import org.apache.flink.runtime.taskmanager.TaskManagerRuntimeInfo;
 import org.mockito.invocation.InvocationOnMock;
 import org.mockito.stubbing.Answer;
@@ -123,7 +123,7 @@ public StreamMockEnvironment(Configuration jobConfig, Configuration taskConfig,
 
 	public StreamMockEnvironment(Configuration jobConfig, Configuration taskConfig, long memorySize,
 								 MockInputSplitProvider inputSplitProvider, int bufferSize) {
-		this(jobConfig, taskConfig, null, memorySize, inputSplitProvider, bufferSize);
+		this(jobConfig, taskConfig, new ExecutionConfig(), memorySize, inputSplitProvider, bufferSize);
 	}
 
 	public void addInputGate(InputGate gate) {
@@ -313,7 +313,7 @@ public void acknowledgeCheckpoint(CheckpointMetaData checkpointMetaData) {
 
 	@Override
 	public void acknowledgeCheckpoint(
-			CheckpointMetaData checkpointMetaData, CheckpointStateHandles checkpointStateHandles) {
+			CheckpointMetaData checkpointMetaData, SubtaskState subtaskState) {
 	}
 
 	@Override
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java
index 8aae19fa4e247..94f6d5ae9c6c4 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java
@@ -200,13 +200,13 @@ public void notifyTaskExecutionStateChanged(TaskExecutionState taskExecutionStat
 		}
 	}
 
-	private Task createTask(
+	public static Task createTask(
 			Class invokable,
 			StreamConfig taskConfig,
 			Configuration taskManagerConfig) throws Exception {
 
 		LibraryCacheManager libCache = mock(LibraryCacheManager.class);
-		when(libCache.getClassLoader(any(JobID.class))).thenReturn(getClass().getClassLoader());
+		when(libCache.getClassLoader(any(JobID.class))).thenReturn(StreamTaskTest.class.getClassLoader());
 		
 		ResultPartitionManager partitionManager = mock(ResultPartitionManager.class);
 		ResultPartitionConsumableNotifier consumableNotifier = mock(ResultPartitionConsumableNotifier.class);
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TwoInputStreamTaskTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TwoInputStreamTaskTest.java
index b9211b10b5e10..1bb3fb0e0c693 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TwoInputStreamTaskTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TwoInputStreamTaskTest.java
@@ -214,6 +214,17 @@ public void testCheckpointBarriers() throws Exception {
 		expectedOutput.add(new StreamRecord("111", initialTime));
 
 		testHarness.waitForInputProcessing();
+
+		// Wait to allow input to end up in the output.
+		// TODO Use count down latches instead as a cleaner solution
+		for (int i = 0; i < 20; ++i) {
+			if (testHarness.getOutput().size() >= expectedOutput.size()) {
+				break;
+			} else {
+				Thread.sleep(100);
+			}
+		}
+
 		// we should not yet see the barrier, only the two elements from non-blocked input
 		TestHarnessUtil.assertOutputEquals("Output was not correct.",
 				testHarness.getOutput(),
@@ -224,17 +235,17 @@ public void testCheckpointBarriers() throws Exception {
 		testHarness.processEvent(new CheckpointBarrier(0, 0), 1, 1);
 
 		testHarness.waitForInputProcessing();
+		testHarness.endInput();
+		testHarness.waitForTaskCompletion();
 
 		// now we should see the barrier and after that the buffered elements
 		expectedOutput.add(new CheckpointBarrier(0, 0));
 		expectedOutput.add(new StreamRecord("Hello-0-0", initialTime));
-		TestHarnessUtil.assertOutputEquals("Output was not correct.",
-				testHarness.getOutput(),
-				expectedOutput);
 
-		testHarness.endInput();
+		TestHarnessUtil.assertOutputEquals("Output was not correct.",
+				expectedOutput,
+				testHarness.getOutput());
 
-		testHarness.waitForTaskCompletion();
 
 		List resultElements = TestHarnessUtil.getRawElementsFromOutput(testHarness.getOutput());
 		Assert.assertEquals(4, resultElements.size());
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedOneInputStreamOperatorTestHarness.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedOneInputStreamOperatorTestHarness.java
index 5275a39c50f9f..41968e6193372 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedOneInputStreamOperatorTestHarness.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedOneInputStreamOperatorTestHarness.java
@@ -31,14 +31,16 @@
 import org.apache.flink.runtime.state.KeyedStateBackend;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.memory.MemoryStateBackend;
-import org.apache.flink.streaming.api.operators.StreamCheckpointedOperator;
 import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamCheckpointedOperator;
+import org.apache.flink.streaming.runtime.tasks.OperatorStateHandles;
 import org.apache.flink.streaming.runtime.tasks.TestTimeServiceProvider;
 import org.mockito.invocation.InvocationOnMock;
 import org.mockito.stubbing.Answer;
 
 import java.io.ObjectInputStream;
 import java.io.ObjectOutputStream;
+import java.util.Collection;
 import java.util.Collections;
 import java.util.concurrent.RunnableFuture;
 
@@ -60,7 +62,7 @@ public class KeyedOneInputStreamOperatorTestHarness
 
 	// when we restore we keep the state here so that we can call restore
 	// when the operator requests the keyed state backend
-	private KeyGroupsStateHandle restoredKeyedState = null;
+	private Collection restoredKeyedState = null;
 
 	public KeyedOneInputStreamOperatorTestHarness(
 			OneInputStreamOperator operator,
@@ -138,7 +140,7 @@ public KeyedStateBackend answer(InvocationOnMock invocationOnMock) throws Throwa
 								keySerializer,
 								numberOfKeyGroups,
 								keyGroupRange,
-								Collections.singletonList(restoredKeyedState),
+								restoredKeyedState,
 								mockTask.getEnvironment().getTaskKvStateRegistry());
 						restoredKeyedState = null;
 						return keyedStateBackend;
@@ -154,7 +156,7 @@ public KeyedStateBackend answer(InvocationOnMock invocationOnMock) throws Throwa
 	 *
 	 */
 	@Override
-	public StreamStateHandle snapshot(long checkpointId, long timestamp) throws Exception {
+	public StreamStateHandle snapshotLegacy(long checkpointId, long timestamp) throws Exception {
 		// simply use an in-memory handle
 		MemoryStateBackend backend = new MemoryStateBackend();
 
@@ -185,7 +187,7 @@ public StreamStateHandle snapshot(long checkpointId, long timestamp) throws Exce
 	}
 
 	/**
-	 * 
+	 *
 	 */
 	@Override
 	public void restore(StreamStateHandle snapshot) throws Exception {
@@ -198,7 +200,7 @@ public void restore(StreamStateHandle snapshot) throws Exception {
 			byte keyedStatePresent = (byte) inStream.read();
 			if (keyedStatePresent == 1) {
 				ObjectInputStream ois = new ObjectInputStream(inStream);
-				this.restoredKeyedState = (KeyGroupsStateHandle) ois.readObject();
+				this.restoredKeyedState = Collections.singletonList((KeyGroupsStateHandle) ois.readObject());
 			}
 		}
 	}
@@ -208,8 +210,16 @@ public void restore(StreamStateHandle snapshot) throws Exception {
 	 */
 	public void close() throws Exception {
 		super.close();
-		if(keyedStateBackend != null) {
+		if (keyedStateBackend != null) {
 			keyedStateBackend.dispose();
 		}
 	}
+
+	@Override
+	public void initializeState(OperatorStateHandles operatorStateHandles) throws Exception {
+		if (null != operatorStateHandles) {
+			this.restoredKeyedState = operatorStateHandles.getManagedKeyedState();
+		}
+		super.initializeState(operatorStateHandles);
+	}
 }
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/OneInputStreamOperatorTestHarness.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/OneInputStreamOperatorTestHarness.java
index d1622ffaf3ca2..9f8d223b229d1 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/OneInputStreamOperatorTestHarness.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/OneInputStreamOperatorTestHarness.java
@@ -28,22 +28,26 @@
 import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider;
 import org.apache.flink.runtime.state.AbstractStateBackend;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
+import org.apache.flink.runtime.state.ClosableRegistry;
+import org.apache.flink.runtime.state.OperatorStateBackend;
+import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.memory.MemoryStateBackend;
-import org.apache.flink.streaming.api.operators.StreamCheckpointedOperator;
 import org.apache.flink.streaming.api.TimeCharacteristic;
 import org.apache.flink.streaming.api.graph.StreamConfig;
 import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.OperatorSnapshotResult;
 import org.apache.flink.streaming.api.operators.Output;
+import org.apache.flink.streaming.api.operators.StreamCheckpointedOperator;
 import org.apache.flink.streaming.api.operators.StreamOperator;
 import org.apache.flink.streaming.api.watermark.Watermark;
 import org.apache.flink.streaming.runtime.streamrecord.LatencyMarker;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.streaming.runtime.tasks.DefaultTimeServiceProvider;
+import org.apache.flink.streaming.runtime.tasks.OperatorStateHandles;
 import org.apache.flink.streaming.runtime.tasks.StreamTask;
 import org.apache.flink.streaming.runtime.tasks.TestTimeServiceProvider;
 import org.apache.flink.streaming.runtime.tasks.TimeServiceProvider;
-
 import org.mockito.invocation.InvocationOnMock;
 import org.mockito.stubbing.Answer;
 
@@ -65,7 +69,7 @@
  */
 public class OneInputStreamOperatorTestHarness {
 
-	protected static final int MAX_PARALLELISM = 10;
+	public static final int MAX_PARALLELISM = 10;
 
 	final OneInputStreamOperator operator;
 
@@ -81,6 +85,8 @@ public class OneInputStreamOperatorTestHarness {
 
 	StreamTask mockTask;
 
+	ClosableRegistry closableRegistry;
+
 	// use this as default for tests
 	AbstractStateBackend stateBackend = new MemoryStateBackend();
 
@@ -88,6 +94,7 @@ public class OneInputStreamOperatorTestHarness {
 	 * Whether setup() was called on the operator. This is reset when calling close().
 	 */
 	private boolean setupCalled = false;
+	private boolean initializeCalled = false;
 
 	private volatile boolean wasFailedExternally = false;
 
@@ -121,6 +128,7 @@ public OneInputStreamOperatorTestHarness(
 		this.config.setCheckpointingEnabled(true);
 		this.executionConfig = executionConfig;
 		this.checkpointLock = checkpointLock;
+		this.closableRegistry = new ClosableRegistry();
 
 		final Environment env = new MockEnvironment("MockTwoInputTask", 3 * 1024 * 1024, new MockInputSplitProvider(), 1024, underlyingConfig, executionConfig, MAX_PARALLELISM, 1, 0);
 		mockTask = mock(StreamTask.class);
@@ -132,6 +140,7 @@ public OneInputStreamOperatorTestHarness(
 		when(mockTask.getEnvironment()).thenReturn(env);
 		when(mockTask.getExecutionConfig()).thenReturn(executionConfig);
 		when(mockTask.getUserCodeClassLoader()).thenReturn(this.getClass().getClassLoader());
+		when(mockTask.getCancelables()).thenReturn(this.closableRegistry);
 
 		doAnswer(new Answer() {
 			@Override
@@ -154,6 +163,26 @@ public CheckpointStreamFactory answer(InvocationOnMock invocationOnMock) throws
 			throw new RuntimeException(e.getMessage(), e);
 		}
 
+		try {
+			doAnswer(new Answer() {
+				@Override
+				public OperatorStateBackend answer(InvocationOnMock invocationOnMock) throws Throwable {
+					final StreamOperator operator = (StreamOperator) invocationOnMock.getArguments()[0];
+					final Collection stateHandles = (Collection) invocationOnMock.getArguments()[1];
+					OperatorStateBackend osb;
+					if (null == stateHandles) {
+						osb = stateBackend.createOperatorStateBackend(env, operator.getClass().getSimpleName());
+					} else {
+						osb = stateBackend.restoreOperatorStateBackend(env, operator.getClass().getSimpleName(), stateHandles);
+					}
+					mockTask.getCancelables().registerClosable(osb);
+					return osb;
+				}
+			}).when(mockTask).createOperatorStateBackend(any(StreamOperator.class), any(Collection.class));
+		} catch (Exception e) {
+			throw new RuntimeException(e.getMessage(), e);
+		}
+
 		timeServiceProvider = testTimeProvider != null ? testTimeProvider :
 			new DefaultTimeServiceProvider(mockTask, this.checkpointLock);
 
@@ -199,8 +228,7 @@ public ConcurrentLinkedQueue getOutput() {
 	}
 
 	/**
-	 * Calls
-	 * {@link org.apache.flink.streaming.api.operators.StreamOperator#setup(StreamTask, StreamConfig, Output)} ()}
+	 * Calls {@link org.apache.flink.streaming.api.operators.StreamOperator#setup(StreamTask, StreamConfig, Output)} ()}
 	 */
 	public void setup() throws Exception {
 		operator.setup(mockTask, config, new MockOutput());
@@ -208,21 +236,48 @@ public void setup() throws Exception {
 	}
 
 	/**
-	 * Calls {@link org.apache.flink.streaming.api.operators.StreamOperator#open()}. This also
-	 * calls {@link org.apache.flink.streaming.api.operators.StreamOperator#setup(StreamTask, StreamConfig, Output)}
+	 * Calls {@link org.apache.flink.streaming.api.operators.StreamOperator#initializeState(OperatorStateHandles)}.
+	 * Calls {@link org.apache.flink.streaming.api.operators.StreamOperator#setup(StreamTask, StreamConfig, Output)}
 	 * if it was not called before.
 	 */
-	public void open() throws Exception {
+	public void initializeState(OperatorStateHandles operatorStateHandles) throws Exception {
 		if (!setupCalled) {
 			setup();
 		}
+		operator.initializeState(operatorStateHandles);
+		initializeCalled = true;
+	}
+
+	/**
+	 * Calls {@link org.apache.flink.streaming.api.operators.StreamOperator#open()}.
+	 * Calls {@link org.apache.flink.streaming.api.operators.StreamOperator#initializeState(OperatorStateHandles)} if it
+	 * was not called before.
+	 */
+	public void open() throws Exception {
+		if (!initializeCalled) {
+			initializeState(null);
+		}
 		operator.open();
 	}
 
 	/**
 	 *
 	 */
-	public StreamStateHandle snapshot(long checkpointId, long timestamp) throws Exception {
+	public OperatorSnapshotResult snapshot(long checkpointId, long timestamp) throws Exception {
+
+		CheckpointStreamFactory streamFactory = stateBackend.createStreamFactory(
+				new JobID(),
+				"test_op");
+
+		return operator.snapshotState(checkpointId, timestamp, streamFactory);
+	}
+
+	/**
+	 *
+	 */
+	@Deprecated
+	public StreamStateHandle snapshotLegacy(long checkpointId, long timestamp) throws Exception {
+
 		CheckpointStreamFactory.CheckpointStateOutputStream outStream = stateBackend.createStreamFactory(
 				new JobID(),
 				"test_op").createCheckpointStateOutputStream(checkpointId, timestamp);
@@ -244,6 +299,7 @@ public void notifyOfCompletedCheckpoint(long checkpointId) throws Exception {
 	/**
 	 *
 	 */
+	@Deprecated
 	public void restore(StreamStateHandle snapshot) throws Exception {
 		if(operator instanceof StreamCheckpointedOperator) {
 			try (FSDataInputStream in = snapshot.openInputStream()) {
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/TwoInputStreamOperatorTestHarness.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/TwoInputStreamOperatorTestHarness.java
index 32b4c77b7fc78..7df68483a5694 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/TwoInputStreamOperatorTestHarness.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/TwoInputStreamOperatorTestHarness.java
@@ -25,12 +25,14 @@
 import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.operators.testutils.MockEnvironment;
 import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider;
+import org.apache.flink.runtime.state.ClosableRegistry;
 import org.apache.flink.streaming.api.graph.StreamConfig;
 import org.apache.flink.streaming.api.operators.Output;
 import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
 import org.apache.flink.streaming.api.watermark.Watermark;
 import org.apache.flink.streaming.runtime.streamrecord.LatencyMarker;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.runtime.tasks.OperatorStateHandles;
 import org.apache.flink.streaming.runtime.tasks.StreamTask;
 
 import java.util.concurrent.ConcurrentLinkedQueue;
@@ -56,6 +58,10 @@ public class TwoInputStreamOperatorTestHarness {
 
 	final Object checkpointLock;
 
+	final ClosableRegistry closableRegistry;
+
+	boolean initializeCalled = false;
+
 	public TwoInputStreamOperatorTestHarness(TwoInputStreamOperator operator) {
 		this(operator, new StreamConfig(new Configuration()));
 	}
@@ -65,6 +71,7 @@ public TwoInputStreamOperatorTestHarness(TwoInputStreamOperator o
 		this.outputList = new ConcurrentLinkedQueue();
 		this.executionConfig = new ExecutionConfig();
 		this.checkpointLock = new Object();
+		this.closableRegistry = new ClosableRegistry();
 
 		Environment env = new MockEnvironment("MockTwoInputTask", 3 * 1024 * 1024, new MockInputSplitProvider(), 1024);
 		StreamTask mockTask = mock(StreamTask.class);
@@ -73,6 +80,7 @@ public TwoInputStreamOperatorTestHarness(TwoInputStreamOperator o
 		when(mockTask.getConfiguration()).thenReturn(config);
 		when(mockTask.getEnvironment()).thenReturn(env);
 		when(mockTask.getExecutionConfig()).thenReturn(executionConfig);
+		when(mockTask.getCancelables()).thenReturn(this.closableRegistry);
 
 		operator.setup(mockTask, new StreamConfig(new Configuration()), new MockOutput());
 	}
@@ -86,11 +94,22 @@ public ConcurrentLinkedQueue getOutput() {
 		return outputList;
 	}
 
+	/**
+	 * Calls {@link org.apache.flink.streaming.api.operators.StreamOperator#initializeState(OperatorStateHandles)}.
+	 */
+	public void initializeState(OperatorStateHandles operatorStateHandles) throws Exception {
+		operator.initializeState(operatorStateHandles);
+		initializeCalled = true;
+	}
 
 	/**
 	 * Calls {@link org.apache.flink.streaming.api.operators.StreamOperator#open()}.
 	 */
 	public void open() throws Exception {
+		if(!initializeCalled) {
+			initializeState(mock(OperatorStateHandles.class));
+		}
+
 		operator.open();
 	}
 
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/WindowingTestHarness.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/WindowingTestHarness.java
index ab8b70f450400..a4e26f006635c 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/WindowingTestHarness.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/WindowingTestHarness.java
@@ -165,7 +165,7 @@ public void compareActualToExpectedOutput(String errorMessage) {
 	 * Takes a snapshot of the current state of the operator. This can be used to test fault-tolerance.
 	 */
 	public StreamStateHandle snapshot(long checkpointId, long timestamp) throws Exception {
-		return testHarness.snapshot(checkpointId, timestamp);
+		return testHarness.snapshotLegacy(checkpointId, timestamp);
 	}
 
 	/**
diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java
index 48d720a0e0c0c..a0a971a8448ad 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java
@@ -22,6 +22,7 @@
 import org.apache.flink.api.common.JobID;
 import org.apache.flink.api.common.functions.RichFlatMapFunction;
 import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.state.ListState;
 import org.apache.flink.api.common.state.ValueState;
 import org.apache.flink.api.common.state.ValueStateDescriptor;
 import org.apache.flink.api.java.functions.KeySelector;
@@ -33,17 +34,21 @@
 import org.apache.flink.runtime.instance.ActorGateway;
 import org.apache.flink.runtime.jobgraph.JobGraph;
 import org.apache.flink.runtime.messages.JobManagerMessages;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
 import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
 import org.apache.flink.runtime.state.filesystem.FsStateBackendFactory;
 import org.apache.flink.runtime.testingUtils.TestingCluster;
 import org.apache.flink.runtime.testingUtils.TestingJobManagerMessages;
 import org.apache.flink.streaming.api.checkpoint.Checkpointed;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
 import org.apache.flink.streaming.api.checkpoint.ListCheckpointed;
 import org.apache.flink.streaming.api.datastream.DataStream;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
 import org.apache.flink.streaming.api.functions.sink.DiscardingSink;
 import org.apache.flink.streaming.api.functions.sink.SinkFunction;
 import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
 import org.apache.flink.util.Collector;
 import org.apache.flink.util.TestLogger;
 import org.junit.AfterClass;
@@ -72,7 +77,6 @@
 
 /**
  * TODO : parameterize to test all different state backends!
- * TODO: reactivate ignored test as soon as savepoints work with deactivated checkpoints.
  */
 public class RescalingITCase extends TestLogger {
 
@@ -80,6 +84,10 @@ public class RescalingITCase extends TestLogger {
 	private static final int slotsPerTaskManager = 2;
 	private static final int numSlots = numTaskManagers * slotsPerTaskManager;
 
+	enum OperatorCheckpointMethod {
+		NON_PARTITIONED, CHECKPOINTED_FUNCTION, LIST_CHECKPOINTED
+	}
+
 	private static TestingCluster cluster;
 
 	@ClassRule
@@ -243,7 +251,7 @@ public void testSavepointRescalingNonPartitionedStateCausesException() throws Ex
 		try {
 			jobManager = cluster.getLeaderGateway(deadline.timeLeft());
 
-			JobGraph jobGraph = createJobGraphWithOperatorState(parallelism, maxParallelism, false);
+			JobGraph jobGraph = createJobGraphWithOperatorState(parallelism, maxParallelism, OperatorCheckpointMethod.NON_PARTITIONED);
 
 			jobID = jobGraph.getJobID();
 
@@ -281,7 +289,7 @@ public void testSavepointRescalingNonPartitionedStateCausesException() throws Ex
 			// job successfully removed
 			jobID = null;
 
-			JobGraph scaledJobGraph = createJobGraphWithOperatorState(parallelism2, maxParallelism, false);
+			JobGraph scaledJobGraph = createJobGraphWithOperatorState(parallelism2, maxParallelism, OperatorCheckpointMethod.NON_PARTITIONED);
 
 			scaledJobGraph.setSavepointPath(savepointPath);
 
@@ -440,12 +448,22 @@ public void testSavepointRescalingWithKeyedAndNonPartitionedState() throws Excep
 
 	@Test
 	public void testSavepointRescalingInPartitionedOperatorState() throws Exception {
-		testSavepointRescalingPartitionedOperatorState(false);
+		testSavepointRescalingPartitionedOperatorState(false, OperatorCheckpointMethod.CHECKPOINTED_FUNCTION);
 	}
 
 	@Test
 	public void testSavepointRescalingOutPartitionedOperatorState() throws Exception {
-		testSavepointRescalingPartitionedOperatorState(true);
+		testSavepointRescalingPartitionedOperatorState(true, OperatorCheckpointMethod.CHECKPOINTED_FUNCTION);
+	}
+
+	@Test
+	public void testSavepointRescalingInPartitionedOperatorStateList() throws Exception {
+		testSavepointRescalingPartitionedOperatorState(false, OperatorCheckpointMethod.LIST_CHECKPOINTED);
+	}
+
+	@Test
+	public void testSavepointRescalingOutPartitionedOperatorStateList() throws Exception {
+		testSavepointRescalingPartitionedOperatorState(true, OperatorCheckpointMethod.LIST_CHECKPOINTED);
 	}
 
 
@@ -453,7 +471,7 @@ public void testSavepointRescalingOutPartitionedOperatorState() throws Exception
 	 * Tests rescaling of partitioned operator state. More specific, we test the mechanism with {@link ListCheckpointed}
 	 * as it subsumes {@link org.apache.flink.streaming.api.checkpoint.CheckpointedFunction}.
 	 */
-	public void testSavepointRescalingPartitionedOperatorState(boolean scaleOut) throws Exception {
+	public void testSavepointRescalingPartitionedOperatorState(boolean scaleOut, OperatorCheckpointMethod checkpointMethod) throws Exception {
 		final int parallelism = scaleOut ? numSlots : numSlots / 2;
 		final int parallelism2 = scaleOut ? numSlots / 2 : numSlots;
 		final int maxParallelism = 13;
@@ -466,13 +484,18 @@ public void testSavepointRescalingPartitionedOperatorState(boolean scaleOut) thr
 
 		int counterSize = Math.max(parallelism, parallelism2);
 
-		PartitionedStateSource.CHECK_CORRECT_SNAPSHOT = new int[counterSize];
-		PartitionedStateSource.CHECK_CORRECT_RESTORE = new int[counterSize];
+		if(checkpointMethod == OperatorCheckpointMethod.CHECKPOINTED_FUNCTION) {
+			PartitionedStateSource.CHECK_CORRECT_SNAPSHOT = new int[counterSize];
+			PartitionedStateSource.CHECK_CORRECT_RESTORE = new int[counterSize];
+		} else {
+			PartitionedStateSourceListCheckpointed.CHECK_CORRECT_SNAPSHOT = new int[counterSize];
+			PartitionedStateSourceListCheckpointed.CHECK_CORRECT_RESTORE = new int[counterSize];
+		}
 
 		try {
 			jobManager = cluster.getLeaderGateway(deadline.timeLeft());
 
-			JobGraph jobGraph = createJobGraphWithOperatorState(parallelism, maxParallelism, true);
+			JobGraph jobGraph = createJobGraphWithOperatorState(parallelism, maxParallelism, checkpointMethod);
 
 			jobID = jobGraph.getJobID();
 
@@ -511,7 +534,7 @@ public void testSavepointRescalingPartitionedOperatorState(boolean scaleOut) thr
 			// job successfully removed
 			jobID = null;
 
-			JobGraph scaledJobGraph = createJobGraphWithOperatorState(parallelism2, maxParallelism, true);
+			JobGraph scaledJobGraph = createJobGraphWithOperatorState(parallelism2, maxParallelism, checkpointMethod);
 
 			scaledJobGraph.setSavepointPath(savepointPath);
 
@@ -522,12 +545,22 @@ public void testSavepointRescalingPartitionedOperatorState(boolean scaleOut) thr
 			int sumExp = 0;
 			int sumAct = 0;
 
-			for (int c : PartitionedStateSource.CHECK_CORRECT_SNAPSHOT) {
-				sumExp += c;
-			}
+			if (checkpointMethod == OperatorCheckpointMethod.CHECKPOINTED_FUNCTION) {
+				for (int c : PartitionedStateSource.CHECK_CORRECT_SNAPSHOT) {
+					sumExp += c;
+				}
+
+				for (int c : PartitionedStateSource.CHECK_CORRECT_RESTORE) {
+					sumAct += c;
+				}
+			} else {
+				for (int c : PartitionedStateSourceListCheckpointed.CHECK_CORRECT_SNAPSHOT) {
+					sumExp += c;
+				}
 
-			for (int c : PartitionedStateSource.CHECK_CORRECT_RESTORE) {
-				sumAct += c;
+				for (int c : PartitionedStateSourceListCheckpointed.CHECK_CORRECT_RESTORE) {
+					sumAct += c;
+				}
 			}
 
 			assertEquals(sumExp, sumAct);
@@ -550,7 +583,7 @@ public void testSavepointRescalingPartitionedOperatorState(boolean scaleOut) thr
 	//------------------------------------------------------------------------------------------------------------------
 
 	private static JobGraph createJobGraphWithOperatorState(
-			int parallelism, int maxParallelism, boolean partitionedOperatorState) {
+			int parallelism, int maxParallelism, OperatorCheckpointMethod checkpointMethod) {
 
 		StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
 		env.setParallelism(parallelism);
@@ -560,8 +593,23 @@ private static JobGraph createJobGraphWithOperatorState(
 
 		StateSourceBase.workStartedLatch = new CountDownLatch(1);
 
-		DataStream input = env.addSource(
-				partitionedOperatorState ? new PartitionedStateSource() : new NonPartitionedStateSource());
+		SourceFunction src;
+
+		switch (checkpointMethod) {
+			case CHECKPOINTED_FUNCTION:
+				src = new PartitionedStateSource();
+				break;
+			case LIST_CHECKPOINTED:
+				src = new PartitionedStateSourceListCheckpointed();
+				break;
+			case NON_PARTITIONED:
+				src = new NonPartitionedStateSource();
+				break;
+			default:
+				throw new IllegalArgumentException();
+		}
+
+		DataStream input = env.addSource(src);
 
 		input.addSink(new DiscardingSink());
 
@@ -718,7 +766,7 @@ public void restoreState(Integer state) throws Exception {
 		}
 	}
 
-	private static class SubtaskIndexFlatMapper extends RichFlatMapFunction> {
+	private static class SubtaskIndexFlatMapper extends RichFlatMapFunction> implements CheckpointedFunction {
 
 		private static final long serialVersionUID = 5273172591283191348L;
 
@@ -733,12 +781,6 @@ private static class SubtaskIndexFlatMapper extends RichFlatMapFunction("counter", Integer.class, 0));
-			sum = getRuntimeContext().getState(new ValueStateDescriptor<>("sum", Integer.class, 0));
-		}
-
 		@Override
 		public void flatMap(Integer value, Collector> out) throws Exception {
 
@@ -753,6 +795,17 @@ public void flatMap(Integer value, Collector> out) thro
 				workCompletedLatch.countDown();
 			}
 		}
+
+		@Override
+		public void snapshotState(FunctionSnapshotContext context) throws Exception {
+			//all managed, nothing to do.
+		}
+
+		@Override
+		public void initializeState(FunctionInitializationContext context) throws Exception {
+			counter = context.getManagedKeyedStateStore().getState(new ValueStateDescriptor<>("counter", Integer.class, 0));
+			sum = context.getManagedKeyedStateStore().getState(new ValueStateDescriptor<>("sum", Integer.class, 0));
+		}
 	}
 
 	private static class CollectionSink implements SinkFunction {
@@ -824,9 +877,9 @@ public void restoreState(Integer state) throws Exception {
 		}
 	}
 
-	private static class PartitionedStateSource extends StateSourceBase implements ListCheckpointed {
+	private static class PartitionedStateSourceListCheckpointed extends StateSourceBase implements ListCheckpointed {
 
-		private static final long serialVersionUID = -359715965103593462L;
+		private static final long serialVersionUID = -4357864582992546L;
 		private static final int NUM_PARTITIONS = 7;
 
 		private static int[] CHECK_CORRECT_SNAPSHOT;
@@ -860,4 +913,46 @@ public void restoreState(List state) throws Exception {
 			CHECK_CORRECT_RESTORE[getRuntimeContext().getIndexOfThisSubtask()] = counter;
 		}
 	}
+
+	private static class PartitionedStateSource extends StateSourceBase implements CheckpointedFunction {
+
+		private static final long serialVersionUID = -359715965103593462L;
+		private static final int NUM_PARTITIONS = 7;
+
+		private ListState counterPartitions;
+
+		private static int[] CHECK_CORRECT_SNAPSHOT;
+		private static int[] CHECK_CORRECT_RESTORE;
+
+
+		@Override
+		public void snapshotState(FunctionSnapshotContext context) throws Exception {
+
+			CHECK_CORRECT_SNAPSHOT[getRuntimeContext().getIndexOfThisSubtask()] = counter;
+
+			int div = counter / NUM_PARTITIONS;
+			int mod = counter % NUM_PARTITIONS;
+
+			for (int i = 0; i < NUM_PARTITIONS; ++i) {
+				int partitionValue = div;
+				if (mod > 0) {
+					--mod;
+					++partitionValue;
+				}
+				counterPartitions.add(partitionValue);
+			}
+		}
+
+		@Override
+		public void initializeState(FunctionInitializationContext context) throws Exception {
+			this.counterPartitions =
+					context.getManagedOperatorStateStore().getSerializableListState("counter_partitions");
+			if (context.isRestored()) {
+				for (int v : counterPartitions.get()) {
+					counter += v;
+				}
+				CHECK_CORRECT_RESTORE[getRuntimeContext().getIndexOfThisSubtask()] = counter;
+			}
+		}
+	}
 }
diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/SavepointITCase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/SavepointITCase.java
index 74de942362c18..5c4986b6577ba 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/SavepointITCase.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/SavepointITCase.java
@@ -338,7 +338,8 @@ protected void run() {
 
 					assertNotNull(subtaskState);
 					errMsg = "Initial operator state mismatch.";
-					assertEquals(errMsg, subtaskState.getChainedStateHandle(), tdd.getOperatorState());
+					assertEquals(errMsg, subtaskState.getLegacyOperatorState(),
+							tdd.getTaskStateHandles().getLegacyOperatorState());
 				}
 			}
 
@@ -364,7 +365,7 @@ protected void run() {
 
 			for (TaskState stateForTaskGroup : savepoint.getTaskStates()) {
 				for (SubtaskState subtaskState : stateForTaskGroup.getStates()) {
-					ChainedStateHandle streamTaskState = subtaskState.getChainedStateHandle();
+					ChainedStateHandle streamTaskState = subtaskState.getLegacyOperatorState();
 
 					for (int i = 0; i < streamTaskState.getLength(); i++) {
 						if (streamTaskState.get(i) != null) {
diff --git a/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/StateBackendITCase.java b/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/StateBackendITCase.java
index 2a635ab5db9fd..963d18a7bcbcd 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/StateBackendITCase.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/StateBackendITCase.java
@@ -38,7 +38,7 @@
 import org.junit.Test;
 
 import java.io.IOException;
-import java.util.List;
+import java.util.Collection;
 
 import static org.junit.Assert.fail;
 
@@ -119,7 +119,7 @@ public  AbstractKeyedStateBackend restoreKeyedStateBackend(
 				TypeSerializer keySerializer,
 				int numberOfKeyGroups,
 				KeyGroupRange keyGroupRange,
-				List restoredState,
+				Collection restoredState,
 				TaskKvStateRegistry kvStateRegistry) throws Exception {
 			throw new SuccessException();
 		}
diff --git a/flink-yarn/src/main/java/org/apache/flink/yarn/cli/FlinkYarnSessionCli.java b/flink-yarn/src/main/java/org/apache/flink/yarn/cli/FlinkYarnSessionCli.java
index 84d2fe675331f..7ce040b640100 100644
--- a/flink-yarn/src/main/java/org/apache/flink/yarn/cli/FlinkYarnSessionCli.java
+++ b/flink-yarn/src/main/java/org/apache/flink/yarn/cli/FlinkYarnSessionCli.java
@@ -31,12 +31,12 @@
 import org.apache.flink.configuration.GlobalConfiguration;
 import org.apache.flink.configuration.HighAvailabilityOptions;
 import org.apache.flink.configuration.IllegalConfigurationException;
+import org.apache.flink.runtime.clusterframework.ApplicationStatus;
+import org.apache.flink.runtime.clusterframework.messages.GetClusterStatusResponse;
 import org.apache.flink.runtime.security.SecurityContext;
 import org.apache.flink.yarn.AbstractYarnClusterDescriptor;
-import org.apache.flink.yarn.YarnClusterDescriptor;
 import org.apache.flink.yarn.YarnClusterClient;
-import org.apache.flink.runtime.clusterframework.ApplicationStatus;
-import org.apache.flink.runtime.clusterframework.messages.GetClusterStatusResponse;
+import org.apache.flink.yarn.YarnClusterDescriptor;
 import org.apache.hadoop.fs.Path;
 import org.apache.hadoop.yarn.util.ConverterUtils;
 import org.slf4j.Logger;
@@ -61,6 +61,7 @@
 import java.util.Properties;
 
 import static org.apache.flink.client.cli.CliFrontendParser.ADDRESS_OPTION;
+import static org.apache.flink.configuration.ConfigConstants.HA_ZOOKEEPER_NAMESPACE_KEY;
 
 /**
  * Class handling the command line interface to the YARN session.