From d6b8b0112c6a4cf3f2cbf5eb758599e15d796aab Mon Sep 17 00:00:00 2001 From: Stefan Richter Date: Wed, 21 Sep 2016 14:55:58 +0200 Subject: [PATCH 1/2] [FLINK-4603] KeyedStateBackend can restore user code classes --- .../state/RocksDBKeyedStateBackend.java | 19 +++++++----- .../streaming/state/RocksDBStateBackend.java | 2 ++ .../apache/flink/util/InstantiationUtil.java | 6 ++-- .../runtime/state/KeyedStateBackend.java | 4 +++ .../state/filesystem/FsStateBackend.java | 2 ++ .../state/heap/HeapKeyedStateBackend.java | 29 ++++++++++--------- .../state/memory/MemoryStateBackend.java | 5 +++- .../streaming/runtime/tasks/StreamTask.java | 8 ++++- 8 files changed, 50 insertions(+), 25 deletions(-) diff --git a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java index 177c09f061e98..d5a96af8c0c62 100644 --- a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java +++ b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java @@ -47,6 +47,7 @@ import org.apache.flink.runtime.state.KeyedStateBackend; import org.apache.flink.runtime.state.StreamStateHandle; import org.apache.flink.runtime.util.SerializableObject; +import org.apache.flink.util.InstantiationUtil; import org.apache.flink.util.Preconditions; import org.rocksdb.ColumnFamilyDescriptor; import org.rocksdb.ColumnFamilyHandle; @@ -63,8 +64,6 @@ import javax.annotation.concurrent.GuardedBy; import java.io.File; import java.io.IOException; -import java.io.ObjectInputStream; -import java.io.ObjectOutputStream; import java.util.ArrayList; import java.util.Comparator; import java.util.HashMap; @@ -125,6 +124,7 @@ public class RocksDBKeyedStateBackend extends KeyedStateBackend { public RocksDBKeyedStateBackend( JobID jobId, String operatorIdentifier, + ClassLoader userCodeClassLoader, File instanceBasePath, DBOptions dbOptions, ColumnFamilyOptions columnFamilyOptions, @@ -134,7 +134,7 @@ public RocksDBKeyedStateBackend( KeyGroupRange keyGroupRange ) throws Exception { - super(kvStateRegistry, keySerializer, numberOfKeyGroups, keyGroupRange); + super(kvStateRegistry, keySerializer, userCodeClassLoader, numberOfKeyGroups, keyGroupRange); this.operatorIdentifier = operatorIdentifier; this.jobId = jobId; @@ -177,6 +177,7 @@ public RocksDBKeyedStateBackend( public RocksDBKeyedStateBackend( JobID jobId, String operatorIdentifier, + ClassLoader userCodeClassLoader, File instanceBasePath, DBOptions dbOptions, ColumnFamilyOptions columnFamilyOptions, @@ -189,6 +190,7 @@ public RocksDBKeyedStateBackend( this( jobId, operatorIdentifier, + userCodeClassLoader, instanceBasePath, dbOptions, columnFamilyOptions, @@ -455,8 +457,8 @@ private void writeKVStateMetaData() throws IOException, InterruptedException { checkInterrupted(); //write StateDescriptor for this k/v state - ObjectOutputStream ooOut = new ObjectOutputStream(outStream); - ooOut.writeObject(column.getValue().f1); + InstantiationUtil.serializeObject(outStream, column.getValue().f1); + //retrieve iterator for this k/v states ReadOptions readOptions = new ReadOptions(); readOptions.setSnapshot(snapshot); @@ -649,8 +651,11 @@ private void restoreKVStateMetaData() throws IOException, ClassNotFoundException //restore the empty columns for the k/v states through the metadata for (int i = 0; i < numColumns; i++) { - ObjectInputStream ooIn = new ObjectInputStream(currentStateHandleInStream); - StateDescriptor stateDescriptor = (StateDescriptor) ooIn.readObject(); + + StateDescriptor stateDescriptor = InstantiationUtil.deserializeObject( + currentStateHandleInStream, + rocksDBKeyedStateBackend.userCodeClassLoader); + Tuple2 columnFamily = rocksDBKeyedStateBackend. kvStateInformation.get(stateDescriptor.getName()); diff --git a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java index 0fdbd5f409fc9..b6ce2245a2833 100644 --- a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java +++ b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java @@ -240,6 +240,7 @@ public KeyedStateBackend createKeyedStateBackend( return new RocksDBKeyedStateBackend<>( jobID, operatorIdentifier, + env.getUserClassLoader(), instanceBasePath, getDbOptions(), getColumnOptions(), @@ -264,6 +265,7 @@ public KeyedStateBackend restoreKeyedStateBackend(Environment env, JobID return new RocksDBKeyedStateBackend<>( jobID, operatorIdentifier, + env.getUserClassLoader(), instanceBasePath, getDbOptions(), getColumnOptions(), diff --git a/flink-core/src/main/java/org/apache/flink/util/InstantiationUtil.java b/flink-core/src/main/java/org/apache/flink/util/InstantiationUtil.java index b1dddae68b6ba..de4cffbc7ee8f 100644 --- a/flink-core/src/main/java/org/apache/flink/util/InstantiationUtil.java +++ b/flink-core/src/main/java/org/apache/flink/util/InstantiationUtil.java @@ -299,7 +299,10 @@ public static T deserializeObject(byte[] bytes, ClassLoader cl) throws IOExc @SuppressWarnings("unchecked") public static T deserializeObject(InputStream in, ClassLoader cl) throws IOException, ClassNotFoundException { final ClassLoader old = Thread.currentThread().getContextClassLoader(); - try (ObjectInputStream oois = new ClassLoaderObjectInputStream(in, cl)) { + ObjectInputStream oois; + // not using resource try to avoid AutoClosable's close() on the given stream + try { + oois = new ClassLoaderObjectInputStream(in, cl); Thread.currentThread().setContextClassLoader(cl); return (T) oois.readObject(); } @@ -332,7 +335,6 @@ public static byte[] serializeObject(Object o) throws IOException { public static void serializeObject(OutputStream out, Object o) throws IOException { ObjectOutputStream oos = new ObjectOutputStream(out); oos.writeObject(o); - oos.flush(); } /** diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java index bf9018e911c63..8db63ee7b672e 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java @@ -77,14 +77,18 @@ public abstract class KeyedStateBackend { /** KvStateRegistry helper for this task */ protected final TaskKvStateRegistry kvStateRegistry; + protected final ClassLoader userCodeClassLoader; + public KeyedStateBackend( TaskKvStateRegistry kvStateRegistry, TypeSerializer keySerializer, + ClassLoader userCodeClassLoader, int numberOfKeyGroups, KeyGroupRange keyGroupRange) { this.kvStateRegistry = Preconditions.checkNotNull(kvStateRegistry); this.keySerializer = Preconditions.checkNotNull(keySerializer); + this.userCodeClassLoader = Preconditions.checkNotNull(userCodeClassLoader); this.numberOfKeyGroups = Preconditions.checkNotNull(numberOfKeyGroups); this.keyGroupRange = Preconditions.checkNotNull(keyGroupRange); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java index 6d92a4dbbf982..99e368441cb3d 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java @@ -186,6 +186,7 @@ public KeyedStateBackend createKeyedStateBackend( return new HeapKeyedStateBackend<>( kvStateRegistry, keySerializer, + env.getUserClassLoader(), numberOfKeyGroups, keyGroupRange); } @@ -203,6 +204,7 @@ public KeyedStateBackend restoreKeyedStateBackend( return new HeapKeyedStateBackend<>( kvStateRegistry, keySerializer, + env.getUserClassLoader(), numberOfKeyGroups, keyGroupRange, restoredState); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java index 8d13941449219..a655ae69f0430 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java @@ -39,12 +39,11 @@ import org.apache.flink.runtime.state.KeyGroupsStateHandle; import org.apache.flink.runtime.state.KeyedStateBackend; import org.apache.flink.runtime.state.StreamStateHandle; +import org.apache.flink.util.InstantiationUtil; import org.apache.flink.util.Preconditions; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.io.ObjectInputStream; -import java.io.ObjectOutputStream; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -75,20 +74,23 @@ public class HeapKeyedStateBackend extends KeyedStateBackend { public HeapKeyedStateBackend( TaskKvStateRegistry kvStateRegistry, TypeSerializer keySerializer, + ClassLoader userCodeClassLoader, int numberOfKeyGroups, KeyGroupRange keyGroupRange) { - super(kvStateRegistry, keySerializer, numberOfKeyGroups, keyGroupRange); + super(kvStateRegistry, keySerializer, userCodeClassLoader, numberOfKeyGroups, keyGroupRange); LOG.info("Initializing heap keyed state backend with stream factory."); } - public HeapKeyedStateBackend(TaskKvStateRegistry kvStateRegistry, + public HeapKeyedStateBackend( + TaskKvStateRegistry kvStateRegistry, TypeSerializer keySerializer, + ClassLoader userCodeClassLoader, int numberOfKeyGroups, KeyGroupRange keyGroupRange, List restoredState) throws Exception { - super(kvStateRegistry, keySerializer, numberOfKeyGroups, keyGroupRange); + super(kvStateRegistry, keySerializer, userCodeClassLoader, numberOfKeyGroups, keyGroupRange); LOG.info("Initializing heap keyed state backend from snapshot."); @@ -189,11 +191,9 @@ public RunnableFuture snapshot( TypeSerializer namespaceSerializer = kvState.getValue().getNamespaceSerializer(); TypeSerializer stateSerializer = kvState.getValue().getStateSerializer(); - - ObjectOutputStream oos = new ObjectOutputStream(outView); - oos.writeObject(namespaceSerializer); - oos.writeObject(stateSerializer); - oos.flush(); + + InstantiationUtil.serializeObject(stream, namespaceSerializer); + InstantiationUtil.serializeObject(stream, stateSerializer); kVStateToId.put(kvState.getKey(), kVStateToId.size()); } @@ -266,10 +266,11 @@ public void restorePartitionedState(List state) throws Exc for (int i = 0; i < numKvStates; ++i) { String stateName = inView.readUTF(); - ObjectInputStream ois = new ObjectInputStream(inView); + TypeSerializer namespaceSerializer = + InstantiationUtil.deserializeObject(fsDataInputStream, userCodeClassLoader); + TypeSerializer stateSerializer = + InstantiationUtil.deserializeObject(fsDataInputStream, userCodeClassLoader); - TypeSerializer namespaceSerializer = (TypeSerializer) ois.readObject(); - TypeSerializer stateSerializer = (TypeSerializer) ois.readObject(); StateTable stateTable = new StateTable(stateSerializer, namespaceSerializer, keyGroupRange); @@ -277,7 +278,7 @@ public void restorePartitionedState(List state) throws Exc kvStatesById.put(i, stateName); } - for (int keyGroupIndex = keyGroupRange.getStartKeyGroup(); keyGroupIndex <= keyGroupRange.getEndKeyGroup(); keyGroupIndex++) { + for (int keyGroupIndex = keyGroupRange.getStartKeyGroup(); keyGroupIndex <= keyGroupRange.getEndKeyGroup(); ++keyGroupIndex) { long offset = keyGroupsHandle.getOffsetForKeyGroup(keyGroupIndex); fsDataInputStream.seek(offset); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java index 179dfe76856ad..cc145ff0b3d32 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java @@ -78,7 +78,8 @@ public CheckpointStreamFactory createStreamFactory(JobID jobId, String operatorI @Override public KeyedStateBackend createKeyedStateBackend( Environment env, JobID jobID, - String operatorIdentifier, TypeSerializer keySerializer, + String operatorIdentifier, + TypeSerializer keySerializer, int numberOfKeyGroups, KeyGroupRange keyGroupRange, TaskKvStateRegistry kvStateRegistry) throws IOException { @@ -86,6 +87,7 @@ public KeyedStateBackend createKeyedStateBackend( return new HeapKeyedStateBackend<>( kvStateRegistry, keySerializer, + env.getUserClassLoader(), numberOfKeyGroups, keyGroupRange); } @@ -103,6 +105,7 @@ public KeyedStateBackend restoreKeyedStateBackend( return new HeapKeyedStateBackend<>( kvStateRegistry, keySerializer, + env.getUserClassLoader(), numberOfKeyGroups, keyGroupRange, restoredState); diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java index 9c2650911e852..d4638a4871669 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java @@ -23,6 +23,7 @@ import org.apache.flink.configuration.ConfigConstants; import org.apache.flink.configuration.Configuration; import org.apache.flink.configuration.IllegalConfigurationException; +import org.apache.flink.core.fs.FSDataInputStream; import org.apache.flink.metrics.Gauge; import org.apache.flink.runtime.execution.CancelTaskException; import org.apache.flink.runtime.io.network.api.CheckpointBarrier; @@ -585,7 +586,12 @@ private void restoreState() throws Exception { if (operator != null) { LOG.debug("Restore state of task {} in chain ({}).", i, getName()); - operator.restoreState(state.openInputStream()); + FSDataInputStream inputStream = state.openInputStream(); + try { + operator.restoreState(inputStream); + } finally { + inputStream.close(); + } } } } From 78b2a4f048bd62e55471a384169304ca46bbbf60 Mon Sep 17 00:00:00 2001 From: Stefan Richter Date: Wed, 21 Sep 2016 17:56:08 +0200 Subject: [PATCH 2/2] [FLINK-4603] Test case --- .../state/heap/HeapKeyedStateBackend.java | 6 +- flink-tests/pom.xml | 19 ++ ...checkpointing-custom_kv_state-assembly.xml | 38 +++ .../test/classloading/ClassLoaderITCase.java | 25 +- .../CheckpointingCustomKvStateProgram.java | 233 ++++++++++++++++++ 5 files changed, 315 insertions(+), 6 deletions(-) create mode 100644 flink-tests/src/test/assembly/test-checkpointing-custom_kv_state-assembly.xml create mode 100644 flink-tests/src/test/java/org/apache/flink/test/classloading/jar/CheckpointingCustomKvStateProgram.java diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java index a655ae69f0430..c13be70fc42d2 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java @@ -137,7 +137,6 @@ public ReducingState createReducingState(TypeSerializer namespaceSe @SuppressWarnings("unchecked,rawtypes") StateTable stateTable = (StateTable) stateTables.get(stateDesc.getName()); - if (stateTable == null) { stateTable = new StateTable<>(stateDesc.getSerializer(), namespaceSerializer, keyGroupRange); stateTables.put(stateDesc.getName(), stateTable); @@ -191,7 +190,7 @@ public RunnableFuture snapshot( TypeSerializer namespaceSerializer = kvState.getValue().getNamespaceSerializer(); TypeSerializer stateSerializer = kvState.getValue().getStateSerializer(); - + InstantiationUtil.serializeObject(stream, namespaceSerializer); InstantiationUtil.serializeObject(stream, stateSerializer); @@ -271,7 +270,8 @@ public void restorePartitionedState(List state) throws Exc TypeSerializer stateSerializer = InstantiationUtil.deserializeObject(fsDataInputStream, userCodeClassLoader); - StateTable stateTable = new StateTable(stateSerializer, + StateTable stateTable = new StateTable( + stateSerializer, namespaceSerializer, keyGroupRange); stateTables.put(stateName, stateTable); diff --git a/flink-tests/pom.xml b/flink-tests/pom.xml index b09db1f230b8b..efc95ab6888ba 100644 --- a/flink-tests/pom.xml +++ b/flink-tests/pom.xml @@ -485,6 +485,25 @@ under the License. + + create-checkpointing_custom_kv_state-jar + process-test-classes + + single + + + + + org.apache.flink.test.classloading.jar.CheckpointingCustomKvStateProgram + + + checkpointing_custom_kv_state + false + + src/test/assembly/test-checkpointing-custom_kv_state-assembly.xml + + + diff --git a/flink-tests/src/test/assembly/test-checkpointing-custom_kv_state-assembly.xml b/flink-tests/src/test/assembly/test-checkpointing-custom_kv_state-assembly.xml new file mode 100644 index 0000000000000..fdebfdd3db806 --- /dev/null +++ b/flink-tests/src/test/assembly/test-checkpointing-custom_kv_state-assembly.xml @@ -0,0 +1,38 @@ + + + + test-jar + + jar + + false + + + ${project.build.testOutputDirectory} + / + + + org/apache/flink/test/classloading/jar/CheckpointingCustomKvStateProgram.class + org/apache/flink/test/classloading/jar/CheckpointingCustomKvStateProgram*.class + + + + diff --git a/flink-tests/src/test/java/org/apache/flink/test/classloading/ClassLoaderITCase.java b/flink-tests/src/test/java/org/apache/flink/test/classloading/ClassLoaderITCase.java index 7afafe43274ae..65da33ffe35bc 100644 --- a/flink-tests/src/test/java/org/apache/flink/test/classloading/ClassLoaderITCase.java +++ b/flink-tests/src/test/java/org/apache/flink/test/classloading/ClassLoaderITCase.java @@ -39,6 +39,7 @@ import org.apache.flink.runtime.testingUtils.TestingCluster; import org.apache.flink.runtime.testingUtils.TestingJobManagerMessages.WaitForAllVerticesToBeRunning; import org.apache.flink.test.testdata.KMeansData; +import org.apache.flink.test.util.SuccessException; import org.apache.flink.util.TestLogger; import org.junit.AfterClass; import org.junit.BeforeClass; @@ -46,7 +47,6 @@ import org.junit.rules.TemporaryFolder; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import scala.Option; import scala.concurrent.Await; import scala.concurrent.Future; import scala.concurrent.duration.Deadline; @@ -79,6 +79,8 @@ public class ClassLoaderITCase extends TestLogger { private static final String CUSTOM_KV_STATE_JAR_PATH = "custom_kv_state-test-jar.jar"; + private static final String CHECKPOINTING_CUSTOM_KV_STATE_JAR_PATH = "checkpointing_custom_kv_state-test-jar.jar"; + public static final TemporaryFolder FOLDER = new TemporaryFolder(); private static TestingCluster testCluster; @@ -199,9 +201,26 @@ public void testJobsWithCustomClassLoader() { }); userCodeTypeProg.invokeInteractiveModeForExecution(); + + File checkpointDir = FOLDER.newFolder(); + File outputDir = FOLDER.newFolder(); + + final PackagedProgram program = new PackagedProgram( + new File(CHECKPOINTING_CUSTOM_KV_STATE_JAR_PATH), + new String[] { + CHECKPOINTING_CUSTOM_KV_STATE_JAR_PATH, + "localhost", + String.valueOf(port), + checkpointDir.toURI().toString(), + outputDir.toURI().toString() + }); + + program.invokeInteractiveModeForExecution(); + } catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); + if (!(e.getCause().getCause() instanceof SuccessException)) { + fail(e.getMessage()); + } } } diff --git a/flink-tests/src/test/java/org/apache/flink/test/classloading/jar/CheckpointingCustomKvStateProgram.java b/flink-tests/src/test/java/org/apache/flink/test/classloading/jar/CheckpointingCustomKvStateProgram.java new file mode 100644 index 0000000000000..6796cb0b5fa17 --- /dev/null +++ b/flink-tests/src/test/java/org/apache/flink/test/classloading/jar/CheckpointingCustomKvStateProgram.java @@ -0,0 +1,233 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.test.classloading.jar; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.common.functions.RichFlatMapFunction; +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.api.common.state.ReducingState; +import org.apache.flink.api.common.state.ReducingStateDescriptor; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.common.typeutils.base.TypeSerializerSingleton; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.core.fs.FileSystem; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.runtime.state.CheckpointListener; +import org.apache.flink.runtime.state.filesystem.FsStateBackend; +import org.apache.flink.streaming.api.checkpoint.Checkpointed; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.functions.source.ParallelSourceFunction; +import org.apache.flink.test.util.SuccessException; +import org.apache.flink.util.Collector; + +import java.io.IOException; +import java.util.concurrent.ThreadLocalRandom; + +public class CheckpointingCustomKvStateProgram { + + public static void main(String[] args) throws Exception { + final String jarFile = args[0]; + final String host = args[1]; + final int port = Integer.parseInt(args[2]); + final String checkpointPath = args[3]; + final String outputPath = args[4]; + final int parallelism = 1; + + StreamExecutionEnvironment env = StreamExecutionEnvironment.createRemoteEnvironment(host, port, jarFile); + + env.setParallelism(parallelism); + env.getConfig().disableSysoutLogging(); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.fixedDelayRestart(1, 1000)); + env.setStateBackend(new FsStateBackend(checkpointPath)); + + DataStream source = env.addSource(new InfiniteIntegerSource()); + source + .map(new MapFunction>() { + private static final long serialVersionUID = 1L; + + @Override + public Tuple2 map(Integer value) throws Exception { + return new Tuple2<>(ThreadLocalRandom.current().nextInt(parallelism), value); + } + }) + .keyBy(new KeySelector, Integer>() { + private static final long serialVersionUID = 1L; + + @Override + public Integer getKey(Tuple2 value) throws Exception { + return value.f0; + } + }).flatMap(new ReducingStateFlatMap()).writeAsText(outputPath, FileSystem.WriteMode.OVERWRITE); + + env.execute(); + } + + private static class InfiniteIntegerSource implements ParallelSourceFunction, Checkpointed { + private static final long serialVersionUID = -7517574288730066280L; + private volatile boolean running = true; + + @Override + public void run(SourceContext ctx) throws Exception { + int counter = 0; + while (running) { + synchronized (ctx.getCheckpointLock()) { + ctx.collect(counter++); + } + } + } + + @Override + public void cancel() { + running = false; + } + + @Override + public Integer snapshotState(long checkpointId, long checkpointTimestamp) throws Exception { + return 0; + } + + @Override + public void restoreState(Integer state) throws Exception { + + } + } + + private static class ReducingStateFlatMap extends RichFlatMapFunction, Integer> implements Checkpointed, CheckpointListener { + + private static final long serialVersionUID = -5939722892793950253L; + private transient ReducingState kvState; + + private boolean atLeastOneSnapshotComplete = false; + private boolean restored = false; + + @Override + public void open(Configuration parameters) throws Exception { + ReducingStateDescriptor stateDescriptor = + new ReducingStateDescriptor<>( + "reducing-state", + new ReduceSum(), + CustomIntSerializer.INSTANCE); + + this.kvState = getRuntimeContext().getReducingState(stateDescriptor); + } + + + @Override + public void flatMap(Tuple2 value, Collector out) throws Exception { + kvState.add(value.f1); + + if(atLeastOneSnapshotComplete) { + if (restored) { + throw new SuccessException(); + } else { + throw new RuntimeException("Intended failure, to trigger restore"); + } + } + } + + @Override + public ReducingStateFlatMap snapshotState(long checkpointId, long checkpointTimestamp) throws Exception { + return this; + } + + @Override + public void restoreState(ReducingStateFlatMap state) throws Exception { + restored = true; + atLeastOneSnapshotComplete = true; + } + + @Override + public void notifyCheckpointComplete(long checkpointId) throws Exception { + atLeastOneSnapshotComplete = true; + } + + private static class ReduceSum implements ReduceFunction { + private static final long serialVersionUID = 1L; + + @Override + public Integer reduce(Integer value1, Integer value2) throws Exception { + return value1 + value2; + } + } + } + + private static final class CustomIntSerializer extends TypeSerializerSingleton { + + private static final long serialVersionUID = 4572452915892737448L; + + public static final TypeSerializer INSTANCE = new CustomIntSerializer(); + + @Override + public boolean isImmutableType() { + return true; + } + + @Override + public Integer createInstance() { + return 0; + } + + @Override + public Integer copy(Integer from) { + return from; + } + + @Override + public Integer copy(Integer from, Integer reuse) { + return from; + } + + @Override + public int getLength() { + return 4; + } + + @Override + public void serialize(Integer record, DataOutputView target) throws IOException { + target.writeInt(record.intValue()); + } + + @Override + public Integer deserialize(DataInputView source) throws IOException { + return Integer.valueOf(source.readInt()); + } + + @Override + public Integer deserialize(Integer reuse, DataInputView source) throws IOException { + return Integer.valueOf(source.readInt()); + } + + @Override + public void copy(DataInputView source, DataOutputView target) throws IOException { + target.writeInt(source.readInt()); + } + + @Override + public boolean canEqual(Object obj) { + return obj instanceof CustomIntSerializer; + } + + } +}