From 9a74a4b6116b9888388d1cd66f90dfdf7b904ca4 Mon Sep 17 00:00:00 2001 From: Stefan Richter Date: Wed, 5 Oct 2016 10:35:13 +0200 Subject: [PATCH] [hotfix] Create dedicated state descriptor to snaphot/restor Kafka consumer state --- .../typeutils/runtime/JavaSerializer.java | 27 ++++++----- .../apache/flink/util/InstantiationUtil.java | 12 ----- .../runtime/state/AbstractStateBackend.java | 4 +- .../state/DefaultOperatorStateBackend.java | 16 ++++++- .../runtime/state/OperatorStateStore.java | 13 +++++ .../state/RetrievableStreamStateHandle.java | 2 +- .../zookeeper/ZooKeeperStateHandleStore.java | 2 +- .../checkpoint/CheckpointCoordinatorTest.java | 19 +++++--- .../jobmanager/JobManagerHARecoveryTest.java | 7 ++- .../state/OperatorStateBackendTest.java | 15 ++++-- .../kafka/FlinkKafkaConsumerBase.java | 15 ++---- .../kafka/FlinkKafkaConsumerBaseTest.java | 47 ++++++++++--------- .../api/checkpoint/ListCheckpointed.java | 5 -- .../operators/AbstractUdfStreamOperator.java | 10 ++-- .../runtime/tasks/OneInputStreamTaskTest.java | 8 ++-- .../EventTimeWindowCheckpointingITCase.java | 7 +-- 16 files changed, 118 insertions(+), 91 deletions(-) diff --git a/flink-core/src/main/java/org/apache/flink/api/java/typeutils/runtime/JavaSerializer.java b/flink-core/src/main/java/org/apache/flink/api/java/typeutils/runtime/JavaSerializer.java index 4ae00d1e57f8e..3af7653c8d4cb 100644 --- a/flink-core/src/main/java/org/apache/flink/api/java/typeutils/runtime/JavaSerializer.java +++ b/flink-core/src/main/java/org/apache/flink/api/java/typeutils/runtime/JavaSerializer.java @@ -22,16 +22,25 @@ import org.apache.flink.core.memory.DataInputView; import org.apache.flink.core.memory.DataOutputView; import org.apache.flink.util.InstantiationUtil; +import org.apache.flink.util.Preconditions; import java.io.IOException; -import java.io.ObjectInputStream; -import java.io.ObjectOutputStream; import java.io.Serializable; public class JavaSerializer extends TypeSerializer { private static final long serialVersionUID = 1L; + private final ClassLoader userClassLoader; + + public JavaSerializer() { + this(Thread.currentThread().getContextClassLoader()); + } + + public JavaSerializer(ClassLoader userClassLoader) { + this.userClassLoader = Preconditions.checkNotNull(userClassLoader); + } + @Override public boolean isImmutableType() { return false; @@ -69,21 +78,15 @@ public int getLength() { @Override public void serialize(T record, DataOutputView target) throws IOException { - ObjectOutputStream oos = new ObjectOutputStream(new DataOutputViewStream(target)); - oos.writeObject(record); - oos.flush(); + InstantiationUtil.serializeObject(new DataOutputViewStream(target), record); } @Override public T deserialize(DataInputView source) throws IOException { - ObjectInputStream ois = new ObjectInputStream(new DataInputViewStream(source)); - try { - @SuppressWarnings("unchecked") - T nfa = (T) ois.readObject(); - return nfa; + return InstantiationUtil.deserializeObject(new DataInputViewStream(source), userClassLoader); } catch (ClassNotFoundException e) { - throw new RuntimeException("Could not deserialize NFA.", e); + throw new IOException("Could not deserialize object.", e); } } @@ -101,7 +104,7 @@ public void copy(DataInputView source, DataOutputView target) throws IOException @Override public boolean equals(Object obj) { - return obj instanceof JavaSerializer && ((JavaSerializer) obj).canEqual(this); + return obj instanceof JavaSerializer && userClassLoader.equals(((JavaSerializer) obj).userClassLoader); } @Override diff --git a/flink-core/src/main/java/org/apache/flink/util/InstantiationUtil.java b/flink-core/src/main/java/org/apache/flink/util/InstantiationUtil.java index de4cffbc7ee8f..cd5c91a1675a8 100644 --- a/flink-core/src/main/java/org/apache/flink/util/InstantiationUtil.java +++ b/flink-core/src/main/java/org/apache/flink/util/InstantiationUtil.java @@ -311,18 +311,6 @@ public static T deserializeObject(InputStream in, ClassLoader cl) throws IOE } } - @SuppressWarnings("unchecked") - public static T deserializeObject(byte[] bytes) throws IOException, ClassNotFoundException { - ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(bytes); - return deserializeObject(byteArrayInputStream); - } - - @SuppressWarnings("unchecked") - public static T deserializeObject(InputStream in) throws IOException, ClassNotFoundException { - ObjectInputStream objectInputStream = new ObjectInputStream(in); - return (T) objectInputStream.readObject(); - } - public static byte[] serializeObject(Object o) throws IOException { try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); ObjectOutputStream oos = new ObjectOutputStream(baos)) { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractStateBackend.java index c2e665b579554..c683a0236eca1 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractStateBackend.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractStateBackend.java @@ -83,7 +83,7 @@ public OperatorStateBackend createOperatorStateBackend( Environment env, String operatorIdentifier ) throws Exception { - return new DefaultOperatorStateBackend(); + return new DefaultOperatorStateBackend(env.getUserClassLoader()); } /** @@ -95,6 +95,6 @@ public OperatorStateBackend restoreOperatorStateBackend( String operatorIdentifier, Collection restoreSnapshots ) throws Exception { - return new DefaultOperatorStateBackend(restoreSnapshots); + return new DefaultOperatorStateBackend(env.getUserClassLoader(), restoreSnapshots); } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java index 0bd5eeb37faab..af97a3f177dc8 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java @@ -21,6 +21,7 @@ import org.apache.flink.api.common.state.ListState; import org.apache.flink.api.common.state.ListStateDescriptor; import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.java.typeutils.runtime.JavaSerializer; import org.apache.flink.core.fs.FSDataInputStream; import org.apache.flink.core.fs.FSDataOutputStream; import org.apache.flink.core.memory.DataInputView; @@ -30,6 +31,7 @@ import org.apache.flink.util.Preconditions; import java.io.IOException; +import java.io.Serializable; import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; @@ -46,6 +48,7 @@ public class DefaultOperatorStateBackend implements OperatorStateBackend { private final Map> registeredStates; private final Collection restoreSnapshots; private final ClosableRegistry closeStreamOnCancelRegistry; + private final JavaSerializer javaSerializer; /** * Restores a OperatorStateStore (lazily) using the provided snapshots. @@ -53,7 +56,11 @@ public class DefaultOperatorStateBackend implements OperatorStateBackend { * @param restoreSnapshots snapshots that are available to restore partitionable states on request. */ public DefaultOperatorStateBackend( + ClassLoader userClassLoader, Collection restoreSnapshots) { + + Preconditions.checkNotNull(userClassLoader); + this.javaSerializer = new JavaSerializer<>(userClassLoader); this.restoreSnapshots = restoreSnapshots; this.registeredStates = new HashMap<>(); this.closeStreamOnCancelRegistry = new ClosableRegistry(); @@ -62,8 +69,13 @@ public DefaultOperatorStateBackend( /** * Creates an empty OperatorStateStore. */ - public DefaultOperatorStateBackend() { - this(null); + public DefaultOperatorStateBackend(ClassLoader userClassLoader) { + this(userClassLoader, null); + } + + @Override + public ListState getDefaultPartitionableState(String stateName) throws Exception { + return getPartitionableState(new ListStateDescriptor<>(stateName, javaSerializer)); } /** diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateStore.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateStore.java index 6914a7ce89c63..ceab87f19b66a 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateStore.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateStore.java @@ -20,7 +20,9 @@ import org.apache.flink.api.common.state.ListState; import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.java.typeutils.runtime.JavaSerializer; +import java.io.Serializable; import java.util.Set; /** @@ -28,6 +30,17 @@ */ public interface OperatorStateStore { + String DEFAULT_OPERATOR_STATE_NAME = ""; + + /** + * Creates a satte descriptor of the given name that uses {@link JavaSerializer}. + * + * @param stateName The name of state to create + * @return A state descriptor that uses {@link JavaSerializer} + * @throws Exception + */ + ListState getDefaultPartitionableState(String stateName) throws Exception; + /** * Creates (or restores) the partitionable state in this backend. Each state is registered under a unique name. * The provided serializer is used to de/serialize the state in case of checkpointing (snapshot/restore). diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/RetrievableStreamStateHandle.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/RetrievableStreamStateHandle.java index 99343824a11c5..29d21acd88bb2 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/RetrievableStreamStateHandle.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/RetrievableStreamStateHandle.java @@ -55,7 +55,7 @@ public RetrievableStreamStateHandle(Path filePath, long stateSize) { @Override public T retrieveState() throws Exception { try (FSDataInputStream in = openInputStream()) { - return InstantiationUtil.deserializeObject(in); + return InstantiationUtil.deserializeObject(in, Thread.currentThread().getContextClassLoader()); } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/zookeeper/ZooKeeperStateHandleStore.java b/flink-runtime/src/main/java/org/apache/flink/runtime/zookeeper/ZooKeeperStateHandleStore.java index d62b13ed5e635..5623715b639f1 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/zookeeper/ZooKeeperStateHandleStore.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/zookeeper/ZooKeeperStateHandleStore.java @@ -219,7 +219,7 @@ public RetrievableStateHandle get(String pathInZooKeeper) throws Exception { checkNotNull(pathInZooKeeper, "Path in ZooKeeper"); byte[] data = client.getData().forPath(pathInZooKeeper); - return InstantiationUtil.deserializeObject(data); + return InstantiationUtil.deserializeObject(data, Thread.currentThread().getContextClassLoader()); } /** diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java index c39e436ee8b4e..5fb0e6f46c803 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java @@ -2478,14 +2478,19 @@ public static void compareKeyPartitionedState( for (int groupId : expectedHeadOpKeyGroupStateHandle.keyGroups()) { long offset = expectedHeadOpKeyGroupStateHandle.getOffsetForKeyGroup(groupId); inputStream.seek(offset); - int expectedKeyGroupState = InstantiationUtil.deserializeObject(inputStream); + int expectedKeyGroupState = + InstantiationUtil.deserializeObject(inputStream, Thread.currentThread().getContextClassLoader()); for (KeyGroupsStateHandle oneActualKeyGroupStateHandle : actualPartitionedKeyGroupState) { if (oneActualKeyGroupStateHandle.containsKeyGroup(groupId)) { long actualOffset = oneActualKeyGroupStateHandle.getOffsetForKeyGroup(groupId); - try (FSDataInputStream actualInputStream = - oneActualKeyGroupStateHandle.getStateHandle().openInputStream()) { + try (FSDataInputStream actualInputStream = oneActualKeyGroupStateHandle. + getStateHandle().openInputStream()) { + actualInputStream.seek(actualOffset); - int actualGroupState = InstantiationUtil.deserializeObject(actualInputStream); + + int actualGroupState = InstantiationUtil. + deserializeObject(actualInputStream, Thread.currentThread().getContextClassLoader()); + assertEquals(expectedKeyGroupState, actualGroupState); } } @@ -2506,7 +2511,8 @@ public static void comparePartitionableState( for (Map.Entry entry : operatorStateHandle.getStateNameToPartitionOffsets().entrySet()) { for (long offset : entry.getValue()) { in.seek(offset); - Integer state = InstantiationUtil.deserializeObject(in); + Integer state = InstantiationUtil. + deserializeObject(in, Thread.currentThread().getContextClassLoader()); expectedResult.add(i + " : " + entry.getKey() + " : " + state); } } @@ -2525,7 +2531,8 @@ public static void comparePartitionableState( for (Map.Entry entry : operatorStateHandle.getStateNameToPartitionOffsets().entrySet()) { for (long offset : entry.getValue()) { in.seek(offset); - Integer state = InstantiationUtil.deserializeObject(in); + Integer state = InstantiationUtil. + deserializeObject(in, Thread.currentThread().getContextClassLoader()); actualResult.add(i + " : " + entry.getKey() + " : " + state); } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java index 9b12cac08180f..38231ec89a7ec 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java @@ -452,7 +452,7 @@ public void setInitialState( int subtaskIndex = getIndexInSubtaskGroup(); if (subtaskIndex < recoveredStates.length) { try (FSDataInputStream in = chainedState.get(0).openInputStream()) { - recoveredStates[subtaskIndex] = InstantiationUtil.deserializeObject(in); + recoveredStates[subtaskIndex] = InstantiationUtil.deserializeObject(in, getUserCodeClassLoader()); } } } @@ -464,7 +464,10 @@ public boolean triggerCheckpoint(long checkpointId, long timestamp) { InstantiationUtil.serializeObject(checkpointId)); RetrievableStreamStateHandle state = new RetrievableStreamStateHandle(byteStreamStateHandle); - ChainedStateHandle chainedStateHandle = new ChainedStateHandle(Collections.singletonList(state)); + + ChainedStateHandle chainedStateHandle = + new ChainedStateHandle(Collections.singletonList(state)); + CheckpointStateHandles checkpointStateHandles = new CheckpointStateHandles(chainedStateHandle, null, Collections.emptyList()); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java index 56c898741e8ee..ff1a23dafc6cc 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java @@ -21,6 +21,7 @@ import org.apache.flink.api.common.state.ListState; import org.apache.flink.api.common.state.ListStateDescriptor; import org.apache.flink.api.java.typeutils.runtime.JavaSerializer; +import org.apache.flink.runtime.execution.Environment; import org.apache.flink.runtime.state.memory.MemoryStateBackend; import org.junit.Test; @@ -31,13 +32,21 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; public class OperatorStateBackendTest { AbstractStateBackend abstractStateBackend = new MemoryStateBackend(1024); + static Environment createMockEnvironment() { + Environment env = mock(Environment.class); + when(env.getUserClassLoader()).thenReturn(Thread.currentThread().getContextClassLoader()); + return env; + } + private OperatorStateBackend createNewOperatorStateBackend() throws Exception { - return abstractStateBackend.createOperatorStateBackend(null, "test-operator"); + return abstractStateBackend.createOperatorStateBackend(createMockEnvironment(), "test-operator"); } @Test @@ -123,8 +132,8 @@ public void testSnapshotRestore() throws Exception { operatorStateBackend.dispose(); - operatorStateBackend = abstractStateBackend. - restoreOperatorStateBackend(null, "testOperator", Collections.singletonList(stateHandle)); + operatorStateBackend = abstractStateBackend.restoreOperatorStateBackend( + createMockEnvironment(), "testOperator", Collections.singletonList(stateHandle)); assertEquals(0, operatorStateBackend.getRegisteredStateNames().size()); diff --git a/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBase.java b/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBase.java index 939b77b474d37..a30341b5a69d8 100644 --- a/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBase.java +++ b/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBase.java @@ -19,7 +19,6 @@ import org.apache.commons.collections.map.LinkedMap; import org.apache.flink.api.common.state.ListState; -import org.apache.flink.api.common.typeinfo.TypeHint; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.typeutils.ResultTypeQueryable; @@ -27,7 +26,6 @@ import org.apache.flink.runtime.state.CheckpointListener; import org.apache.flink.runtime.state.OperatorStateStore; import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction; -import org.apache.flink.streaming.api.checkpoint.ListCheckpointed; import org.apache.flink.streaming.api.functions.AssignerWithPeriodicWatermarks; import org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks; import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction; @@ -67,8 +65,6 @@ public abstract class FlinkKafkaConsumerBase extends RichParallelSourceFuncti CheckpointedFunction { private static final long serialVersionUID = -6272159445203409112L; - private static final String KAFKA_OFFSETS = "kafka_offsets"; - protected static final Logger LOG = LoggerFactory.getLogger(FlinkKafkaConsumerBase.class); /** The maximum number of pending non-committed checkpoints to track, to avoid memory leaks */ @@ -130,9 +126,6 @@ public FlinkKafkaConsumerBase(List topics, KeyedDeserializationSchema checkArgument(topics.size() > 0, "You have to define at least one topic."); this.deserializer = checkNotNull(deserializer, "valueDeserializer"); - - TypeInformation> typeInfo = - TypeInformation.of(new TypeHint>(){}); } /** @@ -314,13 +307,13 @@ public void close() throws Exception { // Checkpoint and restore // ------------------------------------------------------------------------ - @Override public void initializeState(OperatorStateStore stateStore) throws Exception { this.stateStore = stateStore; - ListState offsets = stateStore.getPartitionableState(ListCheckpointed.DEFAULT_LIST_DESCRIPTOR); + ListState offsets = + stateStore.getDefaultPartitionableState(OperatorStateStore.DEFAULT_OPERATOR_STATE_NAME); restoreToOffset = new HashMap<>(); @@ -339,8 +332,8 @@ public void prepareSnapshot(long checkpointId, long timestamp) throws Exception LOG.debug("storeOperatorState() called on closed source"); } else { - ListState listState = stateStore.getPartitionableState(ListCheckpointed.DEFAULT_LIST_DESCRIPTOR); - + ListState listState = + stateStore.getDefaultPartitionableState(OperatorStateStore.DEFAULT_OPERATOR_STATE_NAME); listState.clear(); final AbstractFetcher fetcher = this.kafkaFetcher; diff --git a/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBaseTest.java b/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBaseTest.java index fc8b7e964650d..45b45f032c4ea 100644 --- a/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBaseTest.java +++ b/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBaseTest.java @@ -21,7 +21,6 @@ import org.apache.commons.collections.map.LinkedMap; import org.apache.flink.api.common.state.ListState; import org.apache.flink.api.common.state.ListStateDescriptor; -import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.runtime.state.OperatorStateStore; import org.apache.flink.streaming.api.functions.AssignerWithPeriodicWatermarks; @@ -34,6 +33,7 @@ import org.junit.Test; import org.mockito.Matchers; +import java.io.Serializable; import java.lang.reflect.Field; import java.util.ArrayList; import java.util.Arrays; @@ -47,8 +47,6 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; -import static org.mockito.Mockito.any; -import static org.mockito.Mockito.anyString; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -115,31 +113,31 @@ public void ignoreCheckpointWhenNotRunning() throws Exception { public void checkRestoredCheckpointWhenFetcherNotReady() throws Exception { OperatorStateStore operatorStateStore = mock(OperatorStateStore.class); - TestingListState> expectedState = new TestingListState<>(); + TestingListState expectedState = new TestingListState<>(); expectedState.add(Tuple2.of(new KafkaTopicPartition("abc", 13), 16768L)); expectedState.add(Tuple2.of(new KafkaTopicPartition("def", 7), 987654321L)); - TestingListState> listState = new TestingListState<>(); + TestingListState listState = new TestingListState<>(); FlinkKafkaConsumerBase consumer = getConsumer(null, new LinkedMap(), true); - when(operatorStateStore.getPartitionableState(Matchers.any(ListStateDescriptor.class))).thenReturn(expectedState); + when(operatorStateStore.getDefaultPartitionableState(Matchers.any(String.class))).thenReturn(expectedState); consumer.initializeState(operatorStateStore); - when(operatorStateStore.getPartitionableState(Matchers.any(ListStateDescriptor.class))).thenReturn(listState); + when(operatorStateStore.getDefaultPartitionableState(Matchers.any(String.class))).thenReturn(listState); consumer.prepareSnapshot(17L, 17L); - Set> expected = new HashSet>(); + Set expected = new HashSet<>(); - for (Tuple2 kafkaTopicPartitionLongTuple2 : expectedState.get()) { - expected.add(kafkaTopicPartitionLongTuple2); + for (Serializable serializable : expectedState.get()) { + expected.add(serializable); } int counter = 0; - for (Tuple2 kafkaTopicPartitionLongTuple2 : listState.get()) { - assertTrue(expected.contains(kafkaTopicPartitionLongTuple2)); + for (Serializable serializable : listState.get()) { + assertTrue(expected.contains(serializable)); counter++; } @@ -154,8 +152,8 @@ public void checkRestoredNullCheckpointWhenFetcherNotReady() throws Exception { FlinkKafkaConsumerBase consumer = getConsumer(null, new LinkedMap(), true); OperatorStateStore operatorStateStore = mock(OperatorStateStore.class); - TestingListState> listState = new TestingListState<>(); - when(operatorStateStore.getPartitionableState(Matchers.any(ListStateDescriptor.class))).thenReturn(listState); + TestingListState listState = new TestingListState<>(); + when(operatorStateStore.getDefaultPartitionableState(Matchers.any(String.class))).thenReturn(listState); consumer.initializeState(operatorStateStore); consumer.prepareSnapshot(17L, 17L); @@ -188,12 +186,12 @@ public void testSnapshotState() throws Exception { OperatorStateStore backend = mock(OperatorStateStore.class); - TestingListState> listState1 = new TestingListState<>(); - TestingListState> listState2 = new TestingListState<>(); - TestingListState> listState3 = new TestingListState<>(); + TestingListState listState1 = new TestingListState<>(); + TestingListState listState2 = new TestingListState<>(); + TestingListState listState3 = new TestingListState<>(); - when(backend.getPartitionableState(Matchers.any(ListStateDescriptor.class))). - thenReturn(listState1, listState1, listState2, listState2, listState3, listState3); + when(backend.getDefaultPartitionableState(Matchers.any(String.class))). + thenReturn(listState1, listState1, listState2, listState3); consumer.initializeState(backend); @@ -202,7 +200,8 @@ public void testSnapshotState() throws Exception { HashMap snapshot1 = new HashMap<>(); - for (Tuple2 kafkaTopicPartitionLongTuple2 : listState1.get()) { + for (Serializable serializable : listState1.get()) { + Tuple2 kafkaTopicPartitionLongTuple2 = (Tuple2) serializable; snapshot1.put(kafkaTopicPartitionLongTuple2.f0, kafkaTopicPartitionLongTuple2.f1); } @@ -215,7 +214,8 @@ public void testSnapshotState() throws Exception { HashMap snapshot2 = new HashMap<>(); - for (Tuple2 kafkaTopicPartitionLongTuple2 : listState2.get()) { + for (Serializable serializable : listState2.get()) { + Tuple2 kafkaTopicPartitionLongTuple2 = (Tuple2) serializable; snapshot2.put(kafkaTopicPartitionLongTuple2.f0, kafkaTopicPartitionLongTuple2.f1); } @@ -233,8 +233,9 @@ public void testSnapshotState() throws Exception { HashMap snapshot3 = new HashMap<>(); - for (Tuple2 kafkaTopicPartitionLongTuple2 : listState1.get()) { - snapshot1.put(kafkaTopicPartitionLongTuple2.f0, kafkaTopicPartitionLongTuple2.f1); + for (Serializable serializable : listState3.get()) { + Tuple2 kafkaTopicPartitionLongTuple2 = (Tuple2) serializable; + snapshot3.put(kafkaTopicPartitionLongTuple2.f0, kafkaTopicPartitionLongTuple2.f1); } assertEquals(state3, snapshot3); diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/ListCheckpointed.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/ListCheckpointed.java index 430b2b98ab201..1031b8848e742 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/ListCheckpointed.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/ListCheckpointed.java @@ -19,8 +19,6 @@ package org.apache.flink.streaming.api.checkpoint; import org.apache.flink.annotation.PublicEvolving; -import org.apache.flink.api.common.state.ListStateDescriptor; -import org.apache.flink.api.java.typeutils.runtime.JavaSerializer; import java.io.Serializable; import java.util.List; @@ -36,9 +34,6 @@ @PublicEvolving public interface ListCheckpointed { - ListStateDescriptor DEFAULT_LIST_DESCRIPTOR = - new ListStateDescriptor<>("", new JavaSerializer<>()); - /** * Gets the current state of the function of operator. The state must reflect the result of all * prior invocations to this function. diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java index f683d9a93a38b..428442d231bc2 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java @@ -30,6 +30,7 @@ import org.apache.flink.runtime.state.CheckpointListener; import org.apache.flink.runtime.state.CheckpointStreamFactory; import org.apache.flink.runtime.state.OperatorStateHandle; +import org.apache.flink.runtime.state.OperatorStateStore; import org.apache.flink.streaming.api.checkpoint.Checkpointed; import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction; import org.apache.flink.streaming.api.checkpoint.ListCheckpointed; @@ -70,7 +71,6 @@ public abstract class AbstractUdfStreamOperator /** Flag to prevent duplicate function.close() calls in close() and dispose() */ private transient boolean functionsClosed = false; - public AbstractUdfStreamOperator(F userFunction) { this.userFunction = requireNonNull(userFunction); } @@ -107,8 +107,8 @@ public void open() throws Exception { @SuppressWarnings("unchecked") ListCheckpointed listCheckpointedFun = (ListCheckpointed) userFunction; - ListState listState = - getOperatorStateBackend().getPartitionableState(ListCheckpointed.DEFAULT_LIST_DESCRIPTOR); + ListState listState = getOperatorStateBackend(). + getDefaultPartitionableState(OperatorStateStore.DEFAULT_OPERATOR_STATE_NAME); List list = new ArrayList<>(); @@ -201,8 +201,8 @@ public RunnableFuture snapshotState( List partitionableState = ((ListCheckpointed) userFunction).snapshotState(checkpointId, timestamp); - ListState listState = - getOperatorStateBackend().getPartitionableState(ListCheckpointed.DEFAULT_LIST_DESCRIPTOR); + ListState listState = getOperatorStateBackend(). + getDefaultPartitionableState(OperatorStateStore.DEFAULT_OPERATOR_STATE_NAME); listState.clear(); diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java index 4003e59a46b9a..31ccc283050df 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java @@ -39,12 +39,12 @@ import org.apache.flink.runtime.state.KeyGroupsStateHandle; import org.apache.flink.runtime.state.OperatorStateHandle; import org.apache.flink.runtime.state.StreamStateHandle; -import org.apache.flink.streaming.api.operators.StreamCheckpointedOperator; import org.apache.flink.streaming.api.graph.StreamConfig; import org.apache.flink.streaming.api.graph.StreamEdge; import org.apache.flink.streaming.api.graph.StreamNode; import org.apache.flink.streaming.api.operators.AbstractStreamOperator; import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.operators.StreamCheckpointedOperator; import org.apache.flink.streaming.api.operators.StreamMap; import org.apache.flink.streaming.api.watermark.Watermark; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; @@ -632,8 +632,10 @@ public void restoreState(FSDataInputStream in) throws Exception { assertNotNull(in); - Serializable functionState= InstantiationUtil.deserializeObject(in); - Integer operatorState= InstantiationUtil.deserializeObject(in); + ClassLoader cl = Thread.currentThread().getContextClassLoader(); + + Serializable functionState= InstantiationUtil.deserializeObject(in, cl); + Integer operatorState= InstantiationUtil.deserializeObject(in, cl); assertEquals(random.nextInt(), functionState); assertEquals(random.nextInt(), (int) operatorState); diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/EventTimeWindowCheckpointingITCase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/EventTimeWindowCheckpointingITCase.java index 0aee128a24042..0687f665d7129 100644 --- a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/EventTimeWindowCheckpointingITCase.java +++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/EventTimeWindowCheckpointingITCase.java @@ -30,10 +30,10 @@ import org.apache.flink.contrib.streaming.state.RocksDBStateBackend; import org.apache.flink.runtime.minicluster.LocalFlinkMiniCluster; import org.apache.flink.runtime.state.AbstractStateBackend; +import org.apache.flink.runtime.state.CheckpointListener; import org.apache.flink.runtime.state.filesystem.FsStateBackend; import org.apache.flink.runtime.state.memory.MemoryStateBackend; import org.apache.flink.streaming.api.TimeCharacteristic; -import org.apache.flink.runtime.state.CheckpointListener; import org.apache.flink.streaming.api.checkpoint.Checkpointed; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.streaming.api.functions.sink.RichSinkFunction; @@ -45,7 +45,6 @@ import org.apache.flink.test.util.SuccessException; import org.apache.flink.util.Collector; import org.apache.flink.util.TestLogger; - import org.junit.AfterClass; import org.junit.Before; import org.junit.BeforeClass; @@ -62,7 +61,9 @@ import static java.util.concurrent.TimeUnit.MILLISECONDS; import static org.apache.flink.test.util.TestUtils.tryExecute; -import static org.junit.Assert.*; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; /** * This verifies that checkpointing works correctly with event time windows. This is more