From 0d3b605ef7b382ba26ec2189fa6af1ca04d2129d Mon Sep 17 00:00:00 2001 From: Ufuk Celebi Date: Mon, 30 May 2016 13:42:39 +0200 Subject: [PATCH 1/6] [FLINK-3779] [runtime] Add getSerializedValue(byte[]) to KvState [statebackend-rocksdb, core, streaming-java] - Adds the getSerializedValue(byte[]) to KvState, which is used to query single KvState instances. The serialization business is left to the KvState in order to not burden the accessor -- e.g. the querying network thread -- with setting up/accessing the serializers. - Adds quaryable flag to the StateDescriptor. State, which sets a queryable state name will be published for queries to the KvStateRegistry. - Prohibts null namespace and enforces VoidNamespace instead. This makes stuff more explicit. Furthermore, the concurrent map used for queryable memroy state does not allow working with null keys. --- .../streaming/state/AbstractRocksDBState.java | 40 +- .../streaming/state/RocksDBFoldingState.java | 15 +- .../streaming/state/RocksDBListState.java | 13 +- .../streaming/state/RocksDBReducingState.java | 13 +- .../streaming/state/RocksDBValueState.java | 29 +- .../state/RocksDBAsyncKVSnapshotTest.java | 13 +- .../state/RocksDBStateBackendConfigTest.java | 8 +- .../api/common/state/StateDescriptor.java | 42 ++ .../runtime/state/AbstractHeapState.java | 74 ++- .../runtime/state/AbstractStateBackend.java | 64 ++- .../runtime/state/GenericFoldingState.java | 11 + .../flink/runtime/state/GenericListState.java | 11 + .../runtime/state/GenericReducingState.java | 11 + .../apache/flink/runtime/state/KvState.java | 19 + .../flink/runtime/state/VoidNamespace.java | 35 ++ .../state/VoidNamespaceSerializer.java | 92 ++++ .../state/filesystem/FsFoldingState.java | 33 +- .../runtime/state/filesystem/FsListState.java | 34 +- .../state/filesystem/FsReducingState.java | 24 +- .../state/filesystem/FsStateBackend.java | 3 +- .../state/filesystem/FsValueState.java | 25 +- .../runtime/state/memory/MemFoldingState.java | 32 +- .../runtime/state/memory/MemListState.java | 32 +- .../state/memory/MemReducingState.java | 24 +- .../runtime/state/memory/MemValueState.java | 25 +- .../runtime/state/FileStateBackendTest.java | 6 + .../runtime/state/MemoryStateBackendTest.java | 8 + .../runtime/state/StateBackendTestBase.java | 434 ++++++++++++++++-- .../api/operators/AbstractStreamOperator.java | 11 +- .../operators/windowing/WindowOperator.java | 7 +- .../StreamingRuntimeContextTest.java | 19 +- 31 files changed, 1037 insertions(+), 170 deletions(-) create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/state/VoidNamespace.java create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/state/VoidNamespaceSerializer.java diff --git a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/AbstractRocksDBState.java b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/AbstractRocksDBState.java index f057b6eb726d1..3c4a2091c276b 100644 --- a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/AbstractRocksDBState.java +++ b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/AbstractRocksDBState.java @@ -22,12 +22,11 @@ import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.core.memory.DataOutputView; import org.apache.flink.core.memory.DataOutputViewStreamWrapper; - import org.apache.flink.runtime.state.KvState; import org.apache.flink.runtime.state.KvStateSnapshot; +import org.apache.flink.util.Preconditions; import org.rocksdb.ColumnFamilyHandle; import org.rocksdb.RocksDBException; - import org.rocksdb.WriteOptions; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -46,7 +45,7 @@ * @param The type of {@link State}. * @param The type of {@link StateDescriptor}. */ -public abstract class AbstractRocksDBState> +public abstract class AbstractRocksDBState, V> implements KvState, State { private static final Logger LOG = LoggerFactory.getLogger(AbstractRocksDBState.class); @@ -63,6 +62,9 @@ public abstract class AbstractRocksDBState namespaceSerializer, + SD stateDesc, RocksDBStateBackend backend) { this.namespaceSerializer = namespaceSerializer; @@ -84,6 +86,8 @@ protected AbstractRocksDBState(ColumnFamilyHandle columnFamily, writeOptions = new WriteOptions(); writeOptions.setDisableWAL(true); + + this.stateDesc = Preconditions.checkNotNull(stateDesc, "State Descriptor"); } // ------------------------------------------------------------------------ @@ -109,7 +113,7 @@ protected void writeKeyAndNamespace(DataOutputView out) throws IOException { @Override public void setCurrentNamespace(N namespace) { - this.currentNamespace = namespace; + this.currentNamespace = Preconditions.checkNotNull(namespace, "Namespace"); } @Override @@ -117,10 +121,14 @@ public void dispose() { // ignore because we don't hold any state ourselves } + @Override + public SD getStateDescriptor() { + return stateDesc; + } + @Override public void setCurrentKey(K key) { // ignore because we don't hold any state ourselves - } @Override @@ -128,5 +136,21 @@ public KvStateSnapshot snapshot(long checkpoin long timestamp) throws Exception { throw new RuntimeException("Should not be called. Backups happen in RocksDBStateBackend."); } -} + @Override + @SuppressWarnings("unchecked") + public byte[] getSerializedValue(byte[] serializedKeyAndNamespace) throws Exception { + // Serialized key and namespace is expected to be of the same format + // as writeKeyAndNamespace() + Preconditions.checkNotNull(serializedKeyAndNamespace, "Serialized key and namespace"); + + byte[] value = backend.db.get(columnFamily, serializedKeyAndNamespace); + + if (value != null) { + return value; + } else { + return null; + } + } + +} diff --git a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBFoldingState.java b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBFoldingState.java index 218fa2a56dca8..f1cf40995a60e 100644 --- a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBFoldingState.java +++ b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBFoldingState.java @@ -24,7 +24,6 @@ import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.core.memory.DataInputViewStreamWrapper; import org.apache.flink.core.memory.DataOutputViewStreamWrapper; - import org.rocksdb.ColumnFamilyHandle; import org.rocksdb.RocksDBException; import org.rocksdb.WriteOptions; @@ -33,8 +32,6 @@ import java.io.ByteArrayOutputStream; import java.io.IOException; -import static java.util.Objects.requireNonNull; - /** * {@link FoldingState} implementation that stores state in RocksDB. * @@ -44,15 +41,12 @@ * @param The type of the value in the folding state. */ public class RocksDBFoldingState - extends AbstractRocksDBState, FoldingStateDescriptor> + extends AbstractRocksDBState, FoldingStateDescriptor, ACC> implements FoldingState { /** Serializer for the values */ private final TypeSerializer valueSerializer; - /** This holds the name of the state and can create an initial default value for the state. */ - private final FoldingStateDescriptor stateDesc; - /** User-specified fold function */ private final FoldFunction foldFunction; @@ -74,9 +68,8 @@ public RocksDBFoldingState(ColumnFamilyHandle columnFamily, FoldingStateDescriptor stateDesc, RocksDBStateBackend backend) { - super(columnFamily, namespaceSerializer, backend); - - this.stateDesc = requireNonNull(stateDesc); + super(columnFamily, namespaceSerializer, stateDesc, backend); + this.valueSerializer = stateDesc.getSerializer(); this.foldFunction = stateDesc.getFoldFunction(); @@ -125,5 +118,5 @@ public void add(T value) throws IOException { throw new RuntimeException("Error while adding data to RocksDB", e); } } -} +} diff --git a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBListState.java b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBListState.java index ce3a48e98f1ad..ff1038e68c905 100644 --- a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBListState.java +++ b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBListState.java @@ -23,7 +23,6 @@ import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.core.memory.DataInputViewStreamWrapper; import org.apache.flink.core.memory.DataOutputViewStreamWrapper; - import org.rocksdb.ColumnFamilyHandle; import org.rocksdb.RocksDBException; import org.rocksdb.WriteOptions; @@ -34,8 +33,6 @@ import java.util.ArrayList; import java.util.List; -import static java.util.Objects.requireNonNull; - /** * {@link ListState} implementation that stores state in RocksDB. * @@ -48,15 +45,12 @@ * @param The type of the values in the list state. */ public class RocksDBListState - extends AbstractRocksDBState, ListStateDescriptor> + extends AbstractRocksDBState, ListStateDescriptor, V> implements ListState { /** Serializer for the values */ private final TypeSerializer valueSerializer; - /** This holds the name of the state and can create an initial default value for the state. */ - private final ListStateDescriptor stateDesc; - /** * We disable writes to the write-ahead-log here. We can't have these in the base class * because JNI segfaults for some reason if they are. @@ -75,8 +69,7 @@ public RocksDBListState(ColumnFamilyHandle columnFamily, ListStateDescriptor stateDesc, RocksDBStateBackend backend) { - super(columnFamily, namespaceSerializer, backend); - this.stateDesc = requireNonNull(stateDesc); + super(columnFamily, namespaceSerializer, stateDesc, backend); this.valueSerializer = stateDesc.getSerializer(); writeOptions = new WriteOptions(); @@ -129,5 +122,5 @@ public void add(V value) throws IOException { throw new RuntimeException("Error while adding data to RocksDB", e); } } -} +} diff --git a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBReducingState.java b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBReducingState.java index 0cdd3edae08ab..efa29310cc82b 100644 --- a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBReducingState.java +++ b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBReducingState.java @@ -24,7 +24,6 @@ import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.core.memory.DataInputViewStreamWrapper; import org.apache.flink.core.memory.DataOutputViewStreamWrapper; - import org.rocksdb.ColumnFamilyHandle; import org.rocksdb.RocksDBException; import org.rocksdb.WriteOptions; @@ -33,8 +32,6 @@ import java.io.ByteArrayOutputStream; import java.io.IOException; -import static java.util.Objects.requireNonNull; - /** * {@link ReducingState} implementation that stores state in RocksDB. * @@ -43,15 +40,12 @@ * @param The type of value that the state state stores. */ public class RocksDBReducingState - extends AbstractRocksDBState, ReducingStateDescriptor> + extends AbstractRocksDBState, ReducingStateDescriptor, V> implements ReducingState { /** Serializer for the values */ private final TypeSerializer valueSerializer; - /** This holds the name of the state and can create an initial default value for the state. */ - private final ReducingStateDescriptor stateDesc; - /** User-specified reduce function */ private final ReduceFunction reduceFunction; @@ -73,8 +67,7 @@ public RocksDBReducingState(ColumnFamilyHandle columnFamily, ReducingStateDescriptor stateDesc, RocksDBStateBackend backend) { - super(columnFamily, namespaceSerializer, backend); - this.stateDesc = requireNonNull(stateDesc); + super(columnFamily, namespaceSerializer, stateDesc, backend); this.valueSerializer = stateDesc.getSerializer(); this.reduceFunction = stateDesc.getReduceFunction(); @@ -123,5 +116,5 @@ public void add(V value) throws IOException { throw new RuntimeException("Error while adding data to RocksDB", e); } } -} +} diff --git a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBValueState.java b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBValueState.java index 1a5cb9ef7d4dc..62bc36608886f 100644 --- a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBValueState.java +++ b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBValueState.java @@ -23,7 +23,8 @@ import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.core.memory.DataInputViewStreamWrapper; import org.apache.flink.core.memory.DataOutputViewStreamWrapper; - +import org.apache.flink.runtime.query.netty.message.KvStateRequestSerializer; +import org.apache.flink.util.Preconditions; import org.rocksdb.ColumnFamilyHandle; import org.rocksdb.RocksDBException; import org.rocksdb.WriteOptions; @@ -32,8 +33,6 @@ import java.io.ByteArrayOutputStream; import java.io.IOException; -import static java.util.Objects.requireNonNull; - /** * {@link ValueState} implementation that stores state in RocksDB. * @@ -42,15 +41,12 @@ * @param The type of value that the state state stores. */ public class RocksDBValueState - extends AbstractRocksDBState, ValueStateDescriptor> + extends AbstractRocksDBState, ValueStateDescriptor, V> implements ValueState { /** Serializer for the values */ private final TypeSerializer valueSerializer; - /** This holds the name of the state and can create an initial default value for the state. */ - private final ValueStateDescriptor stateDesc; - /** * We disable writes to the write-ahead-log here. We can't have these in the base class * because JNI segfaults for some reason if they are. @@ -69,8 +65,7 @@ public RocksDBValueState(ColumnFamilyHandle columnFamily, ValueStateDescriptor stateDesc, RocksDBStateBackend backend) { - super(columnFamily, namespaceSerializer, backend); - this.stateDesc = requireNonNull(stateDesc); + super(columnFamily, namespaceSerializer, stateDesc, backend); this.valueSerializer = stateDesc.getSerializer(); writeOptions = new WriteOptions(); @@ -112,5 +107,19 @@ public void update(V value) throws IOException { throw new RuntimeException("Error while adding data to RocksDB", e); } } -} + @Override + public byte[] getSerializedValue(byte[] serializedKeyAndNamespace) throws Exception { + // Serialized key and namespace is expected to be of the same format + // as writeKeyAndNamespace() + Preconditions.checkNotNull(serializedKeyAndNamespace, "Serialized key and namespace"); + + byte[] value = backend.db.get(columnFamily, serializedKeyAndNamespace); + + if (value != null) { + return value; + } else { + return KvStateRequestSerializer.serializeValue(stateDesc.getDefaultValue(), stateDesc.getSerializer()); + } + } +} diff --git a/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBAsyncKVSnapshotTest.java b/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBAsyncKVSnapshotTest.java index a58686b260f9c..7118cf6beeaf9 100644 --- a/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBAsyncKVSnapshotTest.java +++ b/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBAsyncKVSnapshotTest.java @@ -23,13 +23,14 @@ import org.apache.flink.api.common.state.ValueStateDescriptor; import org.apache.flink.api.common.typeinfo.BasicTypeInfo; import org.apache.flink.api.common.typeutils.base.StringSerializer; -import org.apache.flink.api.common.typeutils.base.VoidSerializer; import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.configuration.ConfigConstants; import org.apache.flink.core.testutils.OneShotLatch; import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter; import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider; import org.apache.flink.runtime.state.StateHandle; +import org.apache.flink.runtime.state.VoidNamespace; +import org.apache.flink.runtime.state.VoidNamespaceSerializer; import org.apache.flink.runtime.state.memory.MemoryStateBackend; import org.apache.flink.streaming.api.graph.StreamConfig; import org.apache.flink.streaming.api.operators.AbstractStreamOperator; @@ -295,8 +296,9 @@ public void open() throws Exception { // also get the state in open, this way we are sure that it was created before // we trigger the test checkpoint - ValueState state = getPartitionedState(null, - VoidSerializer.INSTANCE, + ValueState state = getPartitionedState( + VoidNamespace.INSTANCE, + VoidNamespaceSerializer.INSTANCE, new ValueStateDescriptor<>("count", StringSerializer.INSTANCE, "hello")); @@ -306,8 +308,9 @@ public void open() throws Exception { public void processElement(StreamRecord element) throws Exception { // we also don't care - ValueState state = getPartitionedState(null, - VoidSerializer.INSTANCE, + ValueState state = getPartitionedState( + VoidNamespace.INSTANCE, + VoidNamespaceSerializer.INSTANCE, new ValueStateDescriptor<>("count", StringSerializer.INSTANCE, "hello")); diff --git a/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackendConfigTest.java b/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackendConfigTest.java index fca577369f9ab..0878b8c5998f8 100644 --- a/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackendConfigTest.java +++ b/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackendConfigTest.java @@ -23,11 +23,12 @@ import org.apache.flink.api.common.state.ValueStateDescriptor; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.common.typeutils.base.IntSerializer; -import org.apache.flink.api.common.typeutils.base.VoidSerializer; import org.apache.flink.runtime.execution.Environment; import org.apache.flink.runtime.io.disk.iomanager.IOManager; import org.apache.flink.runtime.state.AbstractStateBackend; +import org.apache.flink.runtime.state.VoidNamespace; +import org.apache.flink.runtime.state.VoidNamespaceSerializer; import org.apache.flink.util.OperatingSystem; import org.junit.Assume; import org.junit.Before; @@ -176,7 +177,10 @@ public void testContinueOnSomeDbDirectoriesMissing() throws Exception { rocksDbBackend.initializeForJob(getMockEnvironment(), "foobar", IntSerializer.INSTANCE); // actually get a state to see whether we can write to the storage directory - rocksDbBackend.getPartitionedState(null, VoidSerializer.INSTANCE, new ValueStateDescriptor<>("test", String.class, "")); + rocksDbBackend.getPartitionedState( + VoidNamespace.INSTANCE, + VoidNamespaceSerializer.INSTANCE, + new ValueStateDescriptor<>("test", String.class, "")); } catch (Exception e) { e.printStackTrace(); diff --git a/flink-core/src/main/java/org/apache/flink/api/common/state/StateDescriptor.java b/flink-core/src/main/java/org/apache/flink/api/common/state/StateDescriptor.java index 87ed71d6b1a8b..d99f4de26c259 100644 --- a/flink-core/src/main/java/org/apache/flink/api/common/state/StateDescriptor.java +++ b/flink-core/src/main/java/org/apache/flink/api/common/state/StateDescriptor.java @@ -25,6 +25,7 @@ import org.apache.flink.api.java.typeutils.TypeExtractor; import org.apache.flink.core.memory.DataInputViewStreamWrapper; import org.apache.flink.core.memory.DataOutputViewStreamWrapper; +import org.apache.flink.util.Preconditions; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; @@ -55,6 +56,9 @@ public abstract class StateDescriptor implements Serializabl * or lazily once the type is serialized or an ExecutionConfig is provided. */ protected TypeSerializer serializer; + /** Name for queries against state created from this StateDescriptor. */ + private String queryableStateName; + /** The default value returned by the state when no other value is bound to a key */ protected transient T defaultValue; @@ -153,6 +157,43 @@ public TypeSerializer getSerializer() { } } + /** + * Sets the name for queries of state created from this descriptor. + * + *

If a name is set, the created state will be published for queries + * during runtime. The name needs to be unique per job. If there is another + * state instance published under the same name, the job will fail during runtime. + * + * @param queryableStateName State name for queries (unique name per job) + * @throws IllegalStateException If queryable state name already set + */ + public void setQueryable(String queryableStateName) { + if (this.queryableStateName == null) { + this.queryableStateName = Preconditions.checkNotNull(queryableStateName, "Registration name"); + } else { + throw new IllegalStateException("Queryable state name already set"); + } + } + + /** + * Returns the queryable state name. + * + * @return Queryable state name or null if not set. + */ + public String getQueryableStateName() { + return queryableStateName; + } + + /** + * Returns whether the state created from this descriptor is queryable. + * + * @return true if state is queryable, false + * otherwise. + */ + public boolean isQueryable() { + return queryableStateName != null; + } + /** * Creates a new {@link State} on the given {@link StateBackend}. * @@ -221,6 +262,7 @@ public String toString() { "{name=" + name + ", defaultValue=" + defaultValue + ", serializer=" + serializer + + (isQueryable() ? ", queryableStateName=" + queryableStateName + "" : "") + '}'; } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractHeapState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractHeapState.java index 8e77752894e10..6fa45752a8fb9 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractHeapState.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractHeapState.java @@ -22,11 +22,13 @@ import org.apache.flink.api.common.state.State; import org.apache.flink.api.common.state.StateDescriptor; import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.runtime.query.netty.message.KvStateRequestSerializer; +import org.apache.flink.util.Preconditions; import java.util.HashMap; import java.util.Map; - -import static java.util.Objects.requireNonNull; +import java.util.concurrent.ConcurrentHashMap; /** * Base class for partitioned {@link ListState} implementations that are backed by a regular @@ -43,7 +45,7 @@ public abstract class AbstractHeapState, State { /** Map containing the actual key/value pairs */ - protected final HashMap> state; + protected final Map> state; /** Serializer for the state value. The state value could be a List, for example. */ protected final TypeSerializer stateSerializer; @@ -93,10 +95,20 @@ protected AbstractHeapState(TypeSerializer keySerializer, TypeSerializer namespaceSerializer, TypeSerializer stateSerializer, SD stateDesc, - HashMap> state) { - this.state = requireNonNull(state); - this.keySerializer = requireNonNull(keySerializer); - this.namespaceSerializer = requireNonNull(namespaceSerializer); + Map> state) { + + Preconditions.checkNotNull(state, "State map"); + + // Make sure that the state map supports concurrent read access for + // queries. See also #createNewNamespaceMap for the namespace maps. + if (stateDesc.isQueryable()) { + this.state = new ConcurrentHashMap<>(state); + } else { + this.state = state; + } + + this.keySerializer = Preconditions.checkNotNull(keySerializer); + this.namespaceSerializer = Preconditions.checkNotNull(namespaceSerializer); this.stateSerializer = stateSerializer; this.stateDesc = stateDesc; } @@ -116,7 +128,7 @@ public final void clear() { @Override public final void setCurrentKey(K currentKey) { - this.currentKey = currentKey; + this.currentKey = Preconditions.checkNotNull(currentKey, "Key"); } @Override @@ -124,10 +136,22 @@ public final void setCurrentNamespace(N namespace) { if (namespace != null && namespace.equals(this.currentNamespace)) { return; } - this.currentNamespace = namespace; + this.currentNamespace = Preconditions.checkNotNull(namespace, "Namespace"); this.currentNSState = state.get(currentNamespace); } + @Override + public byte[] getSerializedValue(byte[] serializedKeyAndNamespace) throws Exception { + Preconditions.checkNotNull(serializedKeyAndNamespace, "Serialized key and namespace"); + + Tuple2 keyAndNamespace = KvStateRequestSerializer.deserializeKeyAndNamespace( + serializedKeyAndNamespace, keySerializer, namespaceSerializer); + + return getSerializedValue(keyAndNamespace.f0, keyAndNamespace.f1); + } + + protected abstract byte[] getSerializedValue(K key, N namespace) throws Exception; + /** * Returns the number of all state pairs in this state, across namespaces. */ @@ -144,6 +168,11 @@ public void dispose() { state.clear(); } + @Override + public SD getStateDescriptor() { + return stateDesc; + } + /** * Gets the serializer for the keys. * @@ -161,4 +190,31 @@ public final TypeSerializer getKeySerializer() { public final TypeSerializer getNamespaceSerializer() { return namespaceSerializer; } + + /** + * Creates a new namespace map. + * + *

If the state queryable ({@link StateDescriptor#isQueryable()}, this + * will create a concurrent hash map instead of a regular one. + * + * @return A new namespace map. + */ + protected Map createNewNamespaceMap() { + if (stateDesc.isQueryable()) { + return new ConcurrentHashMap<>(); + } else { + return new HashMap<>(); + } + } + + // ------------------------------------------------------------------------ + + /** + * Returns the internal state map for testing. + * + * @return The internal state map + */ + Map> getStateMap() { + return state; + } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractStateBackend.java index b86688b93c82b..6fc947582617c 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractStateBackend.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractStateBackend.java @@ -38,6 +38,8 @@ import org.apache.flink.core.memory.DataOutputView; import org.apache.flink.core.memory.DataOutputViewStreamWrapper; import org.apache.flink.runtime.execution.Environment; +import org.apache.flink.runtime.query.TaskKvStateRegistry; +import org.apache.flink.util.Preconditions; import java.io.IOException; import java.io.OutputStream; @@ -73,6 +75,12 @@ public abstract class AbstractStateBackend implements java.io.Serializable { @SuppressWarnings("rawtypes") private transient KvState lastState; + /** KvStateRegistry helper for this task */ + protected transient TaskKvStateRegistry kvStateRegistry; + + /** Key group index of this state backend */ + protected transient int keyGroupIndex; + // ------------------------------------------------------------------------ // initialization and cleanup // ------------------------------------------------------------------------ @@ -87,11 +95,16 @@ public abstract class AbstractStateBackend implements java.io.Serializable { * case the job that uses the state backend is considered failed during * deployment. */ - public void initializeForJob(Environment env, - String operatorIdentifier, - TypeSerializer keySerializer) throws Exception { + public void initializeForJob( + Environment env, + String operatorIdentifier, + TypeSerializer keySerializer) throws Exception { + this.userCodeClassLoader = env.getUserClassLoader(); this.keySerializer = keySerializer; + + this.keyGroupIndex = env.getTaskInfo().getIndexOfThisSubtask(); + this.kvStateRegistry = env.getTaskKvStateRegistry(); } /** @@ -110,6 +123,10 @@ public void initializeForJob(Environment env, public abstract void close() throws Exception; public void dispose() { + if (kvStateRegistry != null) { + kvStateRegistry.unregisterAll(); + } + lastName = null; lastState = null; if (keyValueStates != null) { @@ -176,7 +193,7 @@ public void dispose() { */ @SuppressWarnings({"unchecked", "rawtypes"}) public void setCurrentKey(Object currentKey) { - this.currentKey = currentKey; + this.currentKey = Preconditions.checkNotNull(currentKey, "Key"); if (keyValueStates != null) { for (KvState kv : keyValueStates) { kv.setCurrentKey(currentKey); @@ -203,6 +220,8 @@ public Object getCurrentKey() { */ @SuppressWarnings({"rawtypes", "unchecked"}) public S getPartitionedState(final N namespace, final TypeSerializer namespaceSerializer, final StateDescriptor stateDescriptor) throws Exception { + Preconditions.checkNotNull(namespace, "Namespace"); + Preconditions.checkNotNull(namespaceSerializer, "Namespace serializer"); if (keySerializer == null) { throw new RuntimeException("State key serializer has not been configured in the config. " + @@ -231,7 +250,7 @@ public S getPartitionedState(final N namespace, final TypeS } // create a new blank key/value state - S kvstate = stateDescriptor.bind(new StateBackend() { + S state = stateDescriptor.bind(new StateBackend() { @Override public ValueState createValueState(ValueStateDescriptor stateDesc) throws Exception { return AbstractStateBackend.this.createValueState(namespaceSerializer, stateDesc); @@ -254,16 +273,31 @@ public FoldingState createFoldingState(FoldingStateDescriptor) kvstate; + lastState = kvState; + + if (currentKey != null) { + kvState.setCurrentKey(currentKey); + } + + kvState.setCurrentNamespace(namespace); + + // Publish queryable state + if (stateDescriptor.isQueryable()) { + if (kvStateRegistry == null) { + throw new IllegalStateException("State backend has not been initialized for job."); + } - ((KvState) kvstate).setCurrentKey(currentKey); - ((KvState) kvstate).setCurrentNamespace(namespace); + String name = stateDescriptor.getQueryableStateName(); + kvStateRegistry.registerKvState(keyGroupIndex, name, kvState); + } - return kvstate; + return state; } @SuppressWarnings("unchecked,rawtypes") @@ -352,6 +386,16 @@ public void injectKeyValueStateSnapshots(HashMap keyVal keySerializer, userCodeClassLoader); keyValueStatesByName.put(state.getKey(), kvState); + + try { + // Publish queryable state + StateDescriptor stateDesc = kvState.getStateDescriptor(); + if (stateDesc.isQueryable()) { + String queryableStateName = stateDesc.getQueryableStateName(); + kvStateRegistry.registerKvState(keyGroupIndex, queryableStateName, kvState); + } + } catch (Throwable ignored) { + } } keyValueStates = keyValueStatesByName.values().toArray(new KvState[keyValueStatesByName.size()]); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/GenericFoldingState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/GenericFoldingState.java index 5f6600d1e1b26..4d75243f4b5d3 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/GenericFoldingState.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/GenericFoldingState.java @@ -15,6 +15,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.flink.runtime.state; import org.apache.flink.api.common.functions.FoldFunction; @@ -68,6 +69,11 @@ public void setCurrentNamespace(N namespace) { wrappedState.setCurrentNamespace(namespace); } + @Override + public byte[] getSerializedValue(byte[] serializedKeyAndNamespace) throws Exception { + return wrappedState.getSerializedValue(serializedKeyAndNamespace); + } + @Override public KvStateSnapshot, FoldingStateDescriptor, Backend> snapshot( long checkpointId, @@ -83,6 +89,11 @@ public void dispose() { wrappedState.dispose(); } + @Override + public FoldingStateDescriptor getStateDescriptor() { + throw new UnsupportedOperationException("Not supported by generic state type"); + } + @Override public ACC get() throws Exception { return wrappedState.value(); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/GenericListState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/GenericListState.java index 2e408989c1e86..3d259e328e151 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/GenericListState.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/GenericListState.java @@ -15,6 +15,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.flink.runtime.state; import org.apache.flink.api.common.state.ListState; @@ -64,6 +65,11 @@ public void setCurrentNamespace(N namespace) { wrappedState.setCurrentNamespace(namespace); } + @Override + public byte[] getSerializedValue(byte[] serializedKeyAndNamespace) throws Exception { + return wrappedState.getSerializedValue(serializedKeyAndNamespace); + } + @Override public KvStateSnapshot, ListStateDescriptor, Backend> snapshot( long checkpointId, @@ -79,6 +85,11 @@ public void dispose() { wrappedState.dispose(); } + @Override + public ListStateDescriptor getStateDescriptor() { + throw new UnsupportedOperationException("Not supported by generic state type"); + } + @Override public Iterable get() throws Exception { return wrappedState.value(); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/GenericReducingState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/GenericReducingState.java index 9a2eb21f3433b..ccc946044553f 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/GenericReducingState.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/GenericReducingState.java @@ -15,6 +15,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.flink.runtime.state; import org.apache.flink.api.common.functions.ReduceFunction; @@ -67,6 +68,11 @@ public void setCurrentNamespace(N namespace) { wrappedState.setCurrentNamespace(namespace); } + @Override + public byte[] getSerializedValue(byte[] serializedKeyAndNamespace) throws Exception { + return wrappedState.getSerializedValue(serializedKeyAndNamespace); + } + @Override public KvStateSnapshot, ReducingStateDescriptor, Backend> snapshot( long checkpointId, @@ -82,6 +88,11 @@ public void dispose() { wrappedState.dispose(); } + @Override + public ReducingStateDescriptor getStateDescriptor() { + throw new UnsupportedOperationException("Not supported by generic state type"); + } + @Override public T get() throws Exception { return wrappedState.value(); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KvState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KvState.java index 89de000e9902b..a8aa872edebe3 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KvState.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KvState.java @@ -51,6 +51,19 @@ public interface KvState */ void setCurrentNamespace(N namespace); + /** + * Returns the serialized value for the given key and namespace. + * + *

If no value is associated with key and namespace, null + * is returned. + * + * @param serializedKeyAndNamespace Serialized key and namespace + * @return Serialized value or null if no value is associated + * with the key and namespace. + * @throws Exception Exceptions during serialization are forwarded + */ + byte[] getSerializedValue(byte[] serializedKeyAndNamespace) throws Exception; + /** * Creates a snapshot of this state. * @@ -67,4 +80,10 @@ public interface KvState * Disposes the key/value state, releasing all occupied resources. */ void dispose(); + + /** + * Returns the state descriptor from which the KvState instance was created. + */ + SD getStateDescriptor(); + } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/VoidNamespace.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/VoidNamespace.java new file mode 100644 index 0000000000000..9ff9df0e50fc2 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/VoidNamespace.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state; + +/** + * Uninstantiable placeholder class for state without a namespace. + */ +public final class VoidNamespace { + + public static final VoidNamespace INSTANCE = new VoidNamespace(); + + private VoidNamespace() { + } + + public static VoidNamespace get() { + return INSTANCE; + } + +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/VoidNamespaceSerializer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/VoidNamespaceSerializer.java new file mode 100644 index 0000000000000..8b58891e36ba6 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/VoidNamespaceSerializer.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state; + +import org.apache.flink.api.common.typeutils.base.TypeSerializerSingleton; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; + +import java.io.IOException; + +/** + * Serializer for {@link VoidNamespace}. + */ +public final class VoidNamespaceSerializer extends TypeSerializerSingleton { + + private static final long serialVersionUID = 1L; + + public static final VoidNamespaceSerializer INSTANCE = new VoidNamespaceSerializer(); + + @Override + public boolean isImmutableType() { + return true; + } + + @Override + public VoidNamespace createInstance() { + return VoidNamespace.get(); + } + + @Override + public VoidNamespace copy(VoidNamespace from) { + return VoidNamespace.get(); + } + + @Override + public VoidNamespace copy(VoidNamespace from, VoidNamespace reuse) { + return VoidNamespace.get(); + } + + @Override + public int getLength() { + return 0; + } + + @Override + public void serialize(VoidNamespace record, DataOutputView target) throws IOException { + // Make progress in the stream, write one byte. + // + // We could just skip writing anything here, because of the way this is + // used with the state backends, but if it is ever used somewhere else + // (even though it is unlikely to happen), it would be a problem. + target.write(0); + } + + @Override + public VoidNamespace deserialize(DataInputView source) throws IOException { + source.readByte(); + return VoidNamespace.get(); + } + + @Override + public VoidNamespace deserialize(VoidNamespace reuse, DataInputView source) throws IOException { + source.readByte(); + return VoidNamespace.get(); + } + + @Override + public void copy(DataInputView source, DataOutputView target) throws IOException { + target.write(source.readByte()); + } + + @Override + public boolean canEqual(Object obj) { + return obj instanceof VoidNamespaceSerializer; + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsFoldingState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsFoldingState.java index 90baf36cf805d..2fbbdc9e0eadd 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsFoldingState.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsFoldingState.java @@ -23,8 +23,10 @@ import org.apache.flink.api.common.state.FoldingStateDescriptor; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.core.fs.Path; +import org.apache.flink.runtime.query.netty.message.KvStateRequestSerializer; import org.apache.flink.runtime.state.KvState; import org.apache.flink.runtime.state.KvStateSnapshot; +import org.apache.flink.util.Preconditions; import java.io.IOException; import java.util.HashMap; @@ -71,7 +73,7 @@ public FsFoldingState(FsStateBackend backend, * @param keySerializer The serializer for the key. * @param namespaceSerializer The serializer for the namespace. * @param stateDesc The state identifier for the state. This contains name -* and can create a default state value. + * and can create a default state value. * @param state The map of key/value pairs to initialize the state with. */ public FsFoldingState(FsStateBackend backend, @@ -86,20 +88,24 @@ public FsFoldingState(FsStateBackend backend, @Override public ACC get() { if (currentNSState == null) { + Preconditions.checkState(currentNamespace != null, "No namespace set"); currentNSState = state.get(currentNamespace); } - return currentNSState != null ? - currentNSState.get(currentKey) : null; + if (currentNSState != null) { + Preconditions.checkState(currentKey != null, "No key set"); + return currentNSState.get(currentKey); + } else { + return null; + } } @Override public void add(T value) throws IOException { - if (currentKey == null) { - throw new RuntimeException("No key available."); - } + Preconditions.checkState(currentKey != null, "No key set"); if (currentNSState == null) { - currentNSState = new HashMap<>(); + Preconditions.checkState(currentNamespace != null, "No namespace set"); + currentNSState = createNewNamespaceMap(); state.put(currentNamespace, currentNSState); } @@ -121,6 +127,19 @@ public KvStateSnapshot, FoldingStateDescriptor(getKeySerializer(), getNamespaceSerializer(), stateSerializer, stateDesc, filePath); } + @Override + public byte[] getSerializedValue(K key, N namespace) throws Exception { + Preconditions.checkNotNull(key, "Key"); + Preconditions.checkNotNull(namespace, "Namespace"); + + Map stateByKey = state.get(namespace); + + if (stateByKey != null) { + return KvStateRequestSerializer.serializeValue(stateByKey.get(key), stateDesc.getSerializer()); + } else { + return null; + } + } public static class Snapshot extends AbstractFsStateSnapshot, FoldingStateDescriptor> { private static final long serialVersionUID = 1L; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsListState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsListState.java index 46c9830511ab6..dbef90093a611 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsListState.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsListState.java @@ -18,13 +18,15 @@ package org.apache.flink.runtime.state.filesystem; -import org.apache.flink.api.common.state.ListStateDescriptor; import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.core.fs.Path; +import org.apache.flink.runtime.query.netty.message.KvStateRequestSerializer; import org.apache.flink.runtime.state.ArrayListSerializer; import org.apache.flink.runtime.state.KvState; import org.apache.flink.runtime.state.KvStateSnapshot; +import org.apache.flink.util.Preconditions; import java.util.ArrayList; import java.util.HashMap; @@ -81,24 +83,27 @@ public FsListState(FsStateBackend backend, @Override public Iterable get() { if (currentNSState == null) { + Preconditions.checkState(currentNamespace != null, "No namespace set"); currentNSState = state.get(currentNamespace); } - return currentNSState != null ? - currentNSState.get(currentKey) : null; + if (currentNSState != null) { + Preconditions.checkState(currentKey != null, "No key set"); + return currentNSState.get(currentKey); + } else { + return null; + } } @Override public void add(V value) { - if (currentKey == null) { - throw new RuntimeException("No key available."); - } + Preconditions.checkState(currentKey != null, "No key set"); if (currentNSState == null) { - currentNSState = new HashMap<>(); + Preconditions.checkState(currentNamespace != null, "No namespace set"); + currentNSState = createNewNamespaceMap(); state.put(currentNamespace, currentNSState); } - ArrayList list = currentNSState.get(currentKey); if (list == null) { list = new ArrayList<>(); @@ -112,6 +117,19 @@ public KvStateSnapshot, ListStateDescriptor, FsStateBacken return new Snapshot<>(getKeySerializer(), getNamespaceSerializer(), new ArrayListSerializer<>(stateDesc.getSerializer()), stateDesc, filePath); } + @Override + public byte[] getSerializedValue(K key, N namespace) throws Exception { + Preconditions.checkNotNull(key, "Key"); + Preconditions.checkNotNull(namespace, "Namespace"); + + Map> stateByKey = state.get(namespace); + if (stateByKey != null) { + return KvStateRequestSerializer.serializeList(stateByKey.get(key), stateDesc.getSerializer()); + } else { + return null; + } + } + public static class Snapshot extends AbstractFsStateSnapshot, ListState, ListStateDescriptor> { private static final long serialVersionUID = 1L; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsReducingState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsReducingState.java index ef721c973c01c..bb389d9e58db7 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsReducingState.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsReducingState.java @@ -23,8 +23,10 @@ import org.apache.flink.api.common.state.ReducingStateDescriptor; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.core.fs.Path; +import org.apache.flink.runtime.query.netty.message.KvStateRequestSerializer; import org.apache.flink.runtime.state.KvState; import org.apache.flink.runtime.state.KvStateSnapshot; +import org.apache.flink.util.Preconditions; import java.io.IOException; import java.util.HashMap; @@ -86,9 +88,11 @@ public FsReducingState(FsStateBackend backend, @Override public V get() { if (currentNSState == null) { + Preconditions.checkState(currentNamespace != null, "No namespace set"); currentNSState = state.get(currentNamespace); } if (currentNSState != null) { + Preconditions.checkState(currentKey != null, "No key set"); return currentNSState.get(currentKey); } return null; @@ -96,12 +100,11 @@ public V get() { @Override public void add(V value) throws IOException { - if (currentKey == null) { - throw new RuntimeException("No key available."); - } + Preconditions.checkState(currentKey != null, "No key set"); if (currentNSState == null) { - currentNSState = new HashMap<>(); + Preconditions.checkState(currentNamespace != null, "No namespace set"); + currentNSState = createNewNamespaceMap(); state.put(currentNamespace, currentNSState); } // currentKeyState.merge(currentNamespace, value, new BiFunction() { @@ -130,6 +133,19 @@ public KvStateSnapshot, ReducingStateDescriptor, FsSta return new Snapshot<>(getKeySerializer(), getNamespaceSerializer(), stateSerializer, stateDesc, filePath); } + @Override + public byte[] getSerializedValue(K key, N namespace) throws Exception { + Preconditions.checkNotNull(key, "Key"); + Preconditions.checkNotNull(namespace, "Namespace"); + + Map stateByKey = state.get(namespace); + if (stateByKey != null) { + return KvStateRequestSerializer.serializeValue(stateByKey.get(key), stateDesc.getSerializer()); + } else { + return null; + } + } + public static class Snapshot extends AbstractFsStateSnapshot, ReducingStateDescriptor> { private static final long serialVersionUID = 1L; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java index 8a8a26d5a2bb1..61cf7410b92c9 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java @@ -31,9 +31,8 @@ import org.apache.flink.core.fs.FileSystem; import org.apache.flink.core.fs.Path; import org.apache.flink.runtime.execution.Environment; -import org.apache.flink.runtime.state.StateHandle; import org.apache.flink.runtime.state.AbstractStateBackend; - +import org.apache.flink.runtime.state.StateHandle; import org.apache.flink.runtime.state.StreamStateHandle; import org.apache.flink.runtime.state.memory.ByteStreamStateHandle; import org.slf4j.Logger; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsValueState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsValueState.java index 40b973e03c3ed..698bc1fef5e04 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsValueState.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsValueState.java @@ -22,8 +22,10 @@ import org.apache.flink.api.common.state.ValueStateDescriptor; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.core.fs.Path; +import org.apache.flink.runtime.query.netty.message.KvStateRequestSerializer; import org.apache.flink.runtime.state.KvState; import org.apache.flink.runtime.state.KvStateSnapshot; +import org.apache.flink.util.Preconditions; import java.util.HashMap; import java.util.Map; @@ -79,9 +81,11 @@ public FsValueState(FsStateBackend backend, @Override public V value() { if (currentNSState == null) { + Preconditions.checkState(currentNamespace != null, "No namespace set"); currentNSState = state.get(currentNamespace); } if (currentNSState != null) { + Preconditions.checkState(currentKey != null, "No key set"); V value = currentNSState.get(currentKey); return value != null ? value : stateDesc.getDefaultValue(); } @@ -90,9 +94,7 @@ public V value() { @Override public void update(V value) { - if (currentKey == null) { - throw new RuntimeException("No key available."); - } + Preconditions.checkState(currentKey != null, "No key set"); if (value == null) { clear(); @@ -100,7 +102,8 @@ public void update(V value) { } if (currentNSState == null) { - currentNSState = new HashMap<>(); + Preconditions.checkState(currentNamespace != null, "No namespace set"); + currentNSState = createNewNamespaceMap(); state.put(currentNamespace, currentNSState); } @@ -112,6 +115,20 @@ public KvStateSnapshot, ValueStateDescriptor, FsStateBack return new Snapshot<>(getKeySerializer(), getNamespaceSerializer(), stateSerializer, stateDesc, filePath); } + @Override + public byte[] getSerializedValue(K key, N namespace) throws Exception { + Preconditions.checkNotNull(key, "Key"); + Preconditions.checkNotNull(namespace, "Namespace"); + + Map stateByKey = state.get(namespace); + V value = stateByKey != null ? stateByKey.get(key) : stateDesc.getDefaultValue(); + if (value != null) { + return KvStateRequestSerializer.serializeValue(value, stateDesc.getSerializer()); + } else { + return KvStateRequestSerializer.serializeValue(stateDesc.getDefaultValue(), stateDesc.getSerializer()); + } + } + public static class Snapshot extends AbstractFsStateSnapshot, ValueStateDescriptor> { private static final long serialVersionUID = 1L; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemFoldingState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemFoldingState.java index 9953a64563a5a..a4dec3bb1bde4 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemFoldingState.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemFoldingState.java @@ -22,8 +22,10 @@ import org.apache.flink.api.common.state.FoldingState; import org.apache.flink.api.common.state.FoldingStateDescriptor; import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.runtime.query.netty.message.KvStateRequestSerializer; import org.apache.flink.runtime.state.KvState; import org.apache.flink.runtime.state.KvStateSnapshot; +import org.apache.flink.util.Preconditions; import java.io.IOException; import java.util.HashMap; @@ -62,20 +64,24 @@ public MemFoldingState(TypeSerializer keySerializer, @Override public ACC get() { if (currentNSState == null) { + Preconditions.checkState(currentNamespace != null, "No namespace set"); currentNSState = state.get(currentNamespace); } - return currentNSState != null ? - currentNSState.get(currentKey) : null; + if (currentNSState != null) { + Preconditions.checkState(currentKey != null, "No key set"); + return currentNSState.get(currentKey); + } else { + return null; + } } @Override public void add(T value) throws IOException { - if (currentKey == null) { - throw new RuntimeException("No key available."); - } + Preconditions.checkState(currentKey != null, "No key set"); if (currentNSState == null) { - currentNSState = new HashMap<>(); + Preconditions.checkState(currentNamespace != null, "No namespace set"); + currentNSState = createNewNamespaceMap(); state.put(currentNamespace, currentNSState); } @@ -97,6 +103,20 @@ public KvStateSnapshot, FoldingStateDescriptor(getKeySerializer(), getNamespaceSerializer(), stateSerializer, stateDesc, bytes); } + @Override + public byte[] getSerializedValue(K key, N namespace) throws Exception { + Preconditions.checkNotNull(key, "Key"); + Preconditions.checkNotNull(namespace, "Namespace"); + + Map stateByKey = state.get(namespace); + + if (stateByKey != null) { + return KvStateRequestSerializer.serializeValue(stateByKey.get(key), stateDesc.getSerializer()); + } else { + return null; + } + } + public static class Snapshot extends AbstractMemStateSnapshot, FoldingStateDescriptor> { private static final long serialVersionUID = 1L; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemListState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemListState.java index 97461d0140e4f..20b6eb57d4996 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemListState.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemListState.java @@ -21,9 +21,11 @@ import org.apache.flink.api.common.state.ListState; import org.apache.flink.api.common.state.ListStateDescriptor; import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.runtime.query.netty.message.KvStateRequestSerializer; import org.apache.flink.runtime.state.ArrayListSerializer; import org.apache.flink.runtime.state.KvState; import org.apache.flink.runtime.state.KvStateSnapshot; +import org.apache.flink.util.Preconditions; import java.util.ArrayList; import java.util.HashMap; @@ -52,24 +54,27 @@ public MemListState(TypeSerializer keySerializer, TypeSerializer namespace @Override public Iterable get() { if (currentNSState == null) { + Preconditions.checkState(currentNamespace != null, "No namespace set"); currentNSState = state.get(currentNamespace); } - return currentNSState != null ? - currentNSState.get(currentKey) : null; + if (currentNSState != null) { + Preconditions.checkState(currentKey != null, "No key set"); + return currentNSState.get(currentKey); + } else { + return null; + } } @Override public void add(V value) { - if (currentKey == null) { - throw new RuntimeException("No key available."); - } + Preconditions.checkState(currentKey != null, "No key set"); if (currentNSState == null) { - currentNSState = new HashMap<>(); + Preconditions.checkState(currentNamespace != null, "No namespace set"); + currentNSState = createNewNamespaceMap(); state.put(currentNamespace, currentNSState); } - ArrayList list = currentNSState.get(currentKey); if (list == null) { list = new ArrayList<>(); @@ -83,6 +88,19 @@ public KvStateSnapshot, ListStateDescriptor, MemoryStateBa return new Snapshot<>(getKeySerializer(), getNamespaceSerializer(), stateSerializer, stateDesc, bytes); } + @Override + public byte[] getSerializedValue(K key, N namespace) throws Exception { + Preconditions.checkNotNull(key, "Key"); + Preconditions.checkNotNull(namespace, "Namespace"); + + Map> stateByKey = state.get(namespace); + if (stateByKey != null) { + return KvStateRequestSerializer.serializeList(stateByKey.get(key), stateDesc.getSerializer()); + } else { + return null; + } + } + public static class Snapshot extends AbstractMemStateSnapshot, ListState, ListStateDescriptor> { private static final long serialVersionUID = 1L; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemReducingState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemReducingState.java index ce1634436b277..9a4c676a05a1d 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemReducingState.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemReducingState.java @@ -22,8 +22,10 @@ import org.apache.flink.api.common.state.ReducingState; import org.apache.flink.api.common.state.ReducingStateDescriptor; import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.runtime.query.netty.message.KvStateRequestSerializer; import org.apache.flink.runtime.state.KvState; import org.apache.flink.runtime.state.KvStateSnapshot; +import org.apache.flink.util.Preconditions; import java.io.IOException; import java.util.HashMap; @@ -61,9 +63,11 @@ public MemReducingState(TypeSerializer keySerializer, @Override public V get() { if (currentNSState == null) { + Preconditions.checkState(currentNamespace != null, "No namespace set"); currentNSState = state.get(currentNamespace); } if (currentNSState != null) { + Preconditions.checkState(currentKey != null, "No key set"); return currentNSState.get(currentKey); } return null; @@ -71,12 +75,11 @@ public V get() { @Override public void add(V value) throws IOException { - if (currentKey == null) { - throw new RuntimeException("No key available."); - } + Preconditions.checkState(currentKey != null, "No key set"); if (currentNSState == null) { - currentNSState = new HashMap<>(); + Preconditions.checkState(currentNamespace != null, "No namespace set"); + currentNSState = createNewNamespaceMap(); state.put(currentNamespace, currentNSState); } // currentKeyState.merge(currentNamespace, value, new BiFunction() { @@ -106,6 +109,19 @@ public KvStateSnapshot, ReducingStateDescriptor, Memor return new Snapshot<>(getKeySerializer(), getNamespaceSerializer(), stateSerializer, stateDesc, bytes); } + @Override + public byte[] getSerializedValue(K key, N namespace) throws Exception { + Preconditions.checkNotNull(key, "Key"); + Preconditions.checkNotNull(namespace, "Namespace"); + + Map stateByKey = state.get(namespace); + if (stateByKey != null) { + return KvStateRequestSerializer.serializeValue(stateByKey.get(key), stateDesc.getSerializer()); + } else { + return null; + } + } + public static class Snapshot extends AbstractMemStateSnapshot, ReducingStateDescriptor> { private static final long serialVersionUID = 1L; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemValueState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemValueState.java index 45b41580497fc..c0e3779d140c7 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemValueState.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemValueState.java @@ -21,8 +21,10 @@ import org.apache.flink.api.common.state.ValueState; import org.apache.flink.api.common.state.ValueStateDescriptor; import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.runtime.query.netty.message.KvStateRequestSerializer; import org.apache.flink.runtime.state.KvState; import org.apache.flink.runtime.state.KvStateSnapshot; +import org.apache.flink.util.Preconditions; import java.util.HashMap; import java.util.Map; @@ -54,9 +56,11 @@ public MemValueState(TypeSerializer keySerializer, @Override public V value() { if (currentNSState == null) { + Preconditions.checkState(currentNamespace != null, "No namespace set"); currentNSState = state.get(currentNamespace); } if (currentNSState != null) { + Preconditions.checkState(currentKey != null, "No key set"); V value = currentNSState.get(currentKey); return value != null ? value : stateDesc.getDefaultValue(); } @@ -65,9 +69,7 @@ public V value() { @Override public void update(V value) { - if (currentKey == null) { - throw new RuntimeException("No key available."); - } + Preconditions.checkState(currentKey != null, "No key set"); if (value == null) { clear(); @@ -75,7 +77,8 @@ public void update(V value) { } if (currentNSState == null) { - currentNSState = new HashMap<>(); + Preconditions.checkState(currentNamespace != null, "No namespace set"); + currentNSState = createNewNamespaceMap(); state.put(currentNamespace, currentNSState); } @@ -87,6 +90,20 @@ public KvStateSnapshot, ValueStateDescriptor, MemoryState return new Snapshot<>(getKeySerializer(), getNamespaceSerializer(), stateSerializer, stateDesc, bytes); } + @Override + public byte[] getSerializedValue(K key, N namespace) throws Exception { + Preconditions.checkNotNull(key, "Key"); + Preconditions.checkNotNull(namespace, "Namespace"); + + Map stateByKey = state.get(namespace); + V value = stateByKey != null ? stateByKey.get(key) : stateDesc.getDefaultValue(); + if (value != null) { + return KvStateRequestSerializer.serializeValue(value, stateDesc.getSerializer()); + } else { + return KvStateRequestSerializer.serializeValue(stateDesc.getDefaultValue(), stateDesc.getSerializer()); + } + } + public static class Snapshot extends AbstractMemStateSnapshot, ValueStateDescriptor> { private static final long serialVersionUID = 1L; diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/FileStateBackendTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/FileStateBackendTest.java index a7926d50e5137..0f1c0f77bbe0e 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/FileStateBackendTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/FileStateBackendTest.java @@ -282,4 +282,10 @@ private static void validateBytesInStream(InputStream is, byte[] data) throws IO is.close(); } + @Test + public void testConcurrentMapIfQueryable() throws Exception { + backend.initializeForJob(new DummyEnvironment("test", 1, 0), "test_op", IntSerializer.INSTANCE); + StateBackendTestBase.testConcurrentMapIfQueryable(backend); + } + } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java index 73d919be6d609..d3b4dbce7d373 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java @@ -18,6 +18,8 @@ package org.apache.flink.runtime.state; +import org.apache.flink.api.common.typeutils.base.IntSerializer; +import org.apache.flink.runtime.operators.testutils.DummyEnvironment; import org.apache.flink.runtime.state.memory.MemoryStateBackend; import org.junit.Test; @@ -153,4 +155,10 @@ public void testOversizedStateStream() { fail(e.getMessage()); } } + + @Test + public void testConcurrentMapIfQueryable() throws Exception { + backend.initializeForJob(new DummyEnvironment("test", 1, 0), "test_op", IntSerializer.INSTANCE); + StateBackendTestBase.testConcurrentMapIfQueryable(backend); + } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java index 80f1de398a7c4..d59e17baa6559 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java @@ -19,7 +19,6 @@ package org.apache.flink.runtime.state; import com.google.common.base.Joiner; - import org.apache.commons.io.output.ByteArrayOutputStream; import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.functions.FoldFunction; @@ -37,18 +36,34 @@ import org.apache.flink.api.common.typeutils.base.IntSerializer; import org.apache.flink.api.common.typeutils.base.LongSerializer; import org.apache.flink.api.common.typeutils.base.StringSerializer; -import org.apache.flink.api.common.typeutils.base.VoidSerializer; import org.apache.flink.core.memory.DataOutputViewStreamWrapper; import org.apache.flink.runtime.operators.testutils.DummyEnvironment; +import org.apache.flink.runtime.query.KvStateID; +import org.apache.flink.runtime.query.KvStateRegistry; +import org.apache.flink.runtime.query.KvStateRegistryListener; +import org.apache.flink.runtime.query.netty.message.KvStateRequestSerializer; import org.apache.flink.types.IntValue; - import org.junit.After; import org.junit.Before; import org.junit.Test; +import java.util.Collections; import java.util.HashMap; - -import static org.junit.Assert.*; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; /** * Generic tests for the partitioned state part of {@link AbstractStateBackend}. @@ -74,24 +89,33 @@ public void teardown() throws Exception { } @Test + @SuppressWarnings("unchecked") public void testValueState() throws Exception { - backend.initializeForJob(new DummyEnvironment("test", 1, 0), "test_op", IntSerializer.INSTANCE); ValueStateDescriptor kvId = new ValueStateDescriptor<>("id", String.class, null); kvId.initializeSerializerUnlessSet(new ExecutionConfig()); - ValueState state = backend.getPartitionedState(null, VoidSerializer.INSTANCE, kvId); + TypeSerializer keySerializer = IntSerializer.INSTANCE; + TypeSerializer namespaceSerializer = VoidNamespaceSerializer.INSTANCE; + TypeSerializer valueSerializer = kvId.getSerializer(); + + ValueState state = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId); + @SuppressWarnings("unchecked") + KvState kvState = (KvState) state; // some modifications to the state backend.setCurrentKey(1); assertNull(state.value()); + assertNull(getSerializedValue(kvState, 1, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer)); state.update("1"); backend.setCurrentKey(2); assertNull(state.value()); + assertNull(getSerializedValue(kvState, 2, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer)); state.update("2"); backend.setCurrentKey(1); assertEquals("1", state.value()); + assertEquals("1", getSerializedValue(kvState, 1, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer)); // draw a snapshot HashMap> snapshot1 = backend.snapshotPartitionedState(682375462378L, 2); @@ -122,10 +146,13 @@ public void testValueState() throws Exception { // validate the original state backend.setCurrentKey(1); assertEquals("u1", state.value()); + assertEquals("u1", getSerializedValue(kvState, 1, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer)); backend.setCurrentKey(2); assertEquals("u2", state.value()); + assertEquals("u2", getSerializedValue(kvState, 2, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer)); backend.setCurrentKey(3); assertEquals("u3", state.value()); + assertEquals("u3", getSerializedValue(kvState, 3, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer)); backend.dispose(); backend.initializeForJob(new DummyEnvironment("test", 1, 0), "test_op", IntSerializer.INSTANCE); @@ -136,12 +163,16 @@ public void testValueState() throws Exception { snapshot1.get(key).discardState(); } - ValueState restored1 = backend.getPartitionedState(null, VoidSerializer.INSTANCE, kvId); + ValueState restored1 = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId); + @SuppressWarnings("unchecked") + KvState restoredKvState1 = (KvState) restored1; backend.setCurrentKey(1); assertEquals("1", restored1.value()); + assertEquals("1", getSerializedValue(restoredKvState1, 1, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer)); backend.setCurrentKey(2); assertEquals("2", restored1.value()); + assertEquals("2", getSerializedValue(restoredKvState1, 2, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer)); backend.dispose(); backend.initializeForJob(new DummyEnvironment("test", 1, 0), "test_op", IntSerializer.INSTANCE); @@ -152,14 +183,19 @@ public void testValueState() throws Exception { snapshot2.get(key).discardState(); } - ValueState restored2 = backend.getPartitionedState(null, VoidSerializer.INSTANCE, kvId); + ValueState restored2 = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId); + @SuppressWarnings("unchecked") + KvState restoredKvState2 = (KvState) restored2; backend.setCurrentKey(1); assertEquals("u1", restored2.value()); + assertEquals("u1", getSerializedValue(restoredKvState2, 1, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer)); backend.setCurrentKey(2); assertEquals("u2", restored2.value()); + assertEquals("u2", getSerializedValue(restoredKvState2, 2, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer)); backend.setCurrentKey(3); assertEquals("u3", restored2.value()); + assertEquals("u3", getSerializedValue(restoredKvState2, 3, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer)); } /** @@ -169,14 +205,14 @@ public void testValueState() throws Exception { * @throws Exception */ @Test + @SuppressWarnings("unchecked") public void testValueStateNullUpdate() throws Exception { - // precondition: LongSerializer must fail on null value. this way the test would fail // later if null values where actually stored in the state instead of acting as clear() try { LongSerializer.INSTANCE.serialize(null, new DataOutputViewStreamWrapper(new ByteArrayOutputStream())); - fail("Should faill with NullPointerException"); + fail("Should fail with NullPointerException"); } catch (NullPointerException e) { // alrighty } @@ -186,7 +222,7 @@ public void testValueStateNullUpdate() throws Exception { ValueStateDescriptor kvId = new ValueStateDescriptor<>("id", LongSerializer.INSTANCE, 42L); kvId.initializeSerializerUnlessSet(new ExecutionConfig()); - ValueState state = backend.getPartitionedState(null, VoidSerializer.INSTANCE, kvId); + ValueState state = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId); // some modifications to the state backend.setCurrentKey(1); @@ -218,7 +254,6 @@ public void testValueStateNullUpdate() throws Exception { } } - backend.dispose(); backend.initializeForJob(new DummyEnvironment("test", 1, 0), "test_op", IntSerializer.INSTANCE); @@ -228,7 +263,7 @@ public void testValueStateNullUpdate() throws Exception { snapshot1.get(key).discardState(); } - backend.getPartitionedState(null, VoidSerializer.INSTANCE, kvId); + backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId); } @Test @@ -238,18 +273,29 @@ public void testListState() { backend.initializeForJob(new DummyEnvironment("test", 1, 0), "test_op", IntSerializer.INSTANCE); ListStateDescriptor kvId = new ListStateDescriptor<>("id", String.class); - ListState state = backend.getPartitionedState(null, VoidSerializer.INSTANCE, kvId); + kvId.initializeSerializerUnlessSet(new ExecutionConfig()); + + TypeSerializer keySerializer = IntSerializer.INSTANCE; + TypeSerializer namespaceSerializer = VoidNamespaceSerializer.INSTANCE; + TypeSerializer valueSerializer = kvId.getSerializer(); + + ListState state = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId); + @SuppressWarnings("unchecked") + KvState kvState = (KvState) state; Joiner joiner = Joiner.on(","); // some modifications to the state backend.setCurrentKey(1); assertEquals(null, state.get()); + assertEquals(null, getSerializedList(kvState, 1, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer)); state.add("1"); backend.setCurrentKey(2); assertEquals(null, state.get()); + assertEquals(null, getSerializedList(kvState, 2, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer)); state.add("2"); backend.setCurrentKey(1); assertEquals("1", joiner.join(state.get())); + assertEquals("1", joiner.join(getSerializedList(kvState, 1, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer))); // draw a snapshot HashMap> snapshot1 = backend.snapshotPartitionedState(682375462378L, 2); @@ -280,10 +326,13 @@ public void testListState() { // validate the original state backend.setCurrentKey(1); assertEquals("1,u1", joiner.join(state.get())); + assertEquals("1,u1", joiner.join(getSerializedList(kvState, 1, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer))); backend.setCurrentKey(2); assertEquals("2,u2", joiner.join(state.get())); + assertEquals("2,u2", joiner.join(getSerializedList(kvState, 2, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer))); backend.setCurrentKey(3); assertEquals("u3", joiner.join(state.get())); + assertEquals("u3", joiner.join(getSerializedList(kvState, 3, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer))); backend.dispose(); @@ -295,12 +344,16 @@ public void testListState() { snapshot1.get(key).discardState(); } - ListState restored1 = backend.getPartitionedState(null, VoidSerializer.INSTANCE, kvId); + ListState restored1 = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId); + @SuppressWarnings("unchecked") + KvState restoredKvState1 = (KvState) restored1; backend.setCurrentKey(1); assertEquals("1", joiner.join(restored1.get())); + assertEquals("1", joiner.join(getSerializedList(restoredKvState1, 1, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer))); backend.setCurrentKey(2); assertEquals("2", joiner.join(restored1.get())); + assertEquals("2", joiner.join(getSerializedList(restoredKvState1, 2, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer))); backend.dispose(); @@ -312,14 +365,19 @@ public void testListState() { snapshot2.get(key).discardState(); } - ListState restored2 = backend.getPartitionedState(null, VoidSerializer.INSTANCE, kvId); + ListState restored2 = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId); + @SuppressWarnings("unchecked") + KvState restoredKvState2 = (KvState) restored2; backend.setCurrentKey(1); assertEquals("1,u1", joiner.join(restored2.get())); + assertEquals("1,u1", joiner.join(getSerializedList(restoredKvState2, 1, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer))); backend.setCurrentKey(2); assertEquals("2,u2", joiner.join(restored2.get())); + assertEquals("2,u2", joiner.join(getSerializedList(restoredKvState2, 2, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer))); backend.setCurrentKey(3); assertEquals("u3", joiner.join(restored2.get())); + assertEquals("u3", joiner.join(getSerializedList(restoredKvState2, 3, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer))); } catch (Exception e) { e.printStackTrace(); @@ -328,23 +386,34 @@ public void testListState() { } @Test + @SuppressWarnings("unchecked") public void testReducingState() { try { backend.initializeForJob(new DummyEnvironment("test", 1, 0), "test_op", IntSerializer.INSTANCE); ReducingStateDescriptor kvId = new ReducingStateDescriptor<>("id", new AppendingReduce(), String.class); + kvId.initializeSerializerUnlessSet(new ExecutionConfig()); - ReducingState state = backend.getPartitionedState(null, VoidSerializer.INSTANCE, kvId); + TypeSerializer keySerializer = IntSerializer.INSTANCE; + TypeSerializer namespaceSerializer = VoidNamespaceSerializer.INSTANCE; + TypeSerializer valueSerializer = kvId.getSerializer(); + + ReducingState state = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId); + @SuppressWarnings("unchecked") + KvState kvState = (KvState) state; // some modifications to the state backend.setCurrentKey(1); assertEquals(null, state.get()); + assertNull(getSerializedValue(kvState, 1, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer)); state.add("1"); backend.setCurrentKey(2); assertEquals(null, state.get()); + assertNull(getSerializedValue(kvState, 2, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer)); state.add("2"); backend.setCurrentKey(1); assertEquals("1", state.get()); + assertEquals("1", getSerializedValue(kvState, 1, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer)); // draw a snapshot HashMap> snapshot1 = backend.snapshotPartitionedState(682375462378L, 2); @@ -375,10 +444,13 @@ public void testReducingState() { // validate the original state backend.setCurrentKey(1); assertEquals("1,u1", state.get()); + assertEquals("1,u1", getSerializedValue(kvState, 1, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer)); backend.setCurrentKey(2); assertEquals("2,u2", state.get()); + assertEquals("2,u2", getSerializedValue(kvState, 2, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer)); backend.setCurrentKey(3); assertEquals("u3", state.get()); + assertEquals("u3", getSerializedValue(kvState, 3, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer)); backend.dispose(); @@ -390,12 +462,16 @@ public void testReducingState() { snapshot1.get(key).discardState(); } - ReducingState restored1 = backend.getPartitionedState(null, VoidSerializer.INSTANCE, kvId); + ReducingState restored1 = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId); + @SuppressWarnings("unchecked") + KvState restoredKvState1 = (KvState) restored1; backend.setCurrentKey(1); assertEquals("1", restored1.get()); + assertEquals("1", getSerializedValue(restoredKvState1, 1, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer)); backend.setCurrentKey(2); assertEquals("2", restored1.get()); + assertEquals("2", getSerializedValue(restoredKvState1, 2, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer)); backend.dispose(); @@ -407,15 +483,19 @@ public void testReducingState() { snapshot2.get(key).discardState(); } - ReducingState restored2 = backend.getPartitionedState(null, VoidSerializer.INSTANCE, kvId); - + ReducingState restored2 = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId); + @SuppressWarnings("unchecked") + KvState restoredKvState2 = (KvState) restored2; backend.setCurrentKey(1); assertEquals("1,u1", restored2.get()); + assertEquals("1,u1", getSerializedValue(restoredKvState2, 1, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer)); backend.setCurrentKey(2); assertEquals("2,u2", restored2.get()); + assertEquals("2,u2", getSerializedValue(restoredKvState2, 2, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer)); backend.setCurrentKey(3); assertEquals("u3", restored2.get()); + assertEquals("u3", getSerializedValue(restoredKvState2, 3, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer)); } catch (Exception e) { e.printStackTrace(); @@ -432,19 +512,29 @@ public void testFoldingState() { FoldingStateDescriptor kvId = new FoldingStateDescriptor<>("id", "Fold-Initial:", new AppendingFold(), - String.class); + String.class); + kvId.initializeSerializerUnlessSet(new ExecutionConfig()); - FoldingState state = backend.getPartitionedState(null, VoidSerializer.INSTANCE, kvId); + TypeSerializer keySerializer = IntSerializer.INSTANCE; + TypeSerializer namespaceSerializer = VoidNamespaceSerializer.INSTANCE; + TypeSerializer valueSerializer = kvId.getSerializer(); + + FoldingState state = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId); + @SuppressWarnings("unchecked") + KvState kvState = (KvState) state; // some modifications to the state backend.setCurrentKey(1); assertEquals(null, state.get()); + assertEquals(null, getSerializedValue(kvState, 1, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer)); state.add(1); backend.setCurrentKey(2); assertEquals(null, state.get()); + assertEquals(null, getSerializedValue(kvState, 2, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer)); state.add(2); backend.setCurrentKey(1); assertEquals("Fold-Initial:,1", state.get()); + assertEquals("Fold-Initial:,1", getSerializedValue(kvState, 1, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer)); // draw a snapshot HashMap> snapshot1 = backend.snapshotPartitionedState(682375462378L, 2); @@ -476,10 +566,13 @@ public void testFoldingState() { // validate the original state backend.setCurrentKey(1); assertEquals("Fold-Initial:,101", state.get()); + assertEquals("Fold-Initial:,101", getSerializedValue(kvState, 1, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer)); backend.setCurrentKey(2); assertEquals("Fold-Initial:,2,102", state.get()); + assertEquals("Fold-Initial:,2,102", getSerializedValue(kvState, 2, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer)); backend.setCurrentKey(3); assertEquals("Fold-Initial:,103", state.get()); + assertEquals("Fold-Initial:,103", getSerializedValue(kvState, 3, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer)); backend.dispose(); @@ -491,12 +584,16 @@ public void testFoldingState() { snapshot1.get(key).discardState(); } - FoldingState restored1 = backend.getPartitionedState(null, VoidSerializer.INSTANCE, kvId); + FoldingState restored1 = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId); + @SuppressWarnings("unchecked") + KvState restoredKvState1 = (KvState) restored1; backend.setCurrentKey(1); assertEquals("Fold-Initial:,1", restored1.get()); + assertEquals("Fold-Initial:,1", getSerializedValue(restoredKvState1, 1, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer)); backend.setCurrentKey(2); assertEquals("Fold-Initial:,2", restored1.get()); + assertEquals("Fold-Initial:,2", getSerializedValue(restoredKvState1, 2, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer)); backend.dispose(); @@ -509,14 +606,19 @@ public void testFoldingState() { } @SuppressWarnings("unchecked") - FoldingState restored2 = backend.getPartitionedState(null, VoidSerializer.INSTANCE, kvId); + FoldingState restored2 = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId); + @SuppressWarnings("unchecked") + KvState restoredKvState2 = (KvState) restored2; backend.setCurrentKey(1); assertEquals("Fold-Initial:,101", restored2.get()); + assertEquals("Fold-Initial:,101", getSerializedValue(restoredKvState2, 1, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer)); backend.setCurrentKey(2); assertEquals("Fold-Initial:,2,102", restored2.get()); + assertEquals("Fold-Initial:,2,102", getSerializedValue(restoredKvState2, 2, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer)); backend.setCurrentKey(3); assertEquals("Fold-Initial:,103", restored2.get()); + assertEquals("Fold-Initial:,103", getSerializedValue(restoredKvState2, 3, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer)); } catch (Exception e) { e.printStackTrace(); @@ -525,6 +627,7 @@ public void testFoldingState() { } @Test + @SuppressWarnings("unchecked") public void testValueStateRestoreWithWrongSerializers() { try { backend.initializeForJob(new DummyEnvironment("test", 1, 0), @@ -534,7 +637,7 @@ public void testValueStateRestoreWithWrongSerializers() { ValueStateDescriptor kvId = new ValueStateDescriptor<>("id", String.class, null); kvId.initializeSerializerUnlessSet(new ExecutionConfig()); - ValueState state = backend.getPartitionedState(null, VoidSerializer.INSTANCE, kvId); + ValueState state = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId); backend.setCurrentKey(1); state.update("1"); @@ -567,7 +670,7 @@ public void testValueStateRestoreWithWrongSerializers() { try { kvId = new ValueStateDescriptor<>("id", fakeStringSerializer, null); - state = backend.getPartitionedState(null, VoidSerializer.INSTANCE, kvId); + state = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId); state.value(); @@ -588,12 +691,13 @@ public void testValueStateRestoreWithWrongSerializers() { } @Test + @SuppressWarnings("unchecked") public void testListStateRestoreWithWrongSerializers() { try { backend.initializeForJob(new DummyEnvironment("test", 1, 0), "test_op", IntSerializer.INSTANCE); ListStateDescriptor kvId = new ListStateDescriptor<>("id", String.class); - ListState state = backend.getPartitionedState(null, VoidSerializer.INSTANCE, kvId); + ListState state = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId); backend.setCurrentKey(1); state.add("1"); @@ -626,7 +730,7 @@ public void testListStateRestoreWithWrongSerializers() { try { kvId = new ListStateDescriptor<>("id", fakeStringSerializer); - state = backend.getPartitionedState(null, VoidSerializer.INSTANCE, kvId); + state = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId); state.get(); @@ -647,6 +751,7 @@ public void testListStateRestoreWithWrongSerializers() { } @Test + @SuppressWarnings("unchecked") public void testReducingStateRestoreWithWrongSerializers() { try { backend.initializeForJob(new DummyEnvironment("test", 1, 0), "test_op", IntSerializer.INSTANCE); @@ -654,14 +759,13 @@ public void testReducingStateRestoreWithWrongSerializers() { ReducingStateDescriptor kvId = new ReducingStateDescriptor<>("id", new AppendingReduce(), StringSerializer.INSTANCE); - ReducingState state = backend.getPartitionedState(null, VoidSerializer.INSTANCE, kvId); + ReducingState state = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId); backend.setCurrentKey(1); state.add("1"); backend.setCurrentKey(2); state.add("2"); - // draw a snapshot HashMap> snapshot1 = backend.snapshotPartitionedState(682375462378L, 2); @@ -688,7 +792,7 @@ public void testReducingStateRestoreWithWrongSerializers() { try { kvId = new ReducingStateDescriptor<>("id", new AppendingReduce(), fakeStringSerializer); - state = backend.getPartitionedState(null, VoidSerializer.INSTANCE, kvId); + state = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId); state.get(); @@ -715,7 +819,7 @@ public void testCopyDefaultValue() throws Exception { ValueStateDescriptor kvId = new ValueStateDescriptor<>("id", IntValue.class, new IntValue(-1)); kvId.initializeSerializerUnlessSet(new ExecutionConfig()); - ValueState state = backend.getPartitionedState(null, VoidSerializer.INSTANCE, kvId); + ValueState state = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId); backend.setCurrentKey(1); IntValue default1 = state.value(); @@ -729,6 +833,222 @@ public void testCopyDefaultValue() throws Exception { assertFalse(default1 == default2); } + /** + * Previously, it was possible to create partitioned state with + * null namespace. This test makes sure that this is + * prohibited now. + */ + @Test + public void testRequireNonNullNamespace() throws Exception { + backend.initializeForJob(new DummyEnvironment("test", 1, 0), "test_op", IntSerializer.INSTANCE); + + ValueStateDescriptor kvId = new ValueStateDescriptor<>("id", IntValue.class, new IntValue(-1)); + kvId.initializeSerializerUnlessSet(new ExecutionConfig()); + + try { + backend.getPartitionedState(null, VoidNamespaceSerializer.INSTANCE, kvId); + fail("Did not throw expected NullPointerException"); + } catch (NullPointerException ignored) { + } + + try { + backend.getPartitionedState(VoidNamespace.INSTANCE, null, kvId); + fail("Did not throw expected NullPointerException"); + } catch (NullPointerException ignored) { + } + + try { + backend.getPartitionedState(null, null, kvId); + fail("Did not throw expected NullPointerException"); + } catch (NullPointerException ignored) { + } + } + + /** + * Tests that {@link AbstractHeapState} instances respect the queryable + * flag and create concurrent variants for internal state structures. + */ + @SuppressWarnings("unchecked") + protected static void testConcurrentMapIfQueryable(B backend) throws Exception { + { + // ValueState + ValueStateDescriptor desc = new ValueStateDescriptor<>( + "value-state", + Integer.class, + -1); + desc.setQueryable("my-query"); + desc.initializeSerializerUnlessSet(new ExecutionConfig()); + + ValueState state = backend.getPartitionedState( + VoidNamespace.INSTANCE, + VoidNamespaceSerializer.INSTANCE, + desc); + + KvState kvState = (KvState) state; + assertTrue(kvState instanceof AbstractHeapState); + + Map> stateMap = ((AbstractHeapState) kvState).getStateMap(); + assertTrue(stateMap instanceof ConcurrentHashMap); + + kvState.setCurrentNamespace(VoidNamespace.INSTANCE); + kvState.setCurrentKey(1); + state.update(121818273); + + Map namespaceMap = stateMap.get(VoidNamespace.INSTANCE); + + assertNotNull("Value not set", namespaceMap); + assertTrue(namespaceMap instanceof ConcurrentHashMap); + } + + { + // ListState + ListStateDescriptor desc = new ListStateDescriptor<>("list-state", Integer.class); + desc.setQueryable("my-query"); + desc.initializeSerializerUnlessSet(new ExecutionConfig()); + + ListState state = backend.getPartitionedState( + VoidNamespace.INSTANCE, + VoidNamespaceSerializer.INSTANCE, + desc); + + KvState kvState = (KvState) state; + assertTrue(kvState instanceof AbstractHeapState); + + Map> stateMap = ((AbstractHeapState) kvState).getStateMap(); + assertTrue(stateMap instanceof ConcurrentHashMap); + + kvState.setCurrentNamespace(VoidNamespace.INSTANCE); + kvState.setCurrentKey(1); + state.add(121818273); + + Map namespaceMap = stateMap.get(VoidNamespace.INSTANCE); + + assertNotNull("List not set", namespaceMap); + assertTrue(namespaceMap instanceof ConcurrentHashMap); + } + + { + // ReducingState + ReducingStateDescriptor desc = new ReducingStateDescriptor<>( + "reducing-state", new ReduceFunction() { + @Override + public Integer reduce(Integer value1, Integer value2) throws Exception { + return value1 + value2; + } + }, Integer.class); + desc.setQueryable("my-query"); + desc.initializeSerializerUnlessSet(new ExecutionConfig()); + + ReducingState state = backend.getPartitionedState( + VoidNamespace.INSTANCE, + VoidNamespaceSerializer.INSTANCE, + desc); + + KvState kvState = (KvState) state; + assertTrue(kvState instanceof AbstractHeapState); + + Map> stateMap = ((AbstractHeapState) kvState).getStateMap(); + assertTrue(stateMap instanceof ConcurrentHashMap); + + kvState.setCurrentNamespace(VoidNamespace.INSTANCE); + kvState.setCurrentKey(1); + state.add(121818273); + + Map namespaceMap = stateMap.get(VoidNamespace.INSTANCE); + + assertNotNull("List not set", namespaceMap); + assertTrue(namespaceMap instanceof ConcurrentHashMap); + } + + { + // FoldingState + FoldingStateDescriptor desc = new FoldingStateDescriptor<>( + "folding-state", 0, new FoldFunction() { + @Override + public Integer fold(Integer accumulator, Integer value) throws Exception { + return accumulator + value; + } + }, Integer.class); + desc.setQueryable("my-query"); + desc.initializeSerializerUnlessSet(new ExecutionConfig()); + + FoldingState state = backend.getPartitionedState( + VoidNamespace.INSTANCE, + VoidNamespaceSerializer.INSTANCE, + desc); + + KvState kvState = (KvState) state; + assertTrue(kvState instanceof AbstractHeapState); + + Map> stateMap = ((AbstractHeapState) kvState).getStateMap(); + assertTrue(stateMap instanceof ConcurrentHashMap); + + kvState.setCurrentNamespace(VoidNamespace.INSTANCE); + kvState.setCurrentKey(1); + state.add(121818273); + + Map namespaceMap = stateMap.get(VoidNamespace.INSTANCE); + + assertNotNull("List not set", namespaceMap); + assertTrue(namespaceMap instanceof ConcurrentHashMap); + } + } + + /** + * Tests registration with the KvStateRegistry. + */ + @Test + public void testQueryableStateRegistration() throws Exception { + DummyEnvironment env = new DummyEnvironment("test", 1, 0); + KvStateRegistry registry = env.getKvStateRegistry(); + + KvStateRegistryListener listener = mock(KvStateRegistryListener.class); + registry.registerListener(listener); + + backend.initializeForJob(env, "test_op", IntSerializer.INSTANCE); + + ValueStateDescriptor desc = new ValueStateDescriptor<>( + "test", + IntSerializer.INSTANCE, + null); + desc.setQueryable("banana"); + + backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, desc); + + // Verify registered + verify(listener, times(1)).notifyKvStateRegistered( + eq(env.getJobID()), eq(env.getJobVertexId()), eq(0), eq("banana"), any(KvStateID.class)); + + + HashMap> snapshot = backend + .snapshotPartitionedState(682375462379L, 4); + + for (String key: snapshot.keySet()) { + if (snapshot.get(key) instanceof AsynchronousKvStateSnapshot) { + snapshot.put(key, ((AsynchronousKvStateSnapshot) snapshot.get(key)).materialize()); + } + } + + // Verify unregistered + backend.dispose(); + + verify(listener, times(1)).notifyKvStateUnregistered( + eq(env.getJobID()), eq(env.getJobVertexId()), eq(0), eq("banana")); + + // Initialize again + backend.initializeForJob(env, "test_op", IntSerializer.INSTANCE); + + backend.injectKeyValueStateSnapshots((HashMap) snapshot); + + backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, desc); + + // Verify registered again + verify(listener, times(2)).notifyKvStateRegistered( + eq(env.getJobID()), eq(env.getJobVertexId()), eq(0), eq("banana"), any(KvStateID.class)); + + + } + private static class AppendingReduce implements ReduceFunction { @Override public String reduce(String value1, String value2) throws Exception { @@ -744,4 +1064,52 @@ public String fold(String acc, Integer value) throws Exception { return acc + "," + value; } } + + /** + * Returns the value by getting the serialized value and deserializing it + * if it is not null. + */ + private static V getSerializedValue( + KvState kvState, + K key, + TypeSerializer keySerializer, + N namespace, + TypeSerializer namespaceSerializer, + TypeSerializer valueSerializer) throws Exception { + + byte[] serializedKeyAndNamespace = KvStateRequestSerializer.serializeKeyAndNamespace( + key, keySerializer, namespace, namespaceSerializer); + + byte[] serializedValue = kvState.getSerializedValue(serializedKeyAndNamespace); + + if (serializedValue == null) { + return null; + } else { + return KvStateRequestSerializer.deserializeValue(serializedValue, valueSerializer); + } + } + + /** + * Returns the value by getting the serialized value and deserializing it + * if it is not null. + */ + private static List getSerializedList( + KvState kvState, + K key, + TypeSerializer keySerializer, + N namespace, + TypeSerializer namespaceSerializer, + TypeSerializer valueSerializer) throws Exception { + + byte[] serializedKeyAndNamespace = KvStateRequestSerializer.serializeKeyAndNamespace( + key, keySerializer, namespace, namespaceSerializer); + + byte[] serializedValue = kvState.getSerializedValue(serializedKeyAndNamespace); + + if (serializedValue == null) { + return null; + } else { + return KvStateRequestSerializer.deserializeList(serializedValue, valueSerializer); + } + } } diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java index 0269a342f750e..15bb384442319 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java @@ -23,19 +23,19 @@ import org.apache.flink.api.common.state.State; import org.apache.flink.api.common.state.StateDescriptor; import org.apache.flink.api.common.typeutils.TypeSerializer; -import org.apache.flink.api.common.typeutils.base.VoidSerializer; import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.metrics.Counter; import org.apache.flink.metrics.MetricGroup; -import org.apache.flink.streaming.api.graph.StreamConfig; -import org.apache.flink.runtime.state.KvStateSnapshot; import org.apache.flink.runtime.state.AbstractStateBackend; +import org.apache.flink.runtime.state.KvStateSnapshot; +import org.apache.flink.runtime.state.VoidNamespace; +import org.apache.flink.runtime.state.VoidNamespaceSerializer; +import org.apache.flink.streaming.api.graph.StreamConfig; import org.apache.flink.streaming.api.watermark.Watermark; import org.apache.flink.streaming.runtime.operators.Triggerable; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.streaming.runtime.tasks.StreamTask; import org.apache.flink.streaming.runtime.tasks.StreamTaskState; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -189,7 +189,6 @@ public StreamTaskState snapshotOperatorState(long checkpointId, long timestamp) } } - return state; } @@ -271,7 +270,7 @@ protected long getCurrentProcessingTime() { * @throws Exception Thrown, if the state backend cannot create the key/value state. */ protected S getPartitionedState(StateDescriptor stateDescriptor) throws Exception { - return getStateBackend().getPartitionedState(null, VoidSerializer.INSTANCE, stateDescriptor); + return getStateBackend().getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, stateDescriptor); } /** diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/WindowOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/WindowOperator.java index 2434843bd04a2..98bb30350364f 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/WindowOperator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/WindowOperator.java @@ -33,7 +33,6 @@ import org.apache.flink.api.common.state.ValueStateDescriptor; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeutils.TypeSerializer; -import org.apache.flink.api.common.typeutils.base.VoidSerializer; import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.api.java.tuple.Tuple1; import org.apache.flink.api.java.tuple.Tuple2; @@ -45,6 +44,8 @@ import org.apache.flink.metrics.MetricGroup; import org.apache.flink.runtime.state.AbstractStateBackend; import org.apache.flink.runtime.state.StateHandle; +import org.apache.flink.runtime.state.VoidNamespace; +import org.apache.flink.runtime.state.VoidNamespaceSerializer; import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator; import org.apache.flink.streaming.api.operators.ChainingStrategy; import org.apache.flink.streaming.api.operators.OneInputStreamOperator; @@ -555,7 +556,7 @@ protected MergingWindowSet getMergingWindowSet() throws Exception { TupleSerializer> tupleSerializer = new TupleSerializer<>((Class) Tuple2.class, new TypeSerializer[] {windowSerializer, windowSerializer} ); ListStateDescriptor> mergeStateDescriptor = new ListStateDescriptor<>("merging-window-set", tupleSerializer); - ListState> mergeState = getStateBackend().getPartitionedState(null, VoidSerializer.INSTANCE, mergeStateDescriptor); + ListState> mergeState = getStateBackend().getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, mergeStateDescriptor); mergingWindows = new MergingWindowSet<>((MergingWindowAssigner) windowAssigner, mergeState); mergeState.clear(); @@ -863,7 +864,7 @@ public StreamTaskState snapshotOperatorState(long checkpointId, long timestamp) ListStateDescriptor> mergeStateDescriptor = new ListStateDescriptor<>("merging-window-set", tupleSerializer); for (Map.Entry> key: mergingWindowsByKey.entrySet()) { setKeyContext(key.getKey()); - ListState> mergeState = getStateBackend().getPartitionedState(null, VoidSerializer.INSTANCE, mergeStateDescriptor); + ListState> mergeState = getStateBackend().getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, mergeStateDescriptor); mergeState.clear(); key.getValue().persist(mergeState); } diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContextTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContextTest.java index 30ebb20f4b8e1..b2f6dbdab1924 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContextTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContextTest.java @@ -29,14 +29,13 @@ import org.apache.flink.api.common.state.ValueStateDescriptor; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.common.typeutils.base.StringSerializer; -import org.apache.flink.api.common.typeutils.base.VoidSerializer; import org.apache.flink.api.java.typeutils.runtime.kryo.KryoSerializer; import org.apache.flink.core.fs.Path; import org.apache.flink.runtime.execution.Environment; - +import org.apache.flink.runtime.state.VoidNamespace; +import org.apache.flink.runtime.state.VoidNamespaceSerializer; import org.apache.flink.runtime.state.memory.MemListState; import org.junit.Test; - import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; @@ -44,8 +43,12 @@ import java.util.concurrent.Future; import java.util.concurrent.atomic.AtomicReference; -import static org.mockito.Mockito.*; -import static org.junit.Assert.*; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; public class StreamingRuntimeContextTest { @@ -179,8 +182,10 @@ private static AbstractStreamOperator createPlainMockOp() throws Exception { public ListState answer(InvocationOnMock invocationOnMock) throws Throwable { ListStateDescriptor descr = (ListStateDescriptor) invocationOnMock.getArguments()[0]; - return new MemListState( - StringSerializer.INSTANCE, VoidSerializer.INSTANCE, descr); + MemListState listState = new MemListState<>( + StringSerializer.INSTANCE, VoidNamespaceSerializer.INSTANCE, descr); + listState.setCurrentNamespace(VoidNamespace.INSTANCE); + return listState; } }); From fdccd5ee4ea460518bd1bcf072ac40faeeccf769 Mon Sep 17 00:00:00 2001 From: Ufuk Celebi Date: Mon, 30 May 2016 14:03:35 +0200 Subject: [PATCH 2/6] [FLINK-3779] [runtime] Add KvStateRegistry for queryable KvState [streaming-java] - Adds a KvStateRegistry per TaskManager at which created KvState instances are registered/unregistered. - Registered KvState instances are reported to the JobManager, whcih can be queried for KvStateLocation. --- .../state/RocksDBStateBackendConfigTest.java | 6 + .../flink/runtime/execution/Environment.java | 9 + .../executiongraph/ExecutionGraph.java | 10 + .../io/network/NetworkEnvironment.java | 125 +++++++- .../flink/runtime/query/KvStateLocation.java | 239 +++++++++++++++ .../query/KvStateLocationRegistry.java | 161 ++++++++++ .../flink/runtime/query/KvStateMessage.java | 290 ++++++++++++++++++ .../flink/runtime/query/KvStateRegistry.java | 157 ++++++++++ .../query/KvStateRegistryListener.java | 62 ++++ .../runtime/query/TaskKvStateRegistry.java | 93 ++++++ .../taskmanager/RuntimeEnvironment.java | 10 + .../flink/runtime/taskmanager/Task.java | 5 + .../flink/runtime/jobmanager/JobManager.scala | 76 ++++- .../runtime/taskmanager/TaskManager.scala | 6 +- .../io/network/NetworkEnvironmentTest.java | 7 +- .../runtime/jobmanager/JobManagerTest.java | 258 +++++++++++++++- .../operators/testutils/DummyEnvironment.java | 15 + .../testutils/DummyEnvironment.java.orig | 185 +++++++++++ .../operators/testutils/MockEnvironment.java | 14 +- .../query/KvStateLocationRegistryTest.java | 231 ++++++++++++++ .../runtime/query/KvStateLocationTest.java | 92 ++++++ .../taskmanager/TaskAsyncCallTest.java | 3 + ...kManagerComponentsStartupShutdownTest.java | 3 +- .../flink/runtime/taskmanager/TaskTest.java | 10 +- .../tasks/InterruptSensitiveRestoreTest.java | 7 +- .../runtime/tasks/StreamMockEnvironment.java | 12 + .../runtime/tasks/StreamTaskTest.java | 15 +- 27 files changed, 2062 insertions(+), 39 deletions(-) create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/query/KvStateLocation.java create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/query/KvStateLocationRegistry.java create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/query/KvStateMessage.java create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/query/KvStateRegistry.java create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/query/KvStateRegistryListener.java create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/query/TaskKvStateRegistry.java create mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java.orig create mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/query/KvStateLocationRegistryTest.java create mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/query/KvStateLocationTest.java diff --git a/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackendConfigTest.java b/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackendConfigTest.java index 0878b8c5998f8..657c57e53514a 100644 --- a/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackendConfigTest.java +++ b/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackendConfigTest.java @@ -20,6 +20,7 @@ import org.apache.commons.io.FileUtils; import org.apache.flink.api.common.JobID; +import org.apache.flink.api.common.TaskInfo; import org.apache.flink.api.common.state.ValueStateDescriptor; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.common.typeutils.base.IntSerializer; @@ -311,6 +312,11 @@ private static Environment getMockEnvironment(File[] tempDirs) { when(env.getJobID()).thenReturn(new JobID()); when(env.getUserClassLoader()).thenReturn(RocksDBStateBackendConfigTest.class.getClassLoader()); when(env.getIOManager()).thenReturn(ioMan); + + TaskInfo taskInfo = mock(TaskInfo.class); + when(env.getTaskInfo()).thenReturn(taskInfo); + + when(taskInfo.getIndexOfThisSubtask()).thenReturn(0); return env; } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/execution/Environment.java b/flink-runtime/src/main/java/org/apache/flink/runtime/execution/Environment.java index 5ad5fe2831c5d..2f158fd3e0436 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/execution/Environment.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/execution/Environment.java @@ -33,6 +33,8 @@ import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider; import org.apache.flink.runtime.memory.MemoryManager; +import org.apache.flink.runtime.query.TaskKvStateRegistry; +import org.apache.flink.runtime.state.KvState; import org.apache.flink.runtime.state.StateHandle; import org.apache.flink.runtime.taskmanager.TaskManagerRuntimeInfo; @@ -147,6 +149,13 @@ public interface Environment { */ AccumulatorRegistry getAccumulatorRegistry(); + /** + * Returns the registry for {@link KvState} instances. + * + * @return KvState registry + */ + TaskKvStateRegistry getTaskKvStateRegistry(); + /** * Confirms that the invokable has successfully completed all steps it needed to * to for the checkpoint with the give checkpoint-ID. This method does not include diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraph.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraph.java index 1a0301dab3a72..e6ae6ce724c81 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraph.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraph.java @@ -54,6 +54,7 @@ import org.apache.flink.runtime.jobmanager.scheduler.CoLocationGroup; import org.apache.flink.runtime.jobmanager.scheduler.Scheduler; import org.apache.flink.runtime.messages.ExecutionGraphMessages; +import org.apache.flink.runtime.query.KvStateLocationRegistry; import org.apache.flink.runtime.taskmanager.TaskExecutionState; import org.apache.flink.runtime.util.SerializableObject; import org.apache.flink.runtime.util.SerializedThrowable; @@ -224,6 +225,9 @@ public class ExecutionGraph { /** The execution context which is used to execute futures. */ private ExecutionContext executionContext; + /** Registered KvState instances reported by the TaskManagers. */ + private transient KvStateLocationRegistry kvStateLocationRegistry; + // ------ Fields that are only relevant for archived execution graphs ------------ private String jsonPlan; @@ -304,6 +308,8 @@ public ExecutionGraph( this.restartStrategy = restartStrategy; metricGroup.gauge(RESTARTING_TIME_METRIC_NAME, new RestartTimeGauge()); + + this.kvStateLocationRegistry = new KvStateLocationRegistry(jobId, getAllVertices()); } // -------------------------------------------------------------------------------------------- @@ -445,6 +451,10 @@ public SavepointCoordinator getSavepointCoordinator() { return savepointCoordinator; } + public KvStateLocationRegistry getKvStateLocationRegistry() { + return kvStateLocationRegistry; + } + public RestartStrategy getRestartStrategy() { return restartStrategy; } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java index 30d2e387d1361..283d804509893 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java @@ -22,6 +22,7 @@ import org.apache.flink.api.common.JobID; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.instance.ActorGateway; +import org.apache.flink.runtime.instance.InstanceConnectionInfo; import org.apache.flink.runtime.io.disk.iomanager.IOManager.IOMode; import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter; import org.apache.flink.runtime.io.network.buffer.BufferPool; @@ -35,11 +36,21 @@ import org.apache.flink.runtime.io.network.partition.ResultPartitionManager; import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate; import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; +import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.messages.JobManagerMessages.RequestPartitionState; import org.apache.flink.runtime.messages.TaskMessages.FailTask; +import org.apache.flink.runtime.query.KvStateID; +import org.apache.flink.runtime.query.KvStateMessage; +import org.apache.flink.runtime.query.KvStateRegistry; +import org.apache.flink.runtime.query.KvStateRegistryListener; +import org.apache.flink.runtime.query.KvStateServerAddress; +import org.apache.flink.runtime.query.TaskKvStateRegistry; import org.apache.flink.runtime.taskmanager.NetworkEnvironmentConfiguration; import org.apache.flink.runtime.taskmanager.Task; import org.apache.flink.runtime.taskmanager.TaskManager; +import org.apache.flink.runtime.query.netty.AtomicKvStateRequestStats; +import org.apache.flink.runtime.query.netty.KvStateServer; +import org.apache.flink.util.Preconditions; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import scala.Option; @@ -84,6 +95,12 @@ public class NetworkEnvironment { private PartitionStateChecker partitionStateChecker; + /** Server for {@link org.apache.flink.runtime.state.KvState} requests. */ + private KvStateServer kvStateServer; + + /** Registry for {@link org.apache.flink.runtime.state.KvState} instances. */ + private KvStateRegistry kvStateRegistry; + private boolean isShutdown; /** @@ -92,17 +109,21 @@ public class NetworkEnvironment { */ private final ExecutionContext executionContext; + private final InstanceConnectionInfo connectionInfo; + /** * Initializes all network I/O components. */ public NetworkEnvironment( - ExecutionContext executionContext, - FiniteDuration jobManagerTimeout, - NetworkEnvironmentConfiguration config) throws IOException { + ExecutionContext executionContext, + FiniteDuration jobManagerTimeout, + NetworkEnvironmentConfiguration config, + InstanceConnectionInfo connectionInfo) throws IOException { this.executionContext = executionContext; this.configuration = checkNotNull(config); this.jobManagerTimeout = checkNotNull(jobManagerTimeout); + this.connectionInfo = checkNotNull(connectionInfo); // create the network buffers - this is the operation most likely to fail upon // mis-configuration, so we do this first @@ -151,6 +172,10 @@ public Tuple2 getPartitionRequestInitialAndMaxBackoff() { return configuration.partitionRequestInitialAndMaxBackoff(); } + public TaskKvStateRegistry createKvStateTaskRegistry(JobID jobId, JobVertexID jobVertexId) { + return kvStateRegistry.createTaskRegistry(jobId, jobVertexId); + } + // -------------------------------------------------------------------------------------------- // Association / Disassociation with JobManager / TaskManager // -------------------------------------------------------------------------------------------- @@ -183,7 +208,9 @@ public void associateWithTaskManagerAndJobManager( if (this.partitionConsumableNotifier == null && this.partitionManager == null && this.taskEventDispatcher == null && - this.connectionManager == null) + this.connectionManager == null && + this.kvStateRegistry == null && + this.kvStateServer == null) { // good, not currently associated. start the individual components @@ -211,6 +238,29 @@ public void associateWithTaskManagerAndJobManager( catch (Throwable t) { throw new IOException("Failed to instantiate network connection manager: " + t.getMessage(), t); } + + try { + kvStateRegistry = new KvStateRegistry(); + + kvStateServer = new KvStateServer( + connectionInfo.address(), + 0, + 1, + 10, + kvStateRegistry, + new AtomicKvStateRequestStats()); + + kvStateServer.start(); + + KvStateRegistryListener listener = new JobManagerKvStateRegistryListener( + jobManagerGateway, + kvStateServer.getAddress()); + + kvStateRegistry.registerListener(listener); + } catch (Throwable t) { + throw new IOException("Failed to instantiate KvState management components: " + + t.getMessage(), t); + } } else { throw new IllegalStateException( @@ -227,6 +277,19 @@ public void disassociate() throws IOException { LOG.debug("Disassociating NetworkEnvironment from TaskManager. Cleaning all intermediate results."); + // Shut down KvStateRegistry + kvStateRegistry = null; + + // Shut down KvStateServer + if (kvStateServer != null) { + try { + kvStateServer.shutDown(); + } catch (Throwable t) { + throw new IOException("Cannot shutdown KvStateNettyServer", t); + } + kvStateServer = null; + } + // terminate all network connections if (connectionManager != null) { try { @@ -511,4 +574,58 @@ public void triggerPartitionStateCheck( jobManager.tell(msg, taskManager); } } + + /** + * Simple {@link KvStateRegistry} listener, which forwards registrations to + * the JobManager. + */ + private static class JobManagerKvStateRegistryListener implements KvStateRegistryListener { + + private ActorGateway jobManager; + + private KvStateServerAddress kvStateServerAddress; + + public JobManagerKvStateRegistryListener( + ActorGateway jobManager, + KvStateServerAddress kvStateServerAddress) { + + this.jobManager = Preconditions.checkNotNull(jobManager, "JobManager"); + this.kvStateServerAddress = Preconditions.checkNotNull(kvStateServerAddress, "KvStateServerAddress"); + } + + @Override + public void notifyKvStateRegistered( + JobID jobId, + JobVertexID jobVertexId, + int keyGroupIndex, + String registrationName, + KvStateID kvStateId) { + + Object msg = new KvStateMessage.NotifyKvStateRegistered( + jobId, + jobVertexId, + keyGroupIndex, + registrationName, + kvStateId, + kvStateServerAddress); + + jobManager.tell(msg); + } + + @Override + public void notifyKvStateUnregistered( + JobID jobId, + JobVertexID jobVertexId, + int keyGroupIndex, + String registrationName) { + + Object msg = new KvStateMessage.NotifyKvStateUnregistered( + jobId, + jobVertexId, + keyGroupIndex, + registrationName); + + jobManager.tell(msg); + } + } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/query/KvStateLocation.java b/flink-runtime/src/main/java/org/apache/flink/runtime/query/KvStateLocation.java new file mode 100644 index 0000000000000..9be22c2458daf --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/query/KvStateLocation.java @@ -0,0 +1,239 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.query; + +import org.apache.flink.api.common.JobID; +import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.state.KvState; +import org.apache.flink.util.Preconditions; + +import java.io.Serializable; +import java.util.Arrays; + +/** + * Location information for all key groups of a {@link KvState} instance. + * + *

This is populated by the {@link KvStateLocationRegistry} and used by the + * {@link QueryableStateClient} to target queries. + */ +public class KvStateLocation implements Serializable { + + private static final long serialVersionUID = 1L; + + /** JobID the KvState instances belong to. */ + private final JobID jobId; + + /** JobVertexID the KvState instances belong to. */ + private final JobVertexID jobVertexId; + + /** Number of key groups of the operator the KvState instances belong to. */ + private final int numKeyGroups; + + /** Name under which the KvState instances have been registered. */ + private final String registrationName; + + /** IDs for each KvState instance where array index corresponds to key group index. */ + private final KvStateID[] kvStateIds; + + /** + * Server address for each KvState instance where array index corresponds to + * key group index. + */ + private final KvStateServerAddress[] kvStateAddresses; + + /** Current number of registered key groups. */ + private int numRegisteredKeyGroups; + + /** + * Creates the location information + * + * @param jobId JobID the KvState instances belong to + * @param jobVertexId JobVertexID the KvState instances belong to + * @param numKeyGroups Number of key groups of the operator + * @param registrationName Name under which the KvState instances have been registered + */ + public KvStateLocation(JobID jobId, JobVertexID jobVertexId, int numKeyGroups, String registrationName) { + this.jobId = Preconditions.checkNotNull(jobId, "JobID"); + this.jobVertexId = Preconditions.checkNotNull(jobVertexId, "JobVertexID"); + Preconditions.checkArgument(numKeyGroups >= 0, "Negative number of key groups"); + this.numKeyGroups = numKeyGroups; + this.registrationName = Preconditions.checkNotNull(registrationName, "Registration name"); + this.kvStateIds = new KvStateID[numKeyGroups]; + this.kvStateAddresses = new KvStateServerAddress[numKeyGroups]; + } + + /** + * Returns the JobID the KvState instances belong to. + * + * @return JobID the KvState instances belong to + */ + public JobID getJobId() { + return jobId; + } + + /** + * Returns the JobVertexID the KvState instances belong to. + * + * @return JobVertexID the KvState instances belong to + */ + public JobVertexID getJobVertexId() { + return jobVertexId; + } + + /** + * Returns the number of key groups of the operator the KvState instances belong to. + * + * @return Number of key groups of the operator the KvState instances belong to + */ + public int getNumKeyGroups() { + return numKeyGroups; + } + + /** + * Returns the name under which the KvState instances have been registered. + * + * @return Name under which the KvState instances have been registered. + */ + public String getRegistrationName() { + return registrationName; + } + + /** + * Returns the current number of registered key groups. + * + * @return Number of registered key groups. + */ + public int getNumRegisteredKeyGroups() { + return numRegisteredKeyGroups; + } + + /** + * Returns the registered KvStateID for the key group index or + * null if none is registered yet. + * + * @param keyGroupIndex Key group index to get ID for. + * @return KvStateID for the key group index or null if none + * is registered yet + * @throws IndexOutOfBoundsException If key group index < 0 or >= Number of key groups + */ + public KvStateID getKvStateID(int keyGroupIndex) { + if (keyGroupIndex < 0 || keyGroupIndex >= numKeyGroups) { + throw new IndexOutOfBoundsException("Key group index"); + } + + return kvStateIds[keyGroupIndex]; + } + + /** + * Returns the registered KvStateServerAddress for the key group index or + * null if none is registered yet. + * + * @param keyGroupIndex Key group index to get server address for. + * @return KvStateServerAddress for the key group index or null + * if none is registered yet + * @throws IndexOutOfBoundsException If key group index < 0 or >= Number of key groups + */ + public KvStateServerAddress getKvStateServerAddress(int keyGroupIndex) { + if (keyGroupIndex < 0 || keyGroupIndex >= numKeyGroups) { + throw new IndexOutOfBoundsException("Key group index"); + } + + return kvStateAddresses[keyGroupIndex]; + } + + /** + * Registers a KvState instance for the given key group index. + * + * @param keyGroupIndex Key group index to register + * @param kvStateId ID of the KvState instance at the key group index. + * @param kvStateAddress Server address of the KvState instance at the key group index. + * @throws IndexOutOfBoundsException If key group index < 0 or >= Number of key groups + */ + void registerKvState(int keyGroupIndex, KvStateID kvStateId, KvStateServerAddress kvStateAddress) { + if (keyGroupIndex < 0 || keyGroupIndex >= numKeyGroups) { + throw new IndexOutOfBoundsException("Key group index"); + } + + if (kvStateIds[keyGroupIndex] == null && kvStateAddresses[keyGroupIndex] == null) { + numRegisteredKeyGroups++; + } + + kvStateIds[keyGroupIndex] = kvStateId; + kvStateAddresses[keyGroupIndex] = kvStateAddress; + } + + /** + * Registers a KvState instance for the given key group index. + * + * @param keyGroupIndex Key group index to unregister. + * @throws IndexOutOfBoundsException If key group index < 0 or >= Number of key groups + * @throws IllegalArgumentException If no location information registered for key group index. + */ + void unregisterKvState(int keyGroupIndex) { + if (keyGroupIndex < 0 || keyGroupIndex >= numKeyGroups) { + throw new IndexOutOfBoundsException("Key group index"); + } + + if (kvStateIds[keyGroupIndex] == null || kvStateAddresses[keyGroupIndex] == null) { + throw new IllegalArgumentException("Not registered. Probably registration/unregistration race."); + } + + numRegisteredKeyGroups--; + + kvStateIds[keyGroupIndex] = null; + kvStateAddresses[keyGroupIndex] = null; + } + + @Override + public String toString() { + return "KvStateLocation{" + + "jobId=" + jobId + + ", jobVertexId=" + jobVertexId + + ", parallelism=" + numKeyGroups + + ", kvStateIds=" + Arrays.toString(kvStateIds) + + ", kvStateAddresses=" + Arrays.toString(kvStateAddresses) + + '}'; + } + + @Override + public boolean equals(Object o) { + if (this == o) { return true; } + if (o == null || getClass() != o.getClass()) { return false; } + + KvStateLocation that = (KvStateLocation) o; + + if (numKeyGroups != that.numKeyGroups) { return false; } + if (!jobId.equals(that.jobId)) { return false; } + if (!jobVertexId.equals(that.jobVertexId)) { return false; } + if (!registrationName.equals(that.registrationName)) { return false; } + if (!Arrays.equals(kvStateIds, that.kvStateIds)) { return false; } + return Arrays.equals(kvStateAddresses, that.kvStateAddresses); + } + + @Override + public int hashCode() { + int result = jobId.hashCode(); + result = 31 * result + jobVertexId.hashCode(); + result = 31 * result + numKeyGroups; + result = 31 * result + registrationName.hashCode(); + result = 31 * result + Arrays.hashCode(kvStateIds); + result = 31 * result + Arrays.hashCode(kvStateAddresses); + return result; + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/query/KvStateLocationRegistry.java b/flink-runtime/src/main/java/org/apache/flink/runtime/query/KvStateLocationRegistry.java new file mode 100644 index 0000000000000..5b7659870df26 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/query/KvStateLocationRegistry.java @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.query; + +import org.apache.flink.api.common.JobID; +import org.apache.flink.runtime.execution.SuppressRestartsException; +import org.apache.flink.runtime.executiongraph.ExecutionJobVertex; +import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.state.KvState; +import org.apache.flink.util.Preconditions; + +import java.util.HashMap; +import java.util.Map; + +/** + * Simple registry, which maps {@link KvState} registration notifications to + * {@link KvStateLocation} instances. + */ +public class KvStateLocationRegistry { + + /** JobID this coordinator belongs to. */ + private final JobID jobId; + + /** Job vertices for determining parallelism per key. */ + private final Map jobVertices; + + /** + * Location info keyed by registration name. The name needs to be unique + * per JobID, i.e. two operators cannot register KvState with the same + * name. + */ + private final Map lookupTable = new HashMap<>(); + + /** + * Creates the registry for the job. + * + * @param jobId JobID this coordinator belongs to. + * @param jobVertices Job vertices map of all vertices of this job. + */ + public KvStateLocationRegistry(JobID jobId, Map jobVertices) { + this.jobId = Preconditions.checkNotNull(jobId, "JobID"); + this.jobVertices = Preconditions.checkNotNull(jobVertices, "Job vertices"); + } + + /** + * Returns the {@link KvStateLocation} for the registered KvState instance + * or null if no location information is available. + * + * @param registrationName Name under which the KvState instance is registered. + * @return Location information or null. + */ + public KvStateLocation getKvStateLocation(String registrationName) { + return lookupTable.get(registrationName); + } + + /** + * Notifies the registry about a registered KvState instance. + * + * @param jobVertexId JobVertexID the KvState instance belongs to + * @param keyGroupIndex Key group index the KvState instance belongs to + * @param registrationName Name under which the KvState has been registered + * @param kvStateId ID of the registered KvState instance + * @param kvStateServerAddress Server address where to find the KvState instance + * + * @throws IllegalArgumentException If JobVertexID does not belong to job + * @throws IllegalArgumentException If state has been registered with same + * name by another operator. + * @throws IndexOutOfBoundsException If key group index is out of bounds. + */ + public void notifyKvStateRegistered( + JobVertexID jobVertexId, + int keyGroupIndex, + String registrationName, + KvStateID kvStateId, + KvStateServerAddress kvStateServerAddress) { + + KvStateLocation location = lookupTable.get(registrationName); + + if (location == null) { + // First registration for this operator, create the location info + ExecutionJobVertex vertex = jobVertices.get(jobVertexId); + + if (vertex != null) { + int parallelism = vertex.getParallelism(); + location = new KvStateLocation(jobId, jobVertexId, parallelism, registrationName); + lookupTable.put(registrationName, location); + } else { + throw new IllegalArgumentException("Unknown JobVertexID " + jobVertexId); + } + } + + // Duplicated name if vertex IDs don't match + if (!location.getJobVertexId().equals(jobVertexId)) { + IllegalStateException duplicate = new IllegalStateException( + "Registration name clash. KvState with name '" + registrationName + + "' has already been registered by another operator (" + + location.getJobVertexId() + ")."); + + ExecutionJobVertex vertex = jobVertices.get(jobVertexId); + if (vertex != null) { + vertex.fail(new SuppressRestartsException(duplicate)); + } + + throw duplicate; + } + + location.registerKvState(keyGroupIndex, kvStateId, kvStateServerAddress); + } + + /** + * Notifies the registry about an unregistered KvState instance. + * + * @param jobVertexId JobVertexID the KvState instance belongs to + * @param keyGroupIndex Key group index the KvState instance belongs to + * @param registrationName Name under which the KvState has been registered + * @throws IllegalArgumentException If another operator registered the state instance + * @throws IllegalArgumentException If the registration name is not known + */ + public void notifyKvStateUnregistered( + JobVertexID jobVertexId, + int keyGroupIndex, + String registrationName) { + + KvStateLocation location = lookupTable.get(registrationName); + + if (location != null) { + // Duplicate name if vertex IDs don't match + if (!location.getJobVertexId().equals(jobVertexId)) { + throw new IllegalArgumentException("Another operator (" + + location.getJobVertexId() + ") registered the KvState " + + "under '" + registrationName + "'."); + } + + location.unregisterKvState(keyGroupIndex); + + if (location.getNumRegisteredKeyGroups() == 0) { + lookupTable.remove(registrationName); + } + } else { + throw new IllegalArgumentException("Unknown registration name '" + + registrationName + "'. " + "Probably registration/unregistration race."); + } + } + +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/query/KvStateMessage.java b/flink-runtime/src/main/java/org/apache/flink/runtime/query/KvStateMessage.java new file mode 100644 index 0000000000000..5e3c38e983bdd --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/query/KvStateMessage.java @@ -0,0 +1,290 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.query; + +import org.apache.flink.api.common.JobID; +import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.state.KvState; +import org.apache.flink.util.Preconditions; + +import java.io.Serializable; + +/** + * Actor messages for {@link KvState} lookup and registration. + */ +public interface KvStateMessage extends Serializable { + + // ------------------------------------------------------------------------ + // Lookup + // ------------------------------------------------------------------------ + + class LookupKvStateLocation implements KvStateMessage { + + private static final long serialVersionUID = 1L; + + /** JobID the KvState instance belongs to. */ + private final JobID jobId; + + /** Name under which the KvState has been registered. */ + private final String registrationName; + + /** + * Requests a {@link KvStateLocation} for the specified JobID and + * {@link KvState} registration name. + * + * @param jobId JobID the KvState instance belongs to + * @param registrationName Name under which the KvState has been registered + */ + public LookupKvStateLocation(JobID jobId, String registrationName) { + this.jobId = Preconditions.checkNotNull(jobId, "JobID"); + this.registrationName = Preconditions.checkNotNull(registrationName, "Name"); + } + + /** + * Returns the JobID the KvState instance belongs to. + * + * @return JobID the KvState instance belongs to + */ + public JobID getJobId() { + return jobId; + } + + /** + * Returns the name under which the KvState has been registered. + * + * @return Name under which the KvState has been registered + */ + public String getRegistrationName() { + return registrationName; + } + + @Override + public String toString() { + return "LookupKvStateLocation{" + + "jobId=" + jobId + + ", registrationName='" + registrationName + '\'' + + '}'; + } + } + + // ------------------------------------------------------------------------ + // Registration + // ------------------------------------------------------------------------ + + class NotifyKvStateRegistered implements KvStateMessage { + + private static final long serialVersionUID = 1L; + + /** JobID the KvState instance belongs to. */ + private final JobID jobId; + + /** JobVertexID the KvState instance belongs to. */ + private final JobVertexID jobVertexId; + + /** Key group index the KvState instance belongs to. */ + private final int keyGroupIndex; + + /** Name under which the KvState has been registered. */ + private final String registrationName; + + /** ID of the registered KvState instance. */ + private final KvStateID kvStateId; + + /** Server address where to find the KvState instance. */ + private final KvStateServerAddress kvStateServerAddress; + + /** + * Notifies the JobManager about a registered {@link KvState} instance. + * + * @param jobId JobID the KvState instance belongs to + * @param jobVertexId JobVertexID the KvState instance belongs to + * @param keyGroupIndex Key group index the KvState instance belongs to + * @param registrationName Name under which the KvState has been registered + * @param kvStateId ID of the registered KvState instance + * @param kvStateServerAddress Server address where to find the KvState instance + */ + public NotifyKvStateRegistered( + JobID jobId, + JobVertexID jobVertexId, + int keyGroupIndex, + String registrationName, + KvStateID kvStateId, + KvStateServerAddress kvStateServerAddress) { + + this.jobId = Preconditions.checkNotNull(jobId, "JobID"); + this.jobVertexId = Preconditions.checkNotNull(jobVertexId, "JobVertexID"); + Preconditions.checkArgument(keyGroupIndex >= 0, "Negative key group index"); + this.keyGroupIndex = keyGroupIndex; + this.registrationName = Preconditions.checkNotNull(registrationName, "Registration name"); + this.kvStateId = Preconditions.checkNotNull(kvStateId, "KvStateID"); + this.kvStateServerAddress = Preconditions.checkNotNull(kvStateServerAddress, "KvStateServerAddress"); + } + + /** + * Returns the JobID the KvState instance belongs to. + * + * @return JobID the KvState instance belongs to + */ + public JobID getJobId() { + return jobId; + } + + /** + * Returns the JobVertexID the KvState instance belongs to + * + * @return JobVertexID the KvState instance belongs to + */ + public JobVertexID getJobVertexId() { + return jobVertexId; + } + + /** + * Returns the key group index the KvState instance belongs to. + * + * @return Key group index the KvState instance belongs to + */ + public int getKeyGroupIndex() { + return keyGroupIndex; + } + + /** + * Returns the name under which the KvState has been registered. + * + * @return Name under which the KvState has been registered + */ + public String getRegistrationName() { + return registrationName; + } + + /** + * Returns the ID of the registered KvState instance. + * + * @return ID of the registered KvState instance + */ + public KvStateID getKvStateId() { + return kvStateId; + } + + /** + * Returns the server address where to find the KvState instance. + * + * @return Server address where to find the KvState instance + */ + public KvStateServerAddress getKvStateServerAddress() { + return kvStateServerAddress; + } + + @Override + public String toString() { + return "NotifyKvStateRegistered{" + + "jobId=" + jobId + + ", jobVertexId=" + jobVertexId + + ", keyGroupIndex=" + keyGroupIndex + + ", registrationName='" + registrationName + '\'' + + ", kvStateId=" + kvStateId + + ", kvStateServerAddress=" + kvStateServerAddress + + '}'; + } + } + + class NotifyKvStateUnregistered implements KvStateMessage { + + private static final long serialVersionUID = 1L; + + /** JobID the KvState instance belongs to. */ + private final JobID jobId; + + /** JobVertexID the KvState instance belongs to. */ + private final JobVertexID jobVertexId; + + /** Key group index the KvState instance belongs to. */ + private final int keyGroupIndex; + + /** Name under which the KvState has been registered. */ + private final String registrationName; + + /** + * Notifies the JobManager about an unregistered {@link KvState} instance. + * + * @param jobId JobID the KvState instance belongs to + * @param jobVertexId JobVertexID the KvState instance belongs to + * @param keyGroupIndex Key group index the KvState instance belongs to + * @param registrationName Name under which the KvState has been registered + */ + public NotifyKvStateUnregistered( + JobID jobId, + JobVertexID jobVertexId, + int keyGroupIndex, + String registrationName) { + + this.jobId = Preconditions.checkNotNull(jobId, "JobID"); + this.jobVertexId = Preconditions.checkNotNull(jobVertexId, "JobVertexID"); + Preconditions.checkArgument(keyGroupIndex >= 0, "Negative key group index"); + this.keyGroupIndex = keyGroupIndex; + this.registrationName = Preconditions.checkNotNull(registrationName, "Registration name"); + } + + /** + * Returns the JobID the KvState instance belongs to. + * + * @return JobID the KvState instance belongs to + */ + public JobID getJobId() { + return jobId; + } + + /** + * Returns the JobVertexID the KvState instance belongs to + * + * @return JobVertexID the KvState instance belongs to + */ + public JobVertexID getJobVertexId() { + return jobVertexId; + } + + /** + * Returns the key group index the KvState instance belongs to. + * + * @return Key group index the KvState instance belongs to + */ + public int getKeyGroupIndex() { + return keyGroupIndex; + } + + /** + * Returns the name under which the KvState has been registered. + * + * @return Name under which the KvState has been registered + */ + public String getRegistrationName() { + return registrationName; + } + + @Override + public String toString() { + return "NotifyKvStateUnregistered{" + + "jobId=" + jobId + + ", jobVertexId=" + jobVertexId + + ", keyGroupIndex=" + keyGroupIndex + + ", registrationName='" + registrationName + '\'' + + '}'; + } + } + +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/query/KvStateRegistry.java b/flink-runtime/src/main/java/org/apache/flink/runtime/query/KvStateRegistry.java new file mode 100644 index 0000000000000..e09b868b0d344 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/query/KvStateRegistry.java @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.query; + +import org.apache.flink.api.common.JobID; +import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.query.netty.KvStateServer; +import org.apache.flink.runtime.state.KvState; +import org.apache.flink.runtime.taskmanager.Task; + +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicReference; + +/** + * A registry for {@link KvState} instances per task manager. + * + *

This is currently only used for KvState queries: KvState instances, which + * are marked as queryable in their state descriptor are registered here and + * can be queried by the {@link KvStateServer}. + * + *

KvState is registered when it is created/restored and unregistered when + * the owning operator stops running. + */ +public class KvStateRegistry { + + /** All registered KvState instances. */ + private final ConcurrentHashMap> registeredKvStates = + new ConcurrentHashMap<>(); + + /** Registry listener to be notified on registration/unregistration. */ + private final AtomicReference listener = new AtomicReference<>(); + + /** + * Registers a listener with the registry. + * + * @param listener The registry listener. + * @throws IllegalStateException If there is a registered listener + */ + public void registerListener(KvStateRegistryListener listener) { + if (!this.listener.compareAndSet(null, listener)) { + throw new IllegalStateException("Listener already registered."); + } + } + + /** + * Registers the KvState instance identified by the given 4-tuple of JobID, + * JobVertexID, key group index, and registration name. + * + * @param kvStateId KvStateID to identify the KvState instance + * @param kvState KvState instance to register + * @throws IllegalStateException If there is a KvState instance registered + * with the same ID. + */ + + /** + * Registers the KvState instance and returns the assigned ID. + * + * @param jobId JobId the KvState instance belongs to + * @param jobVertexId JobVertexID the KvState instance belongs to + * @param keyGroupIndex Key group index the KvState instance belongs to + * @param registrationName Name under which the KvState is registered + * @param kvState KvState instance to be registered + * @return Assigned KvStateID + */ + public KvStateID registerKvState( + JobID jobId, + JobVertexID jobVertexId, + int keyGroupIndex, + String registrationName, + KvState kvState) { + + KvStateID kvStateId = new KvStateID(); + + if (registeredKvStates.putIfAbsent(kvStateId, kvState) == null) { + KvStateRegistryListener listener = this.listener.get(); + if (listener != null) { + listener.notifyKvStateRegistered( + jobId, + jobVertexId, + keyGroupIndex, + registrationName, + kvStateId); + } + + return kvStateId; + } else { + throw new IllegalStateException(kvStateId + " is already registered."); + } + } + + /** + * Unregisters the KvState instance identified by the given KvStateID. + * + * @param jobId JobId the KvState instance belongs to + * @param kvStateId KvStateID to identify the KvState instance + */ + public void unregisterKvState( + JobID jobId, + JobVertexID jobVertexId, + int keyGroupIndex, + String registrationName, + KvStateID kvStateId) { + + if (registeredKvStates.remove(kvStateId) != null) { + KvStateRegistryListener listener = this.listener.get(); + if (listener != null) { + listener.notifyKvStateUnregistered( + jobId, + jobVertexId, + keyGroupIndex, + registrationName); + } + } + } + + /** + * Returns the KvState instance identified by the given KvStateID or + * null if none is registered. + * + * @param kvStateId KvStateID to identify the KvState instance + * @return KvState instance identified by the KvStateID or null + */ + public KvState getKvState(KvStateID kvStateId) { + return registeredKvStates.get(kvStateId); + } + + // ------------------------------------------------------------------------ + + /** + * Creates a {@link TaskKvStateRegistry} facade for the {@link Task} + * identified by the given JobID and JobVertexID instance. + * + * @param jobId JobID of the task + * @param jobVertexId JobVertexID of the task + * @return A {@link TaskKvStateRegistry} facade for the task + */ + public TaskKvStateRegistry createTaskRegistry(JobID jobId, JobVertexID jobVertexId) { + return new TaskKvStateRegistry(this, jobId, jobVertexId); + } + +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/query/KvStateRegistryListener.java b/flink-runtime/src/main/java/org/apache/flink/runtime/query/KvStateRegistryListener.java new file mode 100644 index 0000000000000..760adf16d7811 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/query/KvStateRegistryListener.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.query; + +import org.apache.flink.api.common.JobID; +import org.apache.flink.runtime.jobgraph.JobVertexID; + +/** + * A listener for a {@link KvStateRegistry}. + * + *

The registry calls these methods when KvState instances are registered + * and unregistered. + */ +public interface KvStateRegistryListener { + + /** + * Notifies the listener about a registered KvState instance. + * + * @param jobId Job ID the KvState instance belongs to + * @param jobVertexId JobVertexID the KvState instance belongs to + * @param keyGroupIndex Key group index the KvState instance belongs to + * @param registrationName Name under which the KvState is registered + * @param kvStateId ID of the KvState instance + */ + void notifyKvStateRegistered( + JobID jobId, + JobVertexID jobVertexId, + int keyGroupIndex, + String registrationName, + KvStateID kvStateId); + + /** + * Notifies the listener about an unregistered KvState instance. + * + * @param jobId Job ID the KvState instance belongs to + * @param jobVertexId JobVertexID the KvState instance belongs to + * @param keyGroupIndex Key group index the KvState instance belongs to + * @param registrationName Name under which the KvState is registered + */ + void notifyKvStateUnregistered( + JobID jobId, + JobVertexID jobVertexId, + int keyGroupIndex, + String registrationName); + +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/query/TaskKvStateRegistry.java b/flink-runtime/src/main/java/org/apache/flink/runtime/query/TaskKvStateRegistry.java new file mode 100644 index 0000000000000..15f0160d76097 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/query/TaskKvStateRegistry.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.query; + +import org.apache.flink.api.common.JobID; +import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.state.KvState; +import org.apache.flink.util.Preconditions; + +import java.util.ArrayList; +import java.util.List; + +/** + * A helper for KvState registrations of a single task. + */ +public class TaskKvStateRegistry { + + /** KvStateRegistry for KvState instance registrations. */ + private final KvStateRegistry registry; + + /** JobID of the task. */ + private final JobID jobId; + + /** JobVertexID of the task. */ + private final JobVertexID jobVertexId; + + /** List of all registered KvState instances of this task. */ + private final List registeredKvStates = new ArrayList<>(); + + TaskKvStateRegistry(KvStateRegistry registry, JobID jobId, JobVertexID jobVertexId) { + this.registry = Preconditions.checkNotNull(registry, "KvStateRegistry"); + this.jobId = Preconditions.checkNotNull(jobId, "JobID"); + this.jobVertexId = Preconditions.checkNotNull(jobVertexId, "JobVertexID"); + } + + /** + * Registers the KvState instance at the KvStateRegistry. + * + * @param keyGroupIndex KeyGroupIndex the KvState instance belongs to + * @param registrationName The registration name (not necessarily the same + * as the KvState name defined in the state + * descriptor used to create the KvState instance) + * @param kvState The + */ + public void registerKvState(int keyGroupIndex, String registrationName, KvState kvState) { + KvStateID kvStateId = registry.registerKvState(jobId, jobVertexId, keyGroupIndex, registrationName, kvState); + registeredKvStates.add(new KvStateInfo(keyGroupIndex, registrationName, kvStateId)); + } + + /** + * Unregisters all registered KvState instances from the KvStateRegistry. + */ + public void unregisterAll() { + for (KvStateInfo kvState : registeredKvStates) { + registry.unregisterKvState(jobId, jobVertexId, kvState.keyGroupIndex, kvState.registrationName, kvState.kvStateId); + } + } + + /** + * 3-tuple holding registered KvState meta data. + */ + private static class KvStateInfo { + + private final int keyGroupIndex; + + private final String registrationName; + + private final KvStateID kvStateId; + + public KvStateInfo(int keyGroupIndex, String registrationName, KvStateID kvStateId) { + this.keyGroupIndex = keyGroupIndex; + this.registrationName = registrationName; + this.kvStateId = kvStateId; + } + } + +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/RuntimeEnvironment.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/RuntimeEnvironment.java index 6fdf6f9492692..69587846090fc 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/RuntimeEnvironment.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/RuntimeEnvironment.java @@ -36,6 +36,7 @@ import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider; import org.apache.flink.runtime.memory.MemoryManager; import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint; +import org.apache.flink.runtime.query.TaskKvStateRegistry; import org.apache.flink.runtime.state.StateHandle; import org.apache.flink.util.SerializedValue; @@ -75,6 +76,8 @@ public class RuntimeEnvironment implements Environment { private final AccumulatorRegistry accumulatorRegistry; + private final TaskKvStateRegistry kvStateRegistry; + private final TaskManagerRuntimeInfo taskManagerInfo; private final TaskMetricGroup metrics; @@ -95,6 +98,7 @@ public RuntimeEnvironment( IOManager ioManager, BroadcastVariableManager bcVarManager, AccumulatorRegistry accumulatorRegistry, + TaskKvStateRegistry kvStateRegistry, InputSplitProvider splitProvider, Map> distCacheEntries, ResultPartitionWriter[] writers, @@ -116,6 +120,7 @@ public RuntimeEnvironment( this.ioManager = checkNotNull(ioManager); this.bcVarManager = checkNotNull(bcVarManager); this.accumulatorRegistry = checkNotNull(accumulatorRegistry); + this.kvStateRegistry = checkNotNull(kvStateRegistry); this.splitProvider = checkNotNull(splitProvider); this.distCacheEntries = checkNotNull(distCacheEntries); this.writers = checkNotNull(writers); @@ -198,6 +203,11 @@ public AccumulatorRegistry getAccumulatorRegistry() { return accumulatorRegistry; } + @Override + public TaskKvStateRegistry getTaskKvStateRegistry() { + return kvStateRegistry; + } + @Override public InputSplitProvider getInputSplitProvider() { return splitProvider; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java index dbc0b6235ee6d..c98d51204b63d 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java @@ -57,6 +57,7 @@ import org.apache.flink.runtime.messages.TaskMessages.TaskInFinalState; import org.apache.flink.runtime.messages.TaskMessages.UpdateTaskExecutionState; import org.apache.flink.runtime.messages.checkpoint.DeclineCheckpoint; +import org.apache.flink.runtime.query.TaskKvStateRegistry; import org.apache.flink.runtime.state.StateHandle; import org.apache.flink.runtime.state.StateUtils; import org.apache.flink.util.SerializedValue; @@ -519,10 +520,14 @@ else if (current == ExecutionState.CANCELING) { TaskInputSplitProvider splitProvider = new TaskInputSplitProvider(jobManager, jobId, vertexId, executionId, userCodeClassLoader, actorAskTimeout); + TaskKvStateRegistry kvStateRegistry = network + .createKvStateTaskRegistry(jobId, getJobVertexId()); + Environment env = new RuntimeEnvironment(jobId, vertexId, executionId, executionConfig, taskInfo, jobConfiguration, taskConfiguration, userCodeClassLoader, memoryManager, ioManager, broadcastVariableManager, accumulatorRegistry, + kvStateRegistry, splitProvider, distributedCacheEntries, writers, inputGates, jobManager, taskManagerConfig, metrics, this); diff --git a/flink-runtime/src/main/scala/org/apache/flink/runtime/jobmanager/JobManager.scala b/flink-runtime/src/main/scala/org/apache/flink/runtime/jobmanager/JobManager.scala index 84d38c1c97fba..01af3c1404e06 100644 --- a/flink-runtime/src/main/scala/org/apache/flink/runtime/jobmanager/JobManager.scala +++ b/flink-runtime/src/main/scala/org/apache/flink/runtime/jobmanager/JobManager.scala @@ -25,7 +25,7 @@ import java.util.UUID import java.util.concurrent.{ExecutorService, TimeUnit, TimeoutException} import javax.management.ObjectName -import akka.actor.Status.Failure +import akka.actor.Status.{Success, Failure} import akka.actor._ import akka.pattern.ask @@ -74,6 +74,8 @@ import org.apache.flink.runtime.messages.webmonitor._ import org.apache.flink.runtime.metrics.{MetricRegistry => FlinkMetricRegistry} import org.apache.flink.runtime.metrics.groups.JobManagerMetricGroup import org.apache.flink.runtime.process.ProcessReaper +import org.apache.flink.runtime.query.{UnknownKvStateLocation, KvStateMessage} +import org.apache.flink.runtime.query.KvStateMessage.{NotifyKvStateUnregistered, LookupKvStateLocation, NotifyKvStateRegistered} import org.apache.flink.runtime.security.SecurityUtils import org.apache.flink.runtime.security.SecurityUtils.FlinkSecuredRunner import org.apache.flink.runtime.taskmanager.TaskManager @@ -678,6 +680,9 @@ class JobManager( case checkpointMessage : AbstractCheckpointMessage => handleCheckpointMessage(checkpointMessage) + case kvStateMsg : KvStateMessage => + handleKvStateMessage(kvStateMsg) + case TriggerSavepoint(jobId) => currentJobs.get(jobId) match { case Some((graph, _)) => @@ -1435,6 +1440,75 @@ class JobManager( case _ => unhandled(actorMessage) } } + + /** + * Handle all [KvStateMessage] instances for KvState location lookups and + * registration. + * + * @param actorMsg The KvState actor message. + */ + private def handleKvStateMessage(actorMsg: KvStateMessage): Unit = { + actorMsg match { + // Client KvStateLocation lookup + case msg: LookupKvStateLocation => + currentJobs.get(msg.getJobId) match { + case Some((graph, _)) => + try { + val registry = graph.getKvStateLocationRegistry + val location = registry.getKvStateLocation(msg.getRegistrationName) + if (location == null) { + sender() ! Failure(new UnknownKvStateLocation(msg.getRegistrationName)) + } else { + sender() ! Success(location) + } + } catch { + case t: Throwable => + sender() ! Failure(t) + } + + case None => + sender() ! Status.Failure(new IllegalStateException(s"Job ${msg.getJobId} not found")) + } + + // TaskManager KvState registration + case msg: NotifyKvStateRegistered => + currentJobs.get(msg.getJobId) match { + case Some((graph, _)) => + try { + graph.getKvStateLocationRegistry.notifyKvStateRegistered( + msg.getJobVertexId, + msg.getKeyGroupIndex, + msg.getRegistrationName, + msg.getKvStateId, + msg.getKvStateServerAddress) + } catch { + case t: Throwable => + log.error(s"Failed to notify KvStateRegistry about registration $msg.") + } + + case None => log.error(s"Received $msg for unavailable job.") + } + + // TaskManager KvState unregistration + case msg: NotifyKvStateUnregistered => + currentJobs.get(msg.getJobId) match { + case Some((graph, _)) => + try { + graph.getKvStateLocationRegistry.notifyKvStateUnregistered( + msg.getJobVertexId, + msg.getKeyGroupIndex, + msg.getRegistrationName) + } catch { + case t: Throwable => + log.error(s"Failed to notify KvStateRegistry about registration $msg.") + } + + case None => log.error(s"Received $msg for unavailable job.") + } + + case _ => unhandled(actorMsg) + } + } /** * Handle unmatched messages with an exception. diff --git a/flink-runtime/src/main/scala/org/apache/flink/runtime/taskmanager/TaskManager.scala b/flink-runtime/src/main/scala/org/apache/flink/runtime/taskmanager/TaskManager.scala index 226fa75b19a4f..7c4b867a114ec 100644 --- a/flink-runtime/src/main/scala/org/apache/flink/runtime/taskmanager/TaskManager.scala +++ b/flink-runtime/src/main/scala/org/apache/flink/runtime/taskmanager/TaskManager.scala @@ -1809,7 +1809,11 @@ object TaskManager { val executionContext = ExecutionContext.fromExecutor(new ForkJoinPool()) // we start the network first, to make sure it can allocate its buffers first - val network = new NetworkEnvironment(executionContext, taskManagerConfig.timeout, netConfig) + val network = new NetworkEnvironment( + executionContext, + taskManagerConfig.timeout, + netConfig, + connectionInfo) // computing the amount of memory to use depends on how much memory is available // it strictly needs to happen AFTER the network stack has been initialized diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/NetworkEnvironmentTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/NetworkEnvironmentTest.java index fca3cebdb0885..938e66190e80e 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/NetworkEnvironmentTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/NetworkEnvironmentTest.java @@ -23,6 +23,7 @@ import org.apache.flink.core.memory.MemoryType; import org.apache.flink.runtime.instance.ActorGateway; import org.apache.flink.runtime.instance.DummyActorGateway; +import org.apache.flink.runtime.instance.InstanceConnectionInfo; import org.apache.flink.runtime.io.disk.iomanager.IOManager; import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter; import org.apache.flink.runtime.io.network.buffer.BufferPool; @@ -85,7 +86,8 @@ public void testAssociateDisassociate() { NetworkEnvironment env = new NetworkEnvironment( TestingUtils.defaultExecutionContext(), new FiniteDuration(30, TimeUnit.SECONDS), - config); + config, + new InstanceConnectionInfo(InetAddress.getLocalHost(), port)); assertFalse(env.isShutdown()); assertFalse(env.isAssociated()); @@ -178,7 +180,8 @@ public void testEagerlyDeployConsumers() throws Exception { NetworkEnvironment env = new NetworkEnvironment( TestingUtils.defaultExecutionContext(), new FiniteDuration(30, TimeUnit.SECONDS), - config); + config, + new InstanceConnectionInfo(InetAddress.getLocalHost(), 12232)); // Associate the environment with the mock actors env.associateWithTaskManagerAndJobManager( diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerTest.java index 5c25003bb7e78..f925d628ca338 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerTest.java @@ -18,57 +18,83 @@ package org.apache.flink.runtime.jobmanager; +import akka.actor.ActorRef; import akka.actor.ActorSystem; import akka.testkit.JavaTestKit; - import com.typesafe.config.Config; - -import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.JobID; +import org.apache.flink.configuration.ConfigConstants; import org.apache.flink.configuration.Configuration; import org.apache.flink.runtime.akka.AkkaUtils; import org.apache.flink.runtime.akka.ListeningBehaviour; +import org.apache.flink.runtime.clusterframework.types.ResourceID; import org.apache.flink.runtime.execution.ExecutionState; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.executiongraph.ExecutionGraph; import org.apache.flink.runtime.executiongraph.ExecutionGraphTestUtils; import org.apache.flink.runtime.executiongraph.ExecutionVertex; import org.apache.flink.runtime.executiongraph.IntermediateResultPartition; -import org.apache.flink.runtime.instance.AkkaActorGateway; import org.apache.flink.runtime.instance.ActorGateway; +import org.apache.flink.runtime.instance.AkkaActorGateway; import org.apache.flink.runtime.io.network.partition.ResultPartitionID; -import org.apache.flink.runtime.jobgraph.JobVertex; import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; import org.apache.flink.runtime.jobgraph.JobGraph; -import org.apache.flink.runtime.messages.JobManagerMessages; +import org.apache.flink.runtime.jobgraph.JobStatus; +import org.apache.flink.runtime.jobgraph.JobVertex; +import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.leaderretrieval.LeaderRetrievalService; +import org.apache.flink.runtime.leaderretrieval.StandaloneLeaderRetrievalService; import org.apache.flink.runtime.messages.JobManagerMessages.LeaderSessionMessage; +import org.apache.flink.runtime.messages.JobManagerMessages.RequestPartitionState; import org.apache.flink.runtime.messages.JobManagerMessages.StopJob; import org.apache.flink.runtime.messages.JobManagerMessages.StoppingFailure; import org.apache.flink.runtime.messages.JobManagerMessages.StoppingSuccess; import org.apache.flink.runtime.messages.JobManagerMessages.SubmitJob; -import org.apache.flink.runtime.messages.JobManagerMessages.RequestPartitionState; import org.apache.flink.runtime.messages.TaskMessages.PartitionState; +import org.apache.flink.runtime.query.KvStateID; +import org.apache.flink.runtime.query.KvStateLocation; +import org.apache.flink.runtime.query.KvStateMessage.LookupKvStateLocation; +import org.apache.flink.runtime.query.KvStateMessage.NotifyKvStateRegistered; +import org.apache.flink.runtime.query.KvStateMessage.NotifyKvStateUnregistered; +import org.apache.flink.runtime.query.KvStateServerAddress; +import org.apache.flink.runtime.query.UnknownKvStateLocation; +import org.apache.flink.runtime.taskmanager.TaskManager; import org.apache.flink.runtime.testingUtils.TestingCluster; +import org.apache.flink.runtime.testingUtils.TestingJobManager; import org.apache.flink.runtime.testingUtils.TestingJobManagerMessages; import org.apache.flink.runtime.testingUtils.TestingJobManagerMessages.ExecutionGraphFound; +import org.apache.flink.runtime.testingUtils.TestingJobManagerMessages.NotifyWhenJobStatus; import org.apache.flink.runtime.testingUtils.TestingJobManagerMessages.RequestExecutionGraph; import org.apache.flink.runtime.testingUtils.TestingJobManagerMessages.WaitForAllVerticesToBeRunning; import org.apache.flink.runtime.testingUtils.TestingJobManagerMessages.WaitForAllVerticesToBeRunningOrFinished; +import org.apache.flink.runtime.testingUtils.TestingTaskManager; import org.apache.flink.runtime.testingUtils.TestingUtils; import org.apache.flink.runtime.testutils.StoppableInvokable; import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Test; - import scala.Some; import scala.Tuple2; +import scala.concurrent.Await; +import scala.concurrent.Future; +import scala.concurrent.duration.Deadline; +import scala.concurrent.duration.FiniteDuration; +import scala.reflect.ClassTag$; import java.net.InetAddress; +import java.util.UUID; +import java.util.concurrent.TimeUnit; import static org.apache.flink.runtime.io.network.partition.ResultPartitionType.PIPELINED; +import static org.apache.flink.runtime.messages.JobManagerMessages.JobResultSuccess; +import static org.apache.flink.runtime.messages.JobManagerMessages.JobSubmitSuccess; +import static org.apache.flink.runtime.testingUtils.TestingJobManagerMessages.AllVerticesRunning; +import static org.apache.flink.runtime.testingUtils.TestingJobManagerMessages.JobStatusIs; +import static org.apache.flink.runtime.testingUtils.TestingJobManagerMessages.NotifyWhenAtLeastNumTaskManagerAreRegistered; import static org.apache.flink.runtime.testingUtils.TestingUtils.DEFAULT_AKKA_ASK_TIMEOUT; import static org.apache.flink.runtime.testingUtils.TestingUtils.startTestingCluster; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -141,13 +167,13 @@ protected void run() { jobGraph, ListeningBehaviour.EXECUTION_RESULT), testActorGateway); - expectMsgClass(JobManagerMessages.JobSubmitSuccess.class); + expectMsgClass(JobSubmitSuccess.class); jobManagerGateway.tell( new WaitForAllVerticesToBeRunningOrFinished(jid), testActorGateway); - expectMsgClass(TestingJobManagerMessages.AllVerticesRunning.class); + expectMsgClass(AllVerticesRunning.class); // This is the mock execution ID of the task requesting the state of the partition final ExecutionAttemptID receiver = new ExecutionAttemptID(); @@ -267,17 +293,17 @@ protected void run() { jobGraph, ListeningBehaviour.EXECUTION_RESULT), testActorGateway); - expectMsgClass(JobManagerMessages.JobSubmitSuccess.class); + expectMsgClass(JobSubmitSuccess.class); jobManagerGateway.tell(new WaitForAllVerticesToBeRunning(jid), testActorGateway); - expectMsgClass(TestingJobManagerMessages.AllVerticesRunning.class); + expectMsgClass(AllVerticesRunning.class); jobManagerGateway.tell(new StopJob(jid), testActorGateway); // - The test ---------------------------------------------------------------------- expectMsgClass(StoppingSuccess.class); - expectMsgClass(JobManagerMessages.JobResultSuccess.class); + expectMsgClass(JobResultSuccess.class); } finally { if (cluster != null) { cluster.shutdown(); @@ -319,10 +345,10 @@ protected void run() { jobGraph, ListeningBehaviour.EXECUTION_RESULT), testActorGateway); - expectMsgClass(JobManagerMessages.JobSubmitSuccess.class); + expectMsgClass(JobSubmitSuccess.class); jobManagerGateway.tell(new WaitForAllVerticesToBeRunning(jid), testActorGateway); - expectMsgClass(TestingJobManagerMessages.AllVerticesRunning.class); + expectMsgClass(AllVerticesRunning.class); jobManagerGateway.tell(new StopJob(jid), testActorGateway); @@ -342,4 +368,206 @@ protected void run() { }}; } + /** + * Tests that the JobManager handles {@link org.apache.flink.runtime.query.KvStateMessage} + * instances as expected. + */ + @Test + public void testKvStateMessages() throws Exception { + Deadline deadline = new FiniteDuration(100, TimeUnit.SECONDS).fromNow(); + + Configuration config = new Configuration(); + config.setString(ConfigConstants.AKKA_ASK_TIMEOUT, "100ms"); + + UUID leaderSessionId = null; + ActorGateway jobManager = new AkkaActorGateway( + JobManager.startJobManagerActors( + config, + system, + TestingJobManager.class, + MemoryArchivist.class)._1(), + leaderSessionId); + + LeaderRetrievalService leaderRetrievalService = new StandaloneLeaderRetrievalService( + AkkaUtils.getAkkaURL(system, jobManager.actor())); + + Configuration tmConfig = new Configuration(); + tmConfig.setInteger(ConfigConstants.TASK_MANAGER_MEMORY_SIZE_KEY, 4); + tmConfig.setInteger(ConfigConstants.TASK_MANAGER_NUM_TASK_SLOTS, 8); + + ActorRef taskManager = TaskManager.startTaskManagerComponentsAndActor( + tmConfig, + ResourceID.generate(), + system, + "localhost", + scala.Option.empty(), + scala.Option.apply(leaderRetrievalService), + true, + TestingTaskManager.class); + + Future registrationFuture = jobManager + .ask(new NotifyWhenAtLeastNumTaskManagerAreRegistered(1), deadline.timeLeft()); + + Await.ready(registrationFuture, deadline.timeLeft()); + + // + // Location lookup + // + LookupKvStateLocation lookupNonExistingJob = new LookupKvStateLocation( + new JobID(), + "any-name"); + + Future lookupFuture = jobManager + .ask(lookupNonExistingJob, deadline.timeLeft()) + .mapTo(ClassTag$.MODULE$.apply(KvStateLocation.class)); + + try { + Await.result(lookupFuture, deadline.timeLeft()); + fail("Did not throw expected Exception"); + } catch (IllegalStateException ignored) { + // Expected + } + + JobGraph jobGraph = new JobGraph("croissant"); + JobVertex jobVertex1 = new JobVertex("cappuccino"); + jobVertex1.setParallelism(4); + jobVertex1.setInvokableClass(Tasks.BlockingNoOpInvokable.class); + + JobVertex jobVertex2 = new JobVertex("americano"); + jobVertex2.setParallelism(4); + jobVertex2.setInvokableClass(Tasks.BlockingNoOpInvokable.class); + + jobGraph.addVertex(jobVertex1); + jobGraph.addVertex(jobVertex2); + + Future submitFuture = jobManager + .ask(new SubmitJob(jobGraph, ListeningBehaviour.DETACHED), deadline.timeLeft()) + .mapTo(ClassTag$.MODULE$.apply(JobSubmitSuccess.class)); + + Await.result(submitFuture, deadline.timeLeft()); + + Object lookupUnknownRegistrationName = new LookupKvStateLocation( + jobGraph.getJobID(), + "unknown"); + + lookupFuture = jobManager + .ask(lookupUnknownRegistrationName, deadline.timeLeft()) + .mapTo(ClassTag$.MODULE$.apply(KvStateLocation.class)); + + try { + Await.result(lookupFuture, deadline.timeLeft()); + fail("Did not throw expected Exception"); + } catch (UnknownKvStateLocation ignored) { + // Expected + } + + // + // Registration + // + NotifyKvStateRegistered registerNonExistingJob = new NotifyKvStateRegistered( + new JobID(), + new JobVertexID(), + 0, + "any-name", + new KvStateID(), + new KvStateServerAddress(InetAddress.getLocalHost(), 1233)); + + jobManager.tell(registerNonExistingJob); + + LookupKvStateLocation lookupAfterRegistration = new LookupKvStateLocation( + registerNonExistingJob.getJobId(), + registerNonExistingJob.getRegistrationName()); + + lookupFuture = jobManager + .ask(lookupAfterRegistration, deadline.timeLeft()) + .mapTo(ClassTag$.MODULE$.apply(KvStateLocation.class)); + + try { + Await.result(lookupFuture, deadline.timeLeft()); + fail("Did not throw expected Exception"); + } catch (IllegalStateException ignored) { + // Expected + } + + NotifyKvStateRegistered registerForExistingJob = new NotifyKvStateRegistered( + jobGraph.getJobID(), + jobVertex1.getID(), + 0, + "register-me", + new KvStateID(), + new KvStateServerAddress(InetAddress.getLocalHost(), 1293)); + + jobManager.tell(registerForExistingJob); + + lookupAfterRegistration = new LookupKvStateLocation( + registerForExistingJob.getJobId(), + registerForExistingJob.getRegistrationName()); + + lookupFuture = jobManager + .ask(lookupAfterRegistration, deadline.timeLeft()) + .mapTo(ClassTag$.MODULE$.apply(KvStateLocation.class)); + + KvStateLocation location = Await.result(lookupFuture, deadline.timeLeft()); + assertNotNull(location); + + assertEquals(jobGraph.getJobID(), location.getJobId()); + assertEquals(jobVertex1.getID(), location.getJobVertexId()); + assertEquals(jobVertex1.getParallelism(), location.getNumKeyGroups()); + assertEquals(1, location.getNumRegisteredKeyGroups()); + int keyGroupIndex = registerForExistingJob.getKeyGroupIndex(); + assertEquals(registerForExistingJob.getKvStateId(), location.getKvStateID(keyGroupIndex)); + assertEquals(registerForExistingJob.getKvStateServerAddress(), location.getKvStateServerAddress(keyGroupIndex)); + + // + // Unregistration + // + NotifyKvStateUnregistered unregister = new NotifyKvStateUnregistered( + registerForExistingJob.getJobId(), + registerForExistingJob.getJobVertexId(), + registerForExistingJob.getKeyGroupIndex(), + registerForExistingJob.getRegistrationName()); + + jobManager.tell(unregister); + + lookupFuture = jobManager + .ask(lookupAfterRegistration, deadline.timeLeft()) + .mapTo(ClassTag$.MODULE$.apply(KvStateLocation.class)); + + try { + Await.result(lookupFuture, deadline.timeLeft()); + fail("Did not throw expected Exception"); + } catch (UnknownKvStateLocation ignored) { + // Expected + } + + // + // Duplicate registration fails task + // + NotifyKvStateRegistered register = new NotifyKvStateRegistered( + jobGraph.getJobID(), + jobVertex1.getID(), + 0, + "duplicate-me", + new KvStateID(), + new KvStateServerAddress(InetAddress.getLocalHost(), 1293)); + + NotifyKvStateRegistered duplicate = new NotifyKvStateRegistered( + jobGraph.getJobID(), + jobVertex2.getID(), // <--- different operator, but... + 0, + "duplicate-me", // ...same name + new KvStateID(), + new KvStateServerAddress(InetAddress.getLocalHost(), 1293)); + + Future failedFuture = jobManager + .ask(new NotifyWhenJobStatus(jobGraph.getJobID(), JobStatus.FAILED), deadline.timeLeft()) + .mapTo(ClassTag$.MODULE$.apply(JobStatusIs.class)); + + jobManager.tell(register); + jobManager.tell(duplicate); + + // Wait for failure + JobStatusIs jobStatus = Await.result(failedFuture, deadline.timeLeft()); + assertEquals(JobStatus.FAILED, jobStatus.state()); + } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java index 5af34fb170cc5..87540bc40938d 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java @@ -34,6 +34,8 @@ import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider; import org.apache.flink.runtime.memory.MemoryManager; +import org.apache.flink.runtime.query.KvStateRegistry; +import org.apache.flink.runtime.query.TaskKvStateRegistry; import org.apache.flink.runtime.state.StateHandle; import org.apache.flink.runtime.taskmanager.TaskManagerRuntimeInfo; @@ -48,9 +50,17 @@ public class DummyEnvironment implements Environment { private final ExecutionAttemptID executionId = new ExecutionAttemptID(); private final ExecutionConfig executionConfig = new ExecutionConfig(); private final TaskInfo taskInfo; + private final KvStateRegistry kvStateRegistry = new KvStateRegistry(); + private final TaskKvStateRegistry taskKvStateRegistry; public DummyEnvironment(String taskName, int numSubTasks, int subTaskIndex) { this.taskInfo = new TaskInfo(taskName, subTaskIndex, numSubTasks, 0); + + this.taskKvStateRegistry = kvStateRegistry.createTaskRegistry(jobId, jobVertexId); + } + + public KvStateRegistry getKvStateRegistry() { + return kvStateRegistry; } @Override @@ -133,6 +143,11 @@ public AccumulatorRegistry getAccumulatorRegistry() { return null; } + @Override + public TaskKvStateRegistry getTaskKvStateRegistry() { + return taskKvStateRegistry; + } + @Override public void acknowledgeCheckpoint(long checkpointId) {} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java.orig b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java.orig new file mode 100644 index 0000000000000..393ee4c920223 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java.orig @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.operators.testutils; + +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.JobID; +import org.apache.flink.api.common.TaskInfo; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.core.fs.Path; +import org.apache.flink.metrics.groups.TaskMetricGroup; +import org.apache.flink.runtime.accumulators.AccumulatorRegistry; +import org.apache.flink.runtime.broadcast.BroadcastVariableManager; +import org.apache.flink.runtime.execution.Environment; +import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; +import org.apache.flink.runtime.io.disk.iomanager.IOManager; +import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter; +import org.apache.flink.runtime.io.network.partition.consumer.InputGate; +import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider; +import org.apache.flink.runtime.memory.MemoryManager; +import org.apache.flink.runtime.query.KvStateRegistry; +import org.apache.flink.runtime.query.TaskKvStateRegistry; +import org.apache.flink.runtime.state.StateHandle; +import org.apache.flink.runtime.taskmanager.TaskManagerRuntimeInfo; + +import java.util.Collections; +import java.util.Map; +import java.util.concurrent.Future; + +public class DummyEnvironment implements Environment { + + private final JobID jobId = new JobID(); + private final JobVertexID jobVertexId = new JobVertexID(); + private final ExecutionAttemptID executionId = new ExecutionAttemptID(); + private final ExecutionConfig executionConfig = new ExecutionConfig(); +<<<<<<< 9a73dbc71b83080b7deccc62b8b6ffa9f102e847 + private final TaskInfo taskInfo; +======= + private final KvStateRegistry kvStateRegistry = new KvStateRegistry(); + private final TaskKvStateRegistry taskKvStateRegistry; +>>>>>>> [FLINK-3779] [runtime] Add KvStateRegistry for queryable KvState + + public DummyEnvironment(String taskName, int numSubTasks, int subTaskIndex) { + this.taskInfo = new TaskInfo(taskName, subTaskIndex, numSubTasks, 0); + + this.taskKvStateRegistry = kvStateRegistry.createTaskRegistry(jobId, jobVertexId); + } + + public KvStateRegistry getKvStateRegistry() { + return kvStateRegistry; + } + + @Override + public ExecutionConfig getExecutionConfig() { + return executionConfig; + } + + @Override + public JobID getJobID() { + return jobId; + } + + @Override + public JobVertexID getJobVertexId() { + return jobVertexId; + } + + @Override + public ExecutionAttemptID getExecutionId() { + return executionId; + } + + @Override + public Configuration getTaskConfiguration() { + return new Configuration(); + } + + @Override + public TaskManagerRuntimeInfo getTaskManagerInfo() { + return null; + } + + @Override + public TaskMetricGroup getMetricGroup() { + return new UnregisteredTaskMetricsGroup(); + } + + @Override + public Configuration getJobConfiguration() { + return new Configuration(); + } + + @Override + public TaskInfo getTaskInfo() { + return taskInfo; + } + + @Override + public InputSplitProvider getInputSplitProvider() { + return null; + } + + @Override + public IOManager getIOManager() { + return null; + } + + @Override + public MemoryManager getMemoryManager() { + return null; + } + + @Override + public ClassLoader getUserClassLoader() { + return getClass().getClassLoader(); + } + + @Override + public Map> getDistributedCacheEntries() { + return Collections.emptyMap(); + } + + @Override + public BroadcastVariableManager getBroadcastVariableManager() { + return null; + } + + @Override + public AccumulatorRegistry getAccumulatorRegistry() { + return null; + } + + @Override +<<<<<<< 9a73dbc71b83080b7deccc62b8b6ffa9f102e847 + public void acknowledgeCheckpoint(long checkpointId) {} +======= + public TaskKvStateRegistry getTaskKvStateRegistry() { + return taskKvStateRegistry; + } + + @Override + public void acknowledgeCheckpoint(long checkpointId) { + } +>>>>>>> [FLINK-3779] [runtime] Add KvStateRegistry for queryable KvState + + @Override + public void acknowledgeCheckpoint(long checkpointId, StateHandle state) {} + + @Override + public ResultPartitionWriter getWriter(int index) { + return null; + } + + @Override + public ResultPartitionWriter[] getAllWriters() { + return null; + } + + @Override + public InputGate getInputGate(int index) { + return null; + } + + @Override + public InputGate[] getAllInputGates() { + return null; + } + +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java index 9dea3242e09a8..7b966c348de1c 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java @@ -31,7 +31,6 @@ import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.io.disk.iomanager.IOManager; import org.apache.flink.runtime.io.disk.iomanager.IOManagerAsync; -import org.apache.flink.runtime.io.network.partition.consumer.IteratorWrappingTestSingleInputGate; import org.apache.flink.runtime.io.network.api.serialization.AdaptiveSpanningRecordDeserializer; import org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer; import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter; @@ -39,10 +38,13 @@ import org.apache.flink.runtime.io.network.buffer.BufferProvider; import org.apache.flink.runtime.io.network.buffer.BufferRecycler; import org.apache.flink.runtime.io.network.partition.consumer.InputGate; +import org.apache.flink.runtime.io.network.partition.consumer.IteratorWrappingTestSingleInputGate; import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider; import org.apache.flink.runtime.memory.MemoryManager; import org.apache.flink.runtime.metrics.groups.TaskMetricGroup; +import org.apache.flink.runtime.query.KvStateRegistry; +import org.apache.flink.runtime.query.TaskKvStateRegistry; import org.apache.flink.runtime.state.StateHandle; import org.apache.flink.runtime.taskmanager.TaskManagerRuntimeInfo; import org.apache.flink.types.Record; @@ -89,6 +91,8 @@ public class MockEnvironment implements Environment { private final AccumulatorRegistry accumulatorRegistry; + private final TaskKvStateRegistry kvStateRegistry; + private final int bufferSize; public MockEnvironment(String taskName, long memorySize, MockInputSplitProvider inputSplitProvider, int bufferSize) { @@ -105,6 +109,9 @@ public MockEnvironment(String taskName, long memorySize, MockInputSplitProvider this.bufferSize = bufferSize; this.accumulatorRegistry = new AccumulatorRegistry(jobID, getExecutionId()); + + KvStateRegistry registry = new KvStateRegistry(); + this.kvStateRegistry = registry.createTaskRegistry(jobID, getJobVertexId()); } public IteratorWrappingTestSingleInputGate addInput(MutableObjectIterator inputIterator) { @@ -280,6 +287,11 @@ public AccumulatorRegistry getAccumulatorRegistry() { return this.accumulatorRegistry; } + @Override + public TaskKvStateRegistry getTaskKvStateRegistry() { + return kvStateRegistry; + } + @Override public void acknowledgeCheckpoint(long checkpointId) { throw new UnsupportedOperationException(); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/query/KvStateLocationRegistryTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/query/KvStateLocationRegistryTest.java new file mode 100644 index 0000000000000..70f0ba29fd1ba --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/query/KvStateLocationRegistryTest.java @@ -0,0 +1,231 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.query; + +import org.apache.flink.api.common.JobID; +import org.apache.flink.runtime.executiongraph.ExecutionJobVertex; +import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.junit.Test; + +import java.net.InetAddress; +import java.util.HashMap; +import java.util.Map; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.fail; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class KvStateLocationRegistryTest { + + /** + * Simple test registering/unregistereing state and looking it up again. + */ + @Test + public void testRegisterAndLookup() throws Exception { + String[] registrationNames = new String[] { + "TAsIrGnc7MULwVupNKZ0", + "086133IrGn0Ii2853237" }; + + ExecutionJobVertex[] vertices = new ExecutionJobVertex[] { + createJobVertex(32), + createJobVertex(13) }; + + // IDs for each key group of each vertex + KvStateID[][] ids = new KvStateID[vertices.length][]; + for (int i = 0; i < ids.length; i++) { + ids[i] = new KvStateID[vertices[i].getParallelism()]; + for (int j = 0; j < vertices[i].getParallelism(); j++) { + ids[i][j] = new KvStateID(); + } + } + + KvStateServerAddress server = new KvStateServerAddress(InetAddress.getLocalHost(), 12032); + + // Create registry + Map vertexMap = createVertexMap(vertices); + KvStateLocationRegistry registry = new KvStateLocationRegistry(new JobID(), vertexMap); + + // Register + for (int i = 0; i < vertices.length; i++) { + int numKeyGroups = vertices[i].getParallelism(); + for (int keyGroupIndex = 0; keyGroupIndex < numKeyGroups; keyGroupIndex++) { + // Register + registry.notifyKvStateRegistered( + vertices[i].getJobVertexId(), + keyGroupIndex, + registrationNames[i], + ids[i][keyGroupIndex], + server); + } + } + + // Verify all registrations + for (int i = 0; i < vertices.length; i++) { + KvStateLocation location = registry.getKvStateLocation(registrationNames[i]); + assertNotNull(location); + + int parallelism = vertices[i].getParallelism(); + for (int keyGroupIndex = 0; keyGroupIndex < parallelism; keyGroupIndex++) { + assertEquals(ids[i][keyGroupIndex], location.getKvStateID(keyGroupIndex)); + assertEquals(server, location.getKvStateServerAddress(keyGroupIndex)); + } + } + + // Unregister + for (int i = 0; i < vertices.length; i++) { + int numKeyGroups = vertices[i].getParallelism(); + JobVertexID jobVertexId = vertices[i].getJobVertexId(); + for (int keyGroupIndex = 0; keyGroupIndex < numKeyGroups; keyGroupIndex++) { + registry.notifyKvStateUnregistered(jobVertexId, keyGroupIndex, registrationNames[i]); + } + } + + for (int i = 0; i < registrationNames.length; i++) { + assertNull(registry.getKvStateLocation(registrationNames[i])); + } + } + + /** + * Tests that registrations with duplicate names throw an Exception. + */ + @Test + public void testRegisterDuplicateName() throws Exception { + ExecutionJobVertex[] vertices = new ExecutionJobVertex[] { + createJobVertex(32), + createJobVertex(13) }; + + Map vertexMap = createVertexMap(vertices); + + String registrationName = "duplicated-name"; + KvStateLocationRegistry registry = new KvStateLocationRegistry(new JobID(), vertexMap); + + // First operator registers + registry.notifyKvStateRegistered( + vertices[0].getJobVertexId(), + 0, + registrationName, + new KvStateID(), + new KvStateServerAddress(InetAddress.getLocalHost(), 12328)); + + try { + // Second operator registers same name + registry.notifyKvStateRegistered( + vertices[1].getJobVertexId(), + 0, + registrationName, + new KvStateID(), + new KvStateServerAddress(InetAddress.getLocalHost(), 12032)); + + fail("Did not throw expected Exception after duplicated name"); + } catch (IllegalStateException ignored) { + // Expected + } + } + + /** + * Tests exception on unregistration before registration. + */ + @Test + public void testUnregisterBeforeRegister() throws Exception { + ExecutionJobVertex vertex = createJobVertex(4); + Map vertexMap = createVertexMap(vertex); + + KvStateLocationRegistry registry = new KvStateLocationRegistry(new JobID(), vertexMap); + try { + registry.notifyKvStateUnregistered(vertex.getJobVertexId(), 0, "any-name"); + fail("Did not throw expected Exception, because of missing registration"); + } catch (IllegalArgumentException ignored) { + // Expected + } + } + + /** + * Tests failures during unregistration. + */ + @Test + public void testUnregisterFailures() throws Exception { + String name = "IrGnc73237TAs"; + + ExecutionJobVertex[] vertices = new ExecutionJobVertex[] { + createJobVertex(32), + createJobVertex(13) }; + + Map vertexMap = new HashMap<>(); + for (ExecutionJobVertex vertex : vertices) { + vertexMap.put(vertex.getJobVertexId(), vertex); + } + + KvStateLocationRegistry registry = new KvStateLocationRegistry(new JobID(), vertexMap); + + // First operator registers name + registry.notifyKvStateRegistered( + vertices[0].getJobVertexId(), + 0, + name, + new KvStateID(), + mock(KvStateServerAddress.class)); + + try { + // Unregister not registered keyGroupIndex + int notRegisteredKeyGroupIndex = 2; + + registry.notifyKvStateUnregistered( + vertices[0].getJobVertexId(), + notRegisteredKeyGroupIndex, + name); + + fail("Did not throw expected Exception"); + } catch (IllegalArgumentException expected) { + } + + try { + // Wrong operator tries to unregister + registry.notifyKvStateUnregistered( + vertices[1].getJobVertexId(), + 0, + name); + + fail("Did not throw expected Exception"); + } catch (IllegalArgumentException expected) { + } + } + + // ------------------------------------------------------------------------ + + private ExecutionJobVertex createJobVertex(int parallelism) { + JobVertexID id = new JobVertexID(); + ExecutionJobVertex vertex = mock(ExecutionJobVertex.class); + + when(vertex.getJobVertexId()).thenReturn(id); + when(vertex.getParallelism()).thenReturn(parallelism); + + return vertex; + } + + private Map createVertexMap(ExecutionJobVertex... vertices) { + Map vertexMap = new HashMap<>(); + for (ExecutionJobVertex vertex : vertices) { + vertexMap.put(vertex.getJobVertexId(), vertex); + } + return vertexMap; + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/query/KvStateLocationTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/query/KvStateLocationTest.java new file mode 100644 index 0000000000000..59ac575002344 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/query/KvStateLocationTest.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.query; + +import org.apache.flink.api.common.JobID; +import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.junit.Test; + +import java.net.InetAddress; + +import static org.junit.Assert.assertEquals; + +public class KvStateLocationTest { + + /** + * Simple test registering/unregistereing state and looking it up again. + */ + @Test + public void testRegisterAndLookup() throws Exception { + JobID jobId = new JobID(); + JobVertexID jobVertexId = new JobVertexID(); + int numKeyGroups = 123; + String registrationName = "asdasdasdasd"; + + KvStateLocation location = new KvStateLocation(jobId, jobVertexId, numKeyGroups, registrationName); + + KvStateID[] kvStateIds = new KvStateID[numKeyGroups]; + KvStateServerAddress[] serverAddresses = new KvStateServerAddress[numKeyGroups]; + + InetAddress host = InetAddress.getLocalHost(); + + // Register + for (int keyGroupIndex = 0; keyGroupIndex < numKeyGroups; keyGroupIndex++) { + kvStateIds[keyGroupIndex] = new KvStateID(); + serverAddresses[keyGroupIndex] = new KvStateServerAddress(host, 1024 + keyGroupIndex); + + location.registerKvState(keyGroupIndex, kvStateIds[keyGroupIndex], serverAddresses[keyGroupIndex]); + assertEquals(keyGroupIndex + 1, location.getNumRegisteredKeyGroups()); + } + + // Lookup + for (int keyGroupIndex = 0; keyGroupIndex < numKeyGroups; keyGroupIndex++) { + assertEquals(kvStateIds[keyGroupIndex], location.getKvStateID(keyGroupIndex)); + assertEquals(serverAddresses[keyGroupIndex], location.getKvStateServerAddress(keyGroupIndex)); + } + + // Overwrite + for (int keyGroupIndex = 0; keyGroupIndex < numKeyGroups; keyGroupIndex++) { + kvStateIds[keyGroupIndex] = new KvStateID(); + serverAddresses[keyGroupIndex] = new KvStateServerAddress(host, 1024 + keyGroupIndex); + + location.registerKvState(keyGroupIndex, kvStateIds[keyGroupIndex], serverAddresses[keyGroupIndex]); + assertEquals(numKeyGroups, location.getNumRegisteredKeyGroups()); + } + + // Lookup + for (int keyGroupIndex = 0; keyGroupIndex < numKeyGroups; keyGroupIndex++) { + assertEquals(kvStateIds[keyGroupIndex], location.getKvStateID(keyGroupIndex)); + assertEquals(serverAddresses[keyGroupIndex], location.getKvStateServerAddress(keyGroupIndex)); + } + + // Unregister + for (int keyGroupIndex = 0; keyGroupIndex < numKeyGroups; keyGroupIndex++) { + location.unregisterKvState(keyGroupIndex); + assertEquals(numKeyGroups - keyGroupIndex - 1, location.getNumRegisteredKeyGroups()); + } + + // Lookup + for (int keyGroupIndex = 0; keyGroupIndex < numKeyGroups; keyGroupIndex++) { + assertEquals(null, location.getKvStateID(keyGroupIndex)); + assertEquals(null, location.getKvStateServerAddress(keyGroupIndex)); + } + + assertEquals(0, location.getNumRegisteredKeyGroups()); + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java index 0c0d064b9725d..fc1a5df631563 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java @@ -43,6 +43,7 @@ import org.apache.flink.runtime.jobgraph.tasks.StatefulTask; import org.apache.flink.runtime.memory.MemoryManager; +import org.apache.flink.runtime.query.TaskKvStateRegistry; import org.apache.flink.runtime.state.StateHandle; import org.apache.flink.util.SerializedValue; import org.junit.Before; @@ -149,6 +150,8 @@ private static Task createTask() throws Exception { when(networkEnvironment.getPartitionManager()).thenReturn(partitionManager); when(networkEnvironment.getPartitionConsumableNotifier()).thenReturn(consumableNotifier); when(networkEnvironment.getDefaultIOMode()).thenReturn(IOManager.IOMode.SYNC); + when(networkEnvironment.createKvStateTaskRegistry(any(JobID.class), any(JobVertexID.class))) + .thenReturn(mock(TaskKvStateRegistry.class)); TaskDeploymentDescriptor tdd = new TaskDeploymentDescriptor( new JobID(), "Job Name", new JobVertexID(), new ExecutionAttemptID(), diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskManagerComponentsStartupShutdownTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskManagerComponentsStartupShutdownTest.java index 60bf8e734bac8..ca7157a33b482 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskManagerComponentsStartupShutdownTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskManagerComponentsStartupShutdownTest.java @@ -108,7 +108,8 @@ public void testComponentsStartupShutdown() { final NetworkEnvironment network = new NetworkEnvironment( TestingUtils.defaultExecutionContext(), timeout, - netConf); + netConf, + connectionInfo); final int numberOfSlots = 1; LeaderRetrievalService leaderRetrievalService = new StandaloneLeaderRetrievalService(jobManager.path().toString()); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskTest.java index fec9ef3599497..2f8e3dbaece65 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskTest.java @@ -19,9 +19,9 @@ package org.apache.flink.runtime.taskmanager; import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.JobID; import org.apache.flink.configuration.Configuration; import org.apache.flink.core.testutils.OneShotLatch; -import org.apache.flink.runtime.metrics.groups.TaskMetricGroup; import org.apache.flink.runtime.blob.BlobKey; import org.apache.flink.runtime.broadcast.BroadcastVariableManager; import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor; @@ -34,7 +34,6 @@ import org.apache.flink.runtime.filecache.FileCache; import org.apache.flink.runtime.instance.ActorGateway; import org.apache.flink.runtime.io.disk.iomanager.IOManager; -import org.apache.flink.api.common.JobID; import org.apache.flink.runtime.io.network.NetworkEnvironment; import org.apache.flink.runtime.io.network.partition.ResultPartitionConsumableNotifier; import org.apache.flink.runtime.io.network.partition.ResultPartitionID; @@ -45,12 +44,12 @@ import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable; import org.apache.flink.runtime.memory.MemoryManager; import org.apache.flink.runtime.messages.TaskMessages; - +import org.apache.flink.runtime.metrics.groups.TaskMetricGroup; +import org.apache.flink.runtime.query.TaskKvStateRegistry; import org.apache.flink.util.SerializedValue; import org.junit.After; import org.junit.Before; import org.junit.Test; - import scala.concurrent.duration.FiniteDuration; import java.io.IOException; @@ -69,7 +68,6 @@ import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; - import static org.mockito.Matchers.any; import static org.mockito.Matchers.eq; import static org.mockito.Mockito.doThrow; @@ -603,6 +601,8 @@ private Task createTask(Class invokable, when(network.getPartitionManager()).thenReturn(partitionManager); when(network.getPartitionConsumableNotifier()).thenReturn(consumableNotifier); when(network.getDefaultIOMode()).thenReturn(IOManager.IOMode.SYNC); + when(network.createKvStateTaskRegistry(any(JobID.class), any(JobVertexID.class))) + .thenReturn(mock(TaskKvStateRegistry.class)); return createTask(invokable, libCache, network); } diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java index 5237c627eb722..f9698a81e203c 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java @@ -37,6 +37,7 @@ import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.memory.MemoryManager; import org.apache.flink.runtime.operators.testutils.UnregisteredTaskMetricsGroup; +import org.apache.flink.runtime.query.TaskKvStateRegistry; import org.apache.flink.runtime.state.StateHandle; import org.apache.flink.runtime.taskmanager.Task; import org.apache.flink.runtime.taskmanager.TaskManagerRuntimeInfo; @@ -134,11 +135,15 @@ private static TaskDeploymentDescriptor createTaskDeploymentDescriptor( } private static Task createTask(TaskDeploymentDescriptor tdd) throws IOException { + NetworkEnvironment networkEnvironment = mock(NetworkEnvironment.class); + when(networkEnvironment.createKvStateTaskRegistry(any(JobID.class), any(JobVertexID.class))) + .thenReturn(mock(TaskKvStateRegistry.class)); + return new Task( tdd, mock(MemoryManager.class), mock(IOManager.class), - mock(NetworkEnvironment.class), + networkEnvironment, mock(BroadcastVariableManager.class), mock(ActorGateway.class), mock(ActorGateway.class), diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java index 70842088b61dc..05b8e8cabd15d 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java @@ -47,6 +47,8 @@ import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider; import org.apache.flink.runtime.plugable.DeserializationDelegate; import org.apache.flink.runtime.plugable.NonReusingDeserializationDelegate; +import org.apache.flink.runtime.query.KvStateRegistry; +import org.apache.flink.runtime.query.TaskKvStateRegistry; import org.apache.flink.runtime.state.StateHandle; import org.apache.flink.runtime.taskmanager.TaskManagerRuntimeInfo; @@ -91,6 +93,8 @@ public class StreamMockEnvironment implements Environment { private final AccumulatorRegistry accumulatorRegistry; + private final TaskKvStateRegistry kvStateRegistry; + private final int bufferSize; private final ExecutionConfig executionConfig; @@ -110,6 +114,9 @@ public StreamMockEnvironment(Configuration jobConfig, Configuration taskConfig, this.executionConfig = executionConfig; this.accumulatorRegistry = new AccumulatorRegistry(jobID, getExecutionId()); + + KvStateRegistry registry = new KvStateRegistry(); + this.kvStateRegistry = registry.createTaskRegistry(jobID, getJobVertexId()); } public StreamMockEnvironment(Configuration jobConfig, Configuration taskConfig, long memorySize, @@ -293,6 +300,11 @@ public AccumulatorRegistry getAccumulatorRegistry() { return accumulatorRegistry; } + @Override + public TaskKvStateRegistry getTaskKvStateRegistry() { + return kvStateRegistry; + } + @Override public void acknowledgeCheckpoint(long checkpointId) { } diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java index e9d583cd16dbc..bcd8a5f57043b 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java @@ -19,11 +19,9 @@ package org.apache.flink.streaming.runtime.tasks; import akka.actor.ActorRef; - import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.JobID; import org.apache.flink.configuration.Configuration; -import org.apache.flink.runtime.metrics.groups.TaskMetricGroup; import org.apache.flink.runtime.blob.BlobKey; import org.apache.flink.runtime.broadcast.BroadcastVariableManager; import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor; @@ -41,6 +39,8 @@ import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable; import org.apache.flink.runtime.memory.MemoryManager; +import org.apache.flink.runtime.metrics.groups.TaskMetricGroup; +import org.apache.flink.runtime.query.TaskKvStateRegistry; import org.apache.flink.runtime.taskmanager.Task; import org.apache.flink.runtime.taskmanager.TaskManagerRuntimeInfo; import org.apache.flink.streaming.api.functions.source.SourceFunction; @@ -49,10 +49,8 @@ import org.apache.flink.streaming.api.operators.StreamSource; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.util.ExceptionUtils; - import org.apache.flink.util.SerializedValue; import org.junit.Test; - import scala.concurrent.ExecutionContext; import scala.concurrent.Future; import scala.concurrent.duration.FiniteDuration; @@ -64,14 +62,13 @@ import java.util.UUID; import java.util.concurrent.TimeUnit; -import static org.mockito.Matchers.any; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; - import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; public class StreamTaskTest { @@ -134,6 +131,8 @@ private Task createTask(Class invokable, StreamConf when(network.getPartitionManager()).thenReturn(partitionManager); when(network.getPartitionConsumableNotifier()).thenReturn(consumableNotifier); when(network.getDefaultIOMode()).thenReturn(IOManager.IOMode.SYNC); + when(network.createKvStateTaskRegistry(any(JobID.class), any(JobVertexID.class))) + .thenReturn(mock(TaskKvStateRegistry.class)); TaskDeploymentDescriptor tdd = new TaskDeploymentDescriptor( new JobID(), "Job Name", new JobVertexID(), new ExecutionAttemptID(), From 5677f96ee49d633b7f7c6ade9c82e45e7eecc746 Mon Sep 17 00:00:00 2001 From: Ufuk Celebi Date: Mon, 30 May 2016 14:00:49 +0200 Subject: [PATCH 3/6] [FLINK-3779] [runtime] Add KvState network client and server - Adds a Netty-based server and client to query KvState instances, which have been published to the KvStateRegistry. --- .../io/network/netty/NettyBufferPool.java | 2 +- .../flink/runtime/jobgraph/JobVertexID.java | 11 +- .../apache/flink/runtime/query/KvStateID.java | 41 + .../runtime/query/KvStateServerAddress.java | 87 +++ .../netty/AtomicKvStateRequestStats.java | 94 +++ .../runtime/query/netty/ChunkedByteBuf.java | 97 +++ .../netty/DisabledKvStateRequestStats.java | 45 ++ .../runtime/query/netty/KvStateClient.java | 575 ++++++++++++++ .../query/netty/KvStateClientHandler.java | 104 +++ .../netty/KvStateClientHandlerCallback.java | 54 ++ .../query/netty/KvStateRequestStats.java | 53 ++ .../runtime/query/netty/KvStateServer.java | 243 ++++++ .../query/netty/KvStateServerHandler.java | 301 ++++++++ .../query/netty/UnknownKeyOrNamespace.java | 31 + .../runtime/query/netty/UnknownKvStateID.java | 35 + .../query/netty/message/KvStateRequest.java | 89 +++ .../netty/message/KvStateRequestFailure.java | 68 ++ .../netty/message/KvStateRequestResult.java | 74 ++ .../message/KvStateRequestSerializer.java | 518 +++++++++++++ .../netty/message/KvStateRequestType.java | 40 + .../runtime/query/netty/package-info.java | 80 ++ .../runtime/util/DataInputDeserializer.java | 8 + .../query/netty/KvStateClientHandlerTest.java | 110 +++ .../query/netty/KvStateClientTest.java | 718 ++++++++++++++++++ .../query/netty/KvStateServerHandlerTest.java | 622 +++++++++++++++ .../query/netty/KvStateServerTest.java | 174 +++++ .../message/KvStateRequestSerializerTest.java | 258 +++++++ 27 files changed, 4525 insertions(+), 7 deletions(-) create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/query/KvStateID.java create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/query/KvStateServerAddress.java create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/AtomicKvStateRequestStats.java create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/ChunkedByteBuf.java create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/DisabledKvStateRequestStats.java create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/KvStateClient.java create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/KvStateClientHandler.java create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/KvStateClientHandlerCallback.java create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/KvStateRequestStats.java create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/KvStateServer.java create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/KvStateServerHandler.java create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/UnknownKeyOrNamespace.java create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/UnknownKvStateID.java create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/message/KvStateRequest.java create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/message/KvStateRequestFailure.java create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/message/KvStateRequestResult.java create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/message/KvStateRequestSerializer.java create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/message/KvStateRequestType.java create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/package-info.java create mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateClientHandlerTest.java create mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateClientTest.java create mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateServerHandlerTest.java create mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateServerTest.java create mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/message/KvStateRequestSerializerTest.java diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyBufferPool.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyBufferPool.java index 6d09f26c00956..4a88b34381e9e 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyBufferPool.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyBufferPool.java @@ -57,7 +57,7 @@ public class NettyBufferPool implements ByteBufAllocator { * @param numberOfArenas Number of arenas (recommended: 2 * number of task * slots) */ - NettyBufferPool(int numberOfArenas) { + public NettyBufferPool(int numberOfArenas) { checkArgument(numberOfArenas >= 1, "Number of arenas"); this.numberOfArenas = numberOfArenas; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/JobVertexID.java b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/JobVertexID.java index 514aabc04743f..1f78b211e6648 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/JobVertexID.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/JobVertexID.java @@ -18,10 +18,10 @@ package org.apache.flink.runtime.jobgraph; -import javax.xml.bind.DatatypeConverter; - import org.apache.flink.util.AbstractID; +import javax.xml.bind.DatatypeConverter; + /** * A class for statistically unique job vertex IDs. */ @@ -32,15 +32,14 @@ public class JobVertexID extends AbstractID { public JobVertexID() { super(); } + public JobVertexID(byte[] bytes) { + super(bytes); + } public JobVertexID(long lowerPart, long upperPart) { super(lowerPart, upperPart); } - public JobVertexID(byte[] bytes) { - super(bytes); - } - public static JobVertexID fromHexString(String hexString) { return new JobVertexID(DatatypeConverter.parseHexBinary(hexString)); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/query/KvStateID.java b/flink-runtime/src/main/java/org/apache/flink/runtime/query/KvStateID.java new file mode 100644 index 0000000000000..bb05833be7bd8 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/query/KvStateID.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.query; + +import org.apache.flink.runtime.state.KvState; +import org.apache.flink.util.AbstractID; + +/** + * Identifier for {@link KvState} instances. + * + *

Assigned when registering state at the {@link KvStateRegistry}. + */ +public class KvStateID extends AbstractID { + + private static final long serialVersionUID = 1L; + + public KvStateID() { + super(); + } + + public KvStateID(long lowerPart, long upperPart) { + super(lowerPart, upperPart); + } + +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/query/KvStateServerAddress.java b/flink-runtime/src/main/java/org/apache/flink/runtime/query/KvStateServerAddress.java new file mode 100644 index 0000000000000..7887ed11c648d --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/query/KvStateServerAddress.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.query; + +import org.apache.flink.runtime.query.netty.KvStateServer; +import org.apache.flink.util.Preconditions; + +import java.io.Serializable; +import java.net.InetAddress; + +/** + * The (host, port)-address of a {@link KvStateServer}. + */ +public class KvStateServerAddress implements Serializable { + + private static final long serialVersionUID = 1L; + + /** KvStateServer host address. */ + private final InetAddress hostAddress; + + /** KvStateServer port. */ + private final int port; + + /** + * Creates a KvStateServerAddress for the given KvStateServer host address + * and port. + * + * @param hostAddress KvStateServer host address + * @param port KvStateServer port + */ + public KvStateServerAddress(InetAddress hostAddress, int port) { + this.hostAddress = Preconditions.checkNotNull(hostAddress, "Host address"); + Preconditions.checkArgument(port > 0 && port <= 65535, "Port " + port + " is out of range 1-65535"); + this.port = port; + } + + /** + * Returns the host address of the KvStateServer. + * + * @return KvStateServer host address + */ + public InetAddress getHost() { + return hostAddress; + } + + /** + * Returns the port of the KvStateServer. + * + * @return KvStateServer port + */ + public int getPort() { + return port; + } + + @Override + public boolean equals(Object o) { + if (this == o) { return true; } + if (o == null || getClass() != o.getClass()) { return false; } + + KvStateServerAddress that = (KvStateServerAddress) o; + + return port == that.port && hostAddress.equals(that.hostAddress); + } + + @Override + public int hashCode() { + int result = hostAddress.hashCode(); + result = 31 * result + port; + return result; + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/AtomicKvStateRequestStats.java b/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/AtomicKvStateRequestStats.java new file mode 100644 index 0000000000000..2fca4a8790566 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/AtomicKvStateRequestStats.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.query.netty; + +import java.util.concurrent.atomic.AtomicLong; + +/** + * Atomic {@link KvStateRequestStats} implementation. + */ +public class AtomicKvStateRequestStats implements KvStateRequestStats { + + /** + * Number of active connections. + */ + private final AtomicLong numConnections = new AtomicLong(); + + /** + * Total number of reported requests. + */ + private final AtomicLong numRequests = new AtomicLong(); + + /** + * Total number of successful requests (<= reported requests). + */ + private final AtomicLong numSuccessful = new AtomicLong(); + + /** + * Total duration of all successful requests. + */ + private final AtomicLong successfulDuration = new AtomicLong(); + + /** + * Total number of failed requests (<= reported requests). + */ + private final AtomicLong numFailed = new AtomicLong(); + + @Override + public void reportActiveConnection() { + numConnections.incrementAndGet(); + } + + @Override + public void reportInactiveConnection() { + numConnections.decrementAndGet(); + } + + @Override + public void reportRequest() { + numRequests.incrementAndGet(); + } + + @Override + public void reportSuccessfulRequest(long durationTotalMillis) { + numSuccessful.incrementAndGet(); + successfulDuration.addAndGet(durationTotalMillis); + } + + @Override + public void reportFailedRequest() { + numFailed.incrementAndGet(); + } + + public long getNumConnections() { + return numConnections.get(); + } + + public long getNumRequests() { + return numRequests.get(); + } + + public long getNumSuccessful() { + return numSuccessful.get(); + } + + public long getNumFailed() { + return numFailed.get(); + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/ChunkedByteBuf.java b/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/ChunkedByteBuf.java new file mode 100644 index 0000000000000..6d32489cf4078 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/ChunkedByteBuf.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.query.netty; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.stream.ChunkedInput; +import io.netty.handler.stream.ChunkedWriteHandler; +import org.apache.flink.util.Preconditions; + +/** + * A {@link ByteBuf} instance to be consumed in chunks by {@link ChunkedWriteHandler}, + * respecting the high and low watermarks. + * + * @see Low/High Watermarks + */ +class ChunkedByteBuf implements ChunkedInput { + + /** The buffer to chunk */ + private final ByteBuf buf; + + /** Size of chunks */ + private final int chunkSize; + + /** Closed flag */ + private boolean isClosed; + + /** End of input flag */ + private boolean isEndOfInput; + + public ChunkedByteBuf(ByteBuf buf, int chunkSize) { + this.buf = Preconditions.checkNotNull(buf, "Buffer"); + Preconditions.checkArgument(chunkSize > 0, "Non-positive chunk size"); + this.chunkSize = chunkSize; + } + + @Override + public boolean isEndOfInput() throws Exception { + return isClosed || isEndOfInput; + } + + @Override + public void close() throws Exception { + if (!isClosed) { + // If we did not consume the whole buffer yet, we have to release + // it here. Otherwise, it's the responsibility of the consumer. + if (!isEndOfInput) { + buf.release(); + } + + isClosed = true; + } + } + + @Override + public ByteBuf readChunk(ChannelHandlerContext ctx) throws Exception { + if (isClosed) { + return null; + } else if (buf.readableBytes() <= chunkSize) { + isEndOfInput = true; + + // Don't retain as the consumer is responsible to release it + return buf.slice(); + } else { + // Return a chunk sized slice of the buffer. The ref count is + // shared with the original buffer. That's why we need to retain + // a reference here. + return buf.readSlice(chunkSize).retain(); + } + } + + @Override + public String toString() { + return "ChunkedByteBuf{" + + "buf=" + buf + + ", chunkSize=" + chunkSize + + ", isClosed=" + isClosed + + ", isEndOfInput=" + isEndOfInput + + '}'; + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/DisabledKvStateRequestStats.java b/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/DisabledKvStateRequestStats.java new file mode 100644 index 0000000000000..de8824ddb3a78 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/DisabledKvStateRequestStats.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.query.netty; + +/** + * Disabled {@link KvStateRequestStats} implementation. + */ +public class DisabledKvStateRequestStats implements KvStateRequestStats { + + @Override + public void reportActiveConnection() { + } + + @Override + public void reportInactiveConnection() { + } + + @Override + public void reportRequest() { + } + + @Override + public void reportSuccessfulRequest(long durationTotalMillis) { + } + + @Override + public void reportFailedRequest() { + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/KvStateClient.java b/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/KvStateClient.java new file mode 100644 index 0000000000000..6cfe86b304d08 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/KvStateClient.java @@ -0,0 +1,575 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.query.netty; + +import akka.dispatch.Futures; +import com.google.common.util.concurrent.ThreadFactoryBuilder; +import io.netty.bootstrap.Bootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelOption; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.SocketChannel; +import io.netty.channel.socket.nio.NioSocketChannel; +import io.netty.handler.codec.LengthFieldBasedFrameDecoder; +import io.netty.handler.stream.ChunkedWriteHandler; +import org.apache.flink.runtime.io.network.netty.NettyBufferPool; +import org.apache.flink.runtime.query.KvStateID; +import org.apache.flink.runtime.query.KvStateServerAddress; +import org.apache.flink.runtime.query.netty.message.KvStateRequestSerializer; +import org.apache.flink.util.Preconditions; +import scala.concurrent.Future; +import scala.concurrent.Promise; + +import java.nio.channels.ClosedChannelException; +import java.util.ArrayDeque; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; + +/** + * Netty-based client querying {@link KvStateServer} instances. + * + *

This client can be used by multiple threads concurrently. Operations are + * executed asynchronously and return Futures to their result. + * + *

The incoming pipeline looks as follows: + *

+ * Socket.read() -> LengthFieldBasedFrameDecoder -> KvStateServerHandler
+ * 
+ * + *

Received binary messages are expected to contain a frame length field. Netty's + * {@link LengthFieldBasedFrameDecoder} is used to fully receive the frame before + * giving it to our {@link KvStateClientHandler}. + * + *

Connections are established and closed by the client. The server only + * closes the connection on a fatal failure that cannot be recovered. + */ +public class KvStateClient { + + /** Netty's Bootstrap. */ + private final Bootstrap bootstrap; + + /** Statistics tracker */ + private final KvStateRequestStats stats; + + /** Established connections. */ + private final ConcurrentHashMap establishedConnections = + new ConcurrentHashMap<>(); + + /** Pending connections. */ + private final ConcurrentHashMap pendingConnections = + new ConcurrentHashMap<>(); + + /** Atomic shut down flag. */ + private final AtomicBoolean shutDown = new AtomicBoolean(); + + /** + * Creates a client with the specified number of event loop threads. + * + * @param numEventLoopThreads Number of event loop threads (minimum 1). + */ + public KvStateClient(int numEventLoopThreads, KvStateRequestStats stats) { + Preconditions.checkArgument(numEventLoopThreads >= 1, "Non-positive number of event loop threads."); + NettyBufferPool bufferPool = new NettyBufferPool(numEventLoopThreads); + + ThreadFactory threadFactory = new ThreadFactoryBuilder() + .setDaemon(true) + .setNameFormat("Flink KvStateClient Event Loop Thread %d") + .build(); + + NioEventLoopGroup nioGroup = new NioEventLoopGroup(numEventLoopThreads, threadFactory); + + this.bootstrap = new Bootstrap() + .group(nioGroup) + .channel(NioSocketChannel.class) + .option(ChannelOption.ALLOCATOR, bufferPool) + .handler(new ChannelInitializer() { + @Override + protected void initChannel(SocketChannel ch) throws Exception { + ch.pipeline() + .addLast(new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4)) + // ChunkedWriteHandler respects Channel writability + .addLast(new ChunkedWriteHandler()); + } + }); + + this.stats = Preconditions.checkNotNull(stats, "Statistics tracker"); + } + + /** + * Returns a future holding the serialized request result. + * + *

If the server does not serve a KvState instance with the given ID, + * the Future will be failed with a {@link UnknownKvStateID}. + * + *

If the KvState instance does not hold any data for the given key + * and namespace, the Future will be failed with a {@link UnknownKeyOrNamespace}. + * + *

All other failures are forwarded to the Future. + * + * @param serverAddress Address of the server to query + * @param kvStateId ID of the KvState instance to query + * @param serializedKeyAndNamespace Serialized key and namespace to query KvState instance with + * @return Future holding the serialized result + */ + public Future getKvState( + KvStateServerAddress serverAddress, + KvStateID kvStateId, + byte[] serializedKeyAndNamespace) { + + if (shutDown.get()) { + return Futures.failed(new IllegalStateException("Shut down")); + } + + EstablishedConnection connection = establishedConnections.get(serverAddress); + + if (connection != null) { + return connection.getKvState(kvStateId, serializedKeyAndNamespace); + } else { + PendingConnection pendingConnection = pendingConnections.get(serverAddress); + if (pendingConnection != null) { + // There was a race, use the existing pending connection. + return pendingConnection.getKvState(kvStateId, serializedKeyAndNamespace); + } else { + // We try to connect to the server. + PendingConnection pending = new PendingConnection(serverAddress); + PendingConnection previous = pendingConnections.putIfAbsent(serverAddress, pending); + + if (previous == null) { + // OK, we are responsible to connect. + bootstrap.connect(serverAddress.getHost(), serverAddress.getPort()) + .addListener(pending); + + return pending.getKvState(kvStateId, serializedKeyAndNamespace); + } else { + // There was a race, use the existing pending connection. + return previous.getKvState(kvStateId, serializedKeyAndNamespace); + } + } + } + } + + /** + * Shuts down the client and closes all connections. + * + *

After a call to this method, all returned futures will be failed. + */ + public void shutDown() { + if (shutDown.compareAndSet(false, true)) { + for (Map.Entry conn : establishedConnections.entrySet()) { + if (establishedConnections.remove(conn.getKey(), conn.getValue())) { + conn.getValue().close(); + } + } + + for (Map.Entry conn : pendingConnections.entrySet()) { + if (pendingConnections.remove(conn.getKey()) != null) { + conn.getValue().close(); + } + } + + if (bootstrap != null) { + EventLoopGroup group = bootstrap.group(); + if (group != null) { + group.shutdownGracefully(); + } + } + } + } + + /** + * Closes the connection to the given server address if it exists. + * + *

If there is a request to the server a new connection will be established. + * + * @param serverAddress Target address of the connection to close + */ + public void closeConnection(KvStateServerAddress serverAddress) { + PendingConnection pending = pendingConnections.get(serverAddress); + if (pending != null) { + pending.close(); + } + + EstablishedConnection established = establishedConnections.remove(serverAddress); + if (established != null) { + established.close(); + } + } + + /** + * A pending connection that is in the process of connecting. + */ + private class PendingConnection implements ChannelFutureListener { + + /** Lock to guard the connect call, channel hand in, etc. */ + private final Object connectLock = new Object(); + + /** Address of the server we are connecting to. */ + private final KvStateServerAddress serverAddress; + + /** Queue of requests while connecting. */ + private final ArrayDeque queuedRequests = new ArrayDeque<>(); + + /** The established connection after the connect succeeds. */ + private EstablishedConnection established; + + /** Closed flag. */ + private boolean closed; + + /** Failure cause if something goes wrong. */ + private Throwable failureCause; + + /** + * Creates a pending connection to the given server. + * + * @param serverAddress Address of the server to connect to. + */ + private PendingConnection(KvStateServerAddress serverAddress) { + this.serverAddress = serverAddress; + } + + @Override + public void operationComplete(ChannelFuture future) throws Exception { + // Callback from the Bootstrap's connect call. + if (future.isSuccess()) { + handInChannel(future.channel()); + } else { + close(future.cause()); + } + } + + /** + * Returns a future holding the serialized request result. + * + *

If the channel has been established, forward the call to the + * established channel, otherwise queue it for when the channel is + * handed in. + * + * @param kvStateId ID of the KvState instance to query + * @param serializedKeyAndNamespace Serialized key and namespace to query KvState instance + * with + * @return Future holding the serialized result + */ + public Future getKvState(KvStateID kvStateId, byte[] serializedKeyAndNamespace) { + synchronized (connectLock) { + if (failureCause != null) { + return Futures.failed(failureCause); + } else if (closed) { + return Futures.failed(new ClosedChannelException()); + } else { + if (established != null) { + return established.getKvState(kvStateId, serializedKeyAndNamespace); + } else { + // Queue this and handle when connected + PendingRequest pending = new PendingRequest(kvStateId, serializedKeyAndNamespace); + queuedRequests.add(pending); + return pending.promise.future(); + } + } + } + } + + /** + * Hands in a channel after a successful connection. + * + * @param channel Channel to hand in + */ + private void handInChannel(Channel channel) { + synchronized (connectLock) { + if (closed || failureCause != null) { + // Close the channel and we are done. Any queued requests + // are removed on the close/failure call and after that no + // new ones can be enqueued. + channel.close(); + } else { + established = new EstablishedConnection(serverAddress, channel); + + PendingRequest pending; + while ((pending = queuedRequests.poll()) != null) { + Future resultFuture = established.getKvState( + pending.kvStateId, + pending.serializedKeyAndNamespace); + + pending.promise.completeWith(resultFuture); + } + + // Publish the channel for the general public + establishedConnections.put(serverAddress, established); + pendingConnections.remove(serverAddress); + + // Check shut down for possible race with shut down. We + // don't want any lingering connections after shut down, + // which can happen if we don't check this here. + if (shutDown.get()) { + if (establishedConnections.remove(serverAddress, established)) { + established.close(); + } + } + } + } + } + + /** + * Close the connecting channel with a ClosedChannelException. + */ + private void close() { + close(new ClosedChannelException()); + } + + /** + * Close the connecting channel with an Exception (can be + * null) or forward to the established channel. + */ + private void close(Throwable cause) { + synchronized (connectLock) { + if (!closed) { + if (failureCause == null) { + failureCause = cause; + } + + if (established != null) { + established.close(); + } else { + PendingRequest pending; + while ((pending = queuedRequests.poll()) != null) { + pending.promise.tryFailure(cause); + } + } + + closed = true; + } + } + } + + /** + * A pending request queued while the channel is connecting. + */ + private final class PendingRequest { + + private final KvStateID kvStateId; + private final byte[] serializedKeyAndNamespace; + private final Promise promise; + + private PendingRequest(KvStateID kvStateId, byte[] serializedKeyAndNamespace) { + this.kvStateId = kvStateId; + this.serializedKeyAndNamespace = serializedKeyAndNamespace; + this.promise = Futures.promise(); + } + } + + @Override + public String toString() { + synchronized (connectLock) { + return "PendingConnection{" + + "serverAddress=" + serverAddress + + ", queuedRequests=" + queuedRequests.size() + + ", established=" + (established != null) + + ", closed=" + closed + + '}'; + } + } + } + + /** + * An established connection that wraps the actual channel instance and is + * registered at the {@link KvStateClientHandler} for callbacks. + */ + private class EstablishedConnection implements KvStateClientHandlerCallback { + + /** Address of the server we are connected to. */ + private final KvStateServerAddress serverAddress; + + /** The actual TCP channel. */ + private final Channel channel; + + /** Pending requests keyed by request ID. */ + private final ConcurrentHashMap pendingRequests = new ConcurrentHashMap<>(); + + /** Current request number used to assign unique request IDs. */ + private final AtomicLong requestCount = new AtomicLong(); + + /** Reference to a failure that was reported by the channel. */ + private final AtomicReference failureCause = new AtomicReference<>(); + + /** + * Creates an established connection with the given channel. + * + * @param serverAddress Address of the server connected to + * @param channel The actual TCP channel + */ + EstablishedConnection(KvStateServerAddress serverAddress, Channel channel) { + this.serverAddress = Preconditions.checkNotNull(serverAddress, "KvStateServerAddress"); + this.channel = Preconditions.checkNotNull(channel, "Channel"); + + // Add the client handler with the callback + channel.pipeline().addLast("KvStateClientHandler", new KvStateClientHandler(this)); + + stats.reportActiveConnection(); + } + + /** + * Close the channel with a ClosedChannelException. + */ + void close() { + close(new ClosedChannelException()); + } + + /** + * Close the channel with a cause. + * + * @param cause The cause to close the channel with. + * @return Channel close future + */ + private boolean close(Throwable cause) { + if (failureCause.compareAndSet(null, cause)) { + channel.close(); + stats.reportInactiveConnection(); + + for (long requestId : pendingRequests.keySet()) { + PromiseAndTimestamp pending = pendingRequests.remove(requestId); + if (pending != null && pending.promise.tryFailure(cause)) { + stats.reportFailedRequest(); + } + } + + return true; + } + + return false; + } + + /** + * Returns a future holding the serialized request result. + * + * @param kvStateId ID of the KvState instance to query + * @param serializedKeyAndNamespace Serialized key and namespace to query KvState instance + * with + * @return Future holding the serialized result + */ + Future getKvState(KvStateID kvStateId, byte[] serializedKeyAndNamespace) { + PromiseAndTimestamp requestPromiseTs = new PromiseAndTimestamp( + Futures.promise(), + System.nanoTime()); + + try { + final long requestId = requestCount.getAndIncrement(); + pendingRequests.put(requestId, requestPromiseTs); + + stats.reportRequest(); + + ByteBuf buf = KvStateRequestSerializer.serializeKvStateRequest( + channel.alloc(), + requestId, + kvStateId, + serializedKeyAndNamespace); + + channel.writeAndFlush(buf).addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + if (!future.isSuccess()) { + // Fail promise if not failed to write + PromiseAndTimestamp pending = pendingRequests.remove(requestId); + if (pending != null && pending.promise.tryFailure(future.cause())) { + stats.reportFailedRequest(); + } + } + } + }); + + // Check failure for possible race. We don't want any lingering + // promises after a failure, which can happen if we don't check + // this here. Note that close is treated as a failure as well. + Throwable failure = failureCause.get(); + if (failure != null) { + // Remove from pending requests to guard against concurrent + // removal and to make sure that we only count it once as failed. + PromiseAndTimestamp p = pendingRequests.remove(requestId); + if (p != null && p.promise.tryFailure(failure)) { + stats.reportFailedRequest(); + } + } + } catch (Throwable t) { + requestPromiseTs.promise.tryFailure(t); + } + + return requestPromiseTs.promise.future(); + } + + @Override + public void onRequestResult(long requestId, byte[] serializedValue) { + PromiseAndTimestamp pending = pendingRequests.remove(requestId); + if (pending != null && pending.promise.trySuccess(serializedValue)) { + long durationMillis = (System.nanoTime() - pending.timestamp) / 1_000_000; + stats.reportSuccessfulRequest(durationMillis); + } + } + + @Override + public void onRequestFailure(long requestId, Throwable cause) { + PromiseAndTimestamp pending = pendingRequests.remove(requestId); + if (pending != null && pending.promise.tryFailure(cause)) { + stats.reportFailedRequest(); + } + } + + @Override + public void onFailure(Throwable cause) { + if (close(cause)) { + // Remove from established channels, otherwise future + // requests will be handled by this failed channel. + establishedConnections.remove(serverAddress, this); + } + } + + @Override + public String toString() { + return "EstablishedConnection{" + + "serverAddress=" + serverAddress + + ", channel=" + channel + + ", pendingRequests=" + pendingRequests.size() + + ", requestCount=" + requestCount + + ", failureCause=" + failureCause + + '}'; + } + + /** + * Pair of promise and a timestamp. + */ + private class PromiseAndTimestamp { + + private final Promise promise; + private final long timestamp; + + public PromiseAndTimestamp(Promise promise, long timestamp) { + this.promise = promise; + this.timestamp = timestamp; + } + } + + } + +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/KvStateClientHandler.java b/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/KvStateClientHandler.java new file mode 100644 index 0000000000000..2166bf257ab0e --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/KvStateClientHandler.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.query.netty; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.util.ReferenceCountUtil; +import org.apache.flink.runtime.query.netty.message.KvStateRequestFailure; +import org.apache.flink.runtime.query.netty.message.KvStateRequestResult; +import org.apache.flink.runtime.query.netty.message.KvStateRequestSerializer; +import org.apache.flink.runtime.query.netty.message.KvStateRequestType; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.nio.channels.ClosedChannelException; + +/** + * This handler expects responses from {@link KvStateServerHandler}. + * + *

It deserializes the response and calls the registered callback, which is + * responsible for actually handling the result (see {@link KvStateClient.EstablishedConnection}). + */ +class KvStateClientHandler extends ChannelInboundHandlerAdapter { + + private static final Logger LOG = LoggerFactory.getLogger(KvStateClientHandler.class); + + private final KvStateClientHandlerCallback callback; + + /** + * Creates a {@link KvStateClientHandler} with the callback. + * + * @param callback Callback for responses. + */ + KvStateClientHandler(KvStateClientHandlerCallback callback) { + this.callback = callback; + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + try { + ByteBuf buf = (ByteBuf) msg; + KvStateRequestType msgType = KvStateRequestSerializer.deserializeHeader(buf); + + if (msgType == KvStateRequestType.REQUEST_RESULT) { + KvStateRequestResult result = KvStateRequestSerializer.deserializeKvStateRequestResult(buf); + callback.onRequestResult(result.getRequestId(), result.getSerializedResult()); + } else if (msgType == KvStateRequestType.REQUEST_FAILURE) { + KvStateRequestFailure failure = KvStateRequestSerializer.deserializeKvStateRequestFailure(buf); + callback.onRequestFailure(failure.getRequestId(), failure.getCause()); + } else if (msgType == KvStateRequestType.SERVER_FAILURE) { + throw KvStateRequestSerializer.deserializeServerFailure(buf); + } else { + throw new IllegalStateException("Unexpected response type '" + msgType + "'"); + } + } catch (Throwable t1) { + try { + callback.onFailure(t1); + } catch (Throwable t2) { + LOG.error("Failed to notify callback about failure", t2); + } + } finally { + ReferenceCountUtil.release(msg); + } + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + try { + callback.onFailure(cause); + } catch (Throwable t) { + LOG.error("Failed to notify callback about failure", t); + } + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + // Only the client is expected to close the channel. Otherwise it + // indicates a failure. Note that this will be invoked in both cases + // though. If the callback closed the channel, the callback must be + // ignored. + try { + callback.onFailure(new ClosedChannelException()); + } catch (Throwable t) { + LOG.error("Failed to notify callback about failure", t); + } + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/KvStateClientHandlerCallback.java b/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/KvStateClientHandlerCallback.java new file mode 100644 index 0000000000000..65ff78180b06e --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/KvStateClientHandlerCallback.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.query.netty; + +import org.apache.flink.runtime.query.netty.message.KvStateRequest; + +/** + * Callback for {@link KvStateClientHandler}. + */ +interface KvStateClientHandlerCallback { + + /** + * Called on a successful {@link KvStateRequest}. + * + * @param requestId ID of the request + * @param serializedValue Serialized value for the request + */ + void onRequestResult(long requestId, byte[] serializedValue); + + /** + * Called on a failed {@link KvStateRequest}. + * + * @param requestId ID of the request + * @param cause Cause of the request failure + */ + void onRequestFailure(long requestId, Throwable cause); + + /** + * Called on any failure, which is not related to a specific request. + * + *

This can be for example a caught Exception in the channel pipeline + * or an unexpected channel close. + * + * @param cause Cause of the failure + */ + void onFailure(Throwable cause); + +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/KvStateRequestStats.java b/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/KvStateRequestStats.java new file mode 100644 index 0000000000000..1c0d8d51218b2 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/KvStateRequestStats.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.query.netty; + +/** + * Simple statistics for {@link KvStateServer} monitoring. + */ +public interface KvStateRequestStats { + + /** + * Reports an active connection. + */ + void reportActiveConnection(); + + /** + * Reports an inactive connection. + */ + void reportInactiveConnection(); + + /** + * Reports an incoming request. + */ + void reportRequest(); + + /** + * Reports a successfully handled request. + * + * @param durationTotalMillis Duration of the request (in milliseconds). + */ + void reportSuccessfulRequest(long durationTotalMillis); + + /** + * Reports a failure during a request. + */ + void reportFailedRequest(); + +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/KvStateServer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/KvStateServer.java new file mode 100644 index 0000000000000..0c0c5ecc6528e --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/KvStateServer.java @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.query.netty; + +import com.google.common.util.concurrent.ThreadFactoryBuilder; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.channel.Channel; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelOption; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.SocketChannel; +import io.netty.channel.socket.nio.NioServerSocketChannel; +import io.netty.handler.codec.LengthFieldBasedFrameDecoder; +import io.netty.handler.stream.ChunkedWriteHandler; +import io.netty.util.concurrent.Future; +import org.apache.flink.runtime.io.network.netty.NettyBufferPool; +import org.apache.flink.runtime.query.KvStateRegistry; +import org.apache.flink.runtime.query.KvStateServerAddress; +import org.apache.flink.runtime.query.netty.message.KvStateRequest; +import org.apache.flink.util.Preconditions; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.TimeUnit; + +/** + * Netty-based server answering {@link KvStateRequest} messages. + * + *

Requests are handled by asynchronous query tasks (see {@link KvStateServerHandler.AsyncKvStateQueryTask}) + * that are executed by a separate query Thread pool. This pool is shared among + * all TCP connections. + * + *

The incoming pipeline looks as follows: + *

+ * Socket.read() -> LengthFieldBasedFrameDecoder -> KvStateServerHandler
+ * 
+ * + *

Received binary messages are expected to contain a frame length field. Netty's + * {@link LengthFieldBasedFrameDecoder} is used to fully receive the frame before + * giving it to our {@link KvStateServerHandler}. + * + *

Connections are established and closed by the client. The server only + * closes the connection on a fatal failure that cannot be recovered. A + * server-side connection close is considered a failure by the client. + */ +public class KvStateServer { + + private static final Logger LOG = LoggerFactory.getLogger(KvStateServer.class); + + /** Server config: low water mark */ + private static final int LOW_WATER_MARK = 8 * 1024; + + /** Server config: high water mark */ + private static final int HIGH_WATER_MARK = 32 * 1024; + + /** Netty's ServerBootstrap. */ + private final ServerBootstrap bootstrap; + + /** Query executor thread pool. */ + private final ExecutorService queryExecutor; + + /** Address of this server. */ + private KvStateServerAddress serverAddress; + + /** + * Creates the {@link KvStateServer}. + * + *

The server needs to be started via {@link #start()} in order to bind + * to the configured bind address. + * + * @param bindAddress Address to bind to + * @param bindPort Port to bind to. Pick random port if 0. + * @param numEventLoopThreads Number of event loop threads + * @param numQueryThreads Number of query threads + * @param kvStateRegistry KvStateRegistry to query for KvState instances + * @param stats Statistics tracker + */ + public KvStateServer( + InetAddress bindAddress, + int bindPort, + int numEventLoopThreads, + int numQueryThreads, + KvStateRegistry kvStateRegistry, + KvStateRequestStats stats) { + + Preconditions.checkArgument(bindPort >= 0 && bindPort <= 65536, "Port " + bindPort + + " is out of valid port range (0-65536)."); + + Preconditions.checkArgument(numEventLoopThreads >= 1, "Non-positive number of event loop threads."); + Preconditions.checkArgument(numQueryThreads >= 1, "Non-positive number of query threads."); + + Preconditions.checkNotNull(kvStateRegistry, "KvStateRegistry"); + Preconditions.checkNotNull(stats, "KvStateRequestStats"); + + NettyBufferPool bufferPool = new NettyBufferPool(numEventLoopThreads); + + ThreadFactory threadFactory = new ThreadFactoryBuilder() + .setDaemon(true) + .setNameFormat("Flink KvStateServer EventLoop Thread %d") + .build(); + + NioEventLoopGroup nioGroup = new NioEventLoopGroup(numEventLoopThreads, threadFactory); + + queryExecutor = createQueryExecutor(numQueryThreads); + + // Shared between all channels + KvStateServerHandler serverHandler = new KvStateServerHandler( + kvStateRegistry, + queryExecutor, + stats); + + bootstrap = new ServerBootstrap() + // Bind address and port + .localAddress(bindAddress, bindPort) + // NIO server channels + .group(nioGroup) + .channel(NioServerSocketChannel.class) + // Server channel Options + .option(ChannelOption.ALLOCATOR, bufferPool) + // Child channel options + .childOption(ChannelOption.ALLOCATOR, bufferPool) + .childOption(ChannelOption.WRITE_BUFFER_LOW_WATER_MARK, LOW_WATER_MARK) + .childOption(ChannelOption.WRITE_BUFFER_HIGH_WATER_MARK, HIGH_WATER_MARK) + // See initializer for pipeline details + .childHandler(new KvStateServerChannelInitializer(serverHandler)); + } + + /** + * Starts the server by binding to the configured bind address (blocking). + * + * @throws InterruptedException If interrupted during the bind operation + */ + public void start() throws InterruptedException { + Channel channel = bootstrap.bind().sync().channel(); + + InetSocketAddress localAddress = (InetSocketAddress) channel.localAddress(); + serverAddress = new KvStateServerAddress(localAddress.getAddress(), localAddress.getPort()); + } + + /** + * Returns the address of this server. + * + * @return Server address + * @throws IllegalStateException If server has not been started yet + */ + public KvStateServerAddress getAddress() { + if (serverAddress == null) { + throw new IllegalStateException("KvStateServer not started yet."); + } + + return serverAddress; + } + + /** + * Shuts down the server and all related thread pools. + */ + public void shutDown() { + if (bootstrap != null) { + EventLoopGroup group = bootstrap.group(); + if (group != null) { + Future shutDownFuture = group.shutdownGracefully(0, 10, TimeUnit.SECONDS); + try { + shutDownFuture.await(); + } catch (InterruptedException e) { + LOG.error("Interrupted during shut down", e); + } + } + } + + if (queryExecutor != null) { + queryExecutor.shutdown(); + } + + serverAddress = null; + } + + /** + * Creates a thread pool for the query execution. + * + * @param numQueryThreads Number of query threads. + * @return Thread pool for query execution + */ + private static ExecutorService createQueryExecutor(int numQueryThreads) { + ThreadFactory threadFactory = new ThreadFactoryBuilder() + .setDaemon(true) + .setNameFormat("Flink KvStateServer Query Thread %d") + .build(); + + return Executors.newFixedThreadPool(numQueryThreads, threadFactory); + } + + /** + * Channel pipeline initializer. + * + *

The request handler is shared, whereas the other handlers are created + * per channel. + */ + private static final class KvStateServerChannelInitializer extends ChannelInitializer { + + /** The shared request handler. */ + private final KvStateServerHandler sharedRequestHandler; + + /** + * Creates the channel pipeline initializer with the shared request handler. + * + * @param sharedRequestHandler Shared request handler. + */ + public KvStateServerChannelInitializer(KvStateServerHandler sharedRequestHandler) { + this.sharedRequestHandler = Preconditions.checkNotNull(sharedRequestHandler, "Request handler"); + } + + @Override + protected void initChannel(SocketChannel ch) throws Exception { + ch.pipeline() + .addLast(new ChunkedWriteHandler()) + .addLast(new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4)) + .addLast(sharedRequestHandler); + } + } + +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/KvStateServerHandler.java b/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/KvStateServerHandler.java new file mode 100644 index 0000000000000..47f2ad6205ce2 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/KvStateServerHandler.java @@ -0,0 +1,301 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.query.netty; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.util.ReferenceCountUtil; +import org.apache.flink.runtime.query.KvStateRegistry; +import org.apache.flink.runtime.query.netty.message.KvStateRequest; +import org.apache.flink.runtime.query.netty.message.KvStateRequestSerializer; +import org.apache.flink.runtime.query.netty.message.KvStateRequestType; +import org.apache.flink.runtime.state.KvState; +import org.apache.flink.util.ExceptionUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.Objects; +import java.util.concurrent.ExecutorService; + +/** + * This handler dispatches asynchronous tasks, which query {@link KvState} + * instances and write the result to the channel. + * + *

The network threads receive the message, deserialize it and dispatch the + * query task. The actual query is handled in a separate thread as it might + * otherwise block the network threads (file I/O etc.). + */ +@ChannelHandler.Sharable +class KvStateServerHandler extends ChannelInboundHandlerAdapter { + + private static final Logger LOG = LoggerFactory.getLogger(KvStateServerHandler.class); + + /** KvState registry holding references to the KvState instances. */ + private final KvStateRegistry registry; + + /** Thread pool for query execution. */ + private final ExecutorService queryExecutor; + + /** Exposed server statistics. */ + private final KvStateRequestStats stats; + + /** + * Create the handler. + * + * @param kvStateRegistry Registry to query. + * @param queryExecutor Thread pool for query execution. + * @param stats Exposed server statistics. + */ + KvStateServerHandler( + KvStateRegistry kvStateRegistry, + ExecutorService queryExecutor, + KvStateRequestStats stats) { + + this.registry = Objects.requireNonNull(kvStateRegistry, "KvStateRegistry"); + this.queryExecutor = Objects.requireNonNull(queryExecutor, "Query thread pool"); + this.stats = Objects.requireNonNull(stats, "KvStateRequestStats"); + } + + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + stats.reportActiveConnection(); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + stats.reportInactiveConnection(); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + KvStateRequest request = null; + + try { + ByteBuf buf = (ByteBuf) msg; + KvStateRequestType msgType = KvStateRequestSerializer.deserializeHeader(buf); + + if (msgType == KvStateRequestType.REQUEST) { + // ------------------------------------------------------------ + // Request + // ------------------------------------------------------------ + request = KvStateRequestSerializer.deserializeKvStateRequest(buf); + + stats.reportRequest(); + + KvState kvState = registry.getKvState(request.getKvStateId()); + + if (kvState != null) { + // Execute actual query async, because it is possibly + // blocking (e.g. file I/O). + // + // A submission failure is not treated as fatal. + queryExecutor.submit(new AsyncKvStateQueryTask(ctx, request, kvState, stats)); + } else { + ByteBuf unknown = KvStateRequestSerializer.serializeKvStateRequestFailure( + ctx.alloc(), + request.getRequestId(), + new UnknownKvStateID(request.getKvStateId())); + + ctx.writeAndFlush(unknown); + + stats.reportFailedRequest(); + } + } else { + // ------------------------------------------------------------ + // Unexpected + // ------------------------------------------------------------ + ByteBuf failure = KvStateRequestSerializer.serializeServerFailure( + ctx.alloc(), + new IllegalArgumentException("Unexpected message type " + msgType + + ". KvStateServerHandler expects " + + KvStateRequestType.REQUEST + " messages.")); + + ctx.writeAndFlush(failure); + } + } catch (Throwable t) { + String stringifiedCause = ExceptionUtils.stringifyException(t); + + ByteBuf err; + if (request != null) { + String errMsg = "Failed to handle incoming request with ID " + + request.getRequestId() + ". Caused by: " + stringifiedCause; + err = KvStateRequestSerializer.serializeKvStateRequestFailure( + ctx.alloc(), + request.getRequestId(), + new RuntimeException(errMsg)); + + stats.reportFailedRequest(); + } else { + String errMsg = "Failed to handle incoming message. Caused by: " + stringifiedCause; + err = KvStateRequestSerializer.serializeServerFailure( + ctx.alloc(), + new RuntimeException(errMsg)); + } + + ctx.writeAndFlush(err); + } finally { + // IMPORTANT: We have to always recycle the incoming buffer. + // Otherwise we will leak memory out of Netty's buffer pool. + // + // If any operation ever holds on to the buffer, it is the + // responsibility of that operation to retain the buffer and + // release it later. + ReferenceCountUtil.release(msg); + } + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + String stringifiedCause = ExceptionUtils.stringifyException(cause); + String msg = "Exception in server pipeline. Caused by: " + stringifiedCause; + + ByteBuf err = KvStateRequestSerializer.serializeServerFailure( + ctx.alloc(), + new RuntimeException(msg)); + + ctx.writeAndFlush(err).addListener(ChannelFutureListener.CLOSE); + } + + /** + * Task to execute the actual query against the {@link KvState} instance. + */ + private static class AsyncKvStateQueryTask implements Runnable { + + private final ChannelHandlerContext ctx; + + private final KvStateRequest request; + + private final KvState kvState; + + private final KvStateRequestStats stats; + + private final long creationNanos; + + public AsyncKvStateQueryTask( + ChannelHandlerContext ctx, + KvStateRequest request, + KvState kvState, + KvStateRequestStats stats) { + + this.ctx = Objects.requireNonNull(ctx, "Channel handler context"); + this.request = Objects.requireNonNull(request, "State query"); + this.kvState = Objects.requireNonNull(kvState, "KvState"); + this.stats = Objects.requireNonNull(stats, "State query stats"); + this.creationNanos = System.nanoTime(); + } + + @Override + public void run() { + boolean success = false; + + try { + if (!ctx.channel().isActive()) { + return; + } + + // Query the KvState instance + byte[] serializedKeyAndNamespace = request.getSerializedKeyAndNamespace(); + byte[] serializedResult = kvState.getSerializedValue(serializedKeyAndNamespace); + + if (serializedResult != null) { + // We found some data, success! + ByteBuf buf = KvStateRequestSerializer.serializeKvStateRequestResult( + ctx.alloc(), + request.getRequestId(), + serializedResult); + + int highWatermark = ctx.channel().config().getWriteBufferHighWaterMark(); + + ChannelFuture write; + if (buf.readableBytes() <= highWatermark) { + write = ctx.writeAndFlush(buf); + } else { + write = ctx.writeAndFlush(new ChunkedByteBuf(buf, highWatermark)); + } + + write.addListener(new QueryResultWriteListener()); + + success = true; + } else { + // No data for the key/namespace. This is considered to be + // a failure. + ByteBuf unknownKey = KvStateRequestSerializer.serializeKvStateRequestFailure( + ctx.alloc(), + request.getRequestId(), + new UnknownKeyOrNamespace()); + + ctx.writeAndFlush(unknownKey); + } + } catch (Throwable t) { + try { + String stringifiedCause = ExceptionUtils.stringifyException(t); + String errMsg = "Failed to query state backend for query " + + request.getRequestId() + ". Caused by: " + stringifiedCause; + + ByteBuf err = KvStateRequestSerializer.serializeKvStateRequestFailure( + ctx.alloc(), request.getRequestId(), new RuntimeException(errMsg)); + + ctx.writeAndFlush(err); + } catch (IOException e) { + LOG.error("Failed to respond with the error after failed to query state backend", e); + } + } finally { + if (!success) { + stats.reportFailedRequest(); + } + } + } + + @Override + public String toString() { + return "AsyncKvStateQueryTask{" + + ", request=" + request + + ", creationNanos=" + creationNanos + + '}'; + } + + /** + * Callback after query result has been written. + * + *

Gathers stats and logs errors. + */ + private class QueryResultWriteListener implements ChannelFutureListener { + + @Override + public void operationComplete(ChannelFuture future) throws Exception { + long durationMillis = (System.nanoTime() - creationNanos) / 1_000_000; + + if (future.isSuccess()) { + stats.reportSuccessfulRequest(durationMillis); + } else { + if (LOG.isDebugEnabled()) { + LOG.debug("Query " + request + " failed after " + durationMillis + " ms", future.cause()); + } + + stats.reportFailedRequest(); + } + } + } + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/UnknownKeyOrNamespace.java b/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/UnknownKeyOrNamespace.java new file mode 100644 index 0000000000000..4e5a1deaa229c --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/UnknownKeyOrNamespace.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.query.netty; + +/** + * Thrown if the KvState does not hold any state for the given key or namespace. + */ +public class UnknownKeyOrNamespace extends IllegalStateException { + + private static final long serialVersionUID = 1L; + + UnknownKeyOrNamespace() { + super("KvState does not hold any state for key/namespace."); + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/UnknownKvStateID.java b/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/UnknownKvStateID.java new file mode 100644 index 0000000000000..cc60035042ecd --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/UnknownKvStateID.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.query.netty; + +import org.apache.flink.runtime.query.KvStateID; +import org.apache.flink.util.Preconditions; + +/** + * Thrown if no KvState with the given ID cannot found by the server handler. + */ +public class UnknownKvStateID extends IllegalStateException { + + private static final long serialVersionUID = 1L; + + public UnknownKvStateID(KvStateID kvStateId) { + super("No KvState registered with ID " + Preconditions.checkNotNull(kvStateId, "KvStateID") + + " at TaskManager."); + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/message/KvStateRequest.java b/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/message/KvStateRequest.java new file mode 100644 index 0000000000000..0abb653797099 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/message/KvStateRequest.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.query.netty.message; + +import org.apache.flink.runtime.query.KvStateID; +import org.apache.flink.runtime.state.KvState; +import org.apache.flink.util.Preconditions; + +/** + * A {@link KvState} instance request for a specific key and namespace. + */ +public final class KvStateRequest { + + /** ID for this request. */ + private final long requestId; + + /** ID of the requested KvState instance. */ + private final KvStateID kvStateId; + + /** Serialized key and namespace to request from the KvState instance. */ + private final byte[] serializedKeyAndNamespace; + + /** + * Creates a KvState instance request. + * + * @param requestId ID for this request + * @param kvStateId ID of the requested KvState instance + * @param serializedKeyAndNamespace Serialized key and namespace to request from the KvState + * instance + */ + KvStateRequest(long requestId, KvStateID kvStateId, byte[] serializedKeyAndNamespace) { + this.requestId = requestId; + this.kvStateId = Preconditions.checkNotNull(kvStateId, "KvStateID"); + this.serializedKeyAndNamespace = Preconditions.checkNotNull(serializedKeyAndNamespace, "Serialized key and namespace"); + } + + /** + * Returns the request ID. + * + * @return Request ID + */ + public long getRequestId() { + return requestId; + } + + /** + * Returns the ID of the requested KvState instance. + * + * @return ID of the requested KvState instance + */ + public KvStateID getKvStateId() { + return kvStateId; + } + + /** + * Returns the serialized key and namespace to request from the KvState + * instance. + * + * @return Serialized key and namespace to request from the KvState instance + */ + public byte[] getSerializedKeyAndNamespace() { + return serializedKeyAndNamespace; + } + + @Override + public String toString() { + return "KvStateRequest{" + + "requestId=" + requestId + + ", kvStateId=" + kvStateId + + ", serializedKeyAndNamespace.length=" + serializedKeyAndNamespace.length + + '}'; + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/message/KvStateRequestFailure.java b/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/message/KvStateRequestFailure.java new file mode 100644 index 0000000000000..06a3ce89e24ba --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/message/KvStateRequestFailure.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.query.netty.message; + +/** + * A failure response to a {@link KvStateRequest}. + */ +public final class KvStateRequestFailure { + + /** ID of the request responding to. */ + private final long requestId; + + /** Failure cause. Not allowed to be a user type. */ + private final Throwable cause; + + /** + * Creates a failure response to a {@link KvStateRequest}. + * + * @param requestId ID for the request responding to + * @param cause Failure cause (not allowed to be a user type) + */ + KvStateRequestFailure(long requestId, Throwable cause) { + this.requestId = requestId; + this.cause = cause; + } + + /** + * Returns the request ID responding to. + * + * @return Request ID responding to + */ + public long getRequestId() { + return requestId; + } + + /** + * Returns the failure cause. + * + * @return Failure cause + */ + public Throwable getCause() { + return cause; + } + + @Override + public String toString() { + return "KvStateRequestFailure{" + + "requestId=" + requestId + + ", cause=" + cause + + '}'; + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/message/KvStateRequestResult.java b/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/message/KvStateRequestResult.java new file mode 100644 index 0000000000000..2bd8a36b12993 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/message/KvStateRequestResult.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.query.netty.message; + +import org.apache.flink.util.Preconditions; + +/** + * A successful response to a {@link KvStateRequest} containing the serialized + * result for the requested key and namespace. + */ +public final class KvStateRequestResult { + + /** ID of the request responding to. */ + private final long requestId; + + /** + * Serialized result for the requested key and namespace. If no result was + * available for the specified key and namespace, this is null. + */ + private final byte[] serializedResult; + + /** + * Creates a successful {@link KvStateRequestResult} response. + * + * @param requestId ID of the request responding to + * @param serializedResult Serialized result or null if none + */ + KvStateRequestResult(long requestId, byte[] serializedResult) { + this.requestId = requestId; + this.serializedResult = Preconditions.checkNotNull(serializedResult, "Serialization result"); + } + + /** + * Returns the request ID responding to. + * + * @return Request ID responding to + */ + public long getRequestId() { + return requestId; + } + + /** + * Returns the serialized result or null if none available. + * + * @return Serialized result or null if none available. + */ + public byte[] getSerializedResult() { + return serializedResult; + } + + @Override + public String toString() { + return "KvStateRequestResult{" + + "requestId=" + requestId + + ", serializedResult.length=" + serializedResult.length + + '}'; + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/message/KvStateRequestSerializer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/message/KvStateRequestSerializer.java new file mode 100644 index 0000000000000..0ae60f60ca77f --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/message/KvStateRequestSerializer.java @@ -0,0 +1,518 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.query.netty.message; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufInputStream; +import io.netty.buffer.ByteBufOutputStream; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.runtime.query.KvStateID; +import org.apache.flink.runtime.query.netty.KvStateClient; +import org.apache.flink.runtime.query.netty.KvStateServer; +import org.apache.flink.runtime.util.DataInputDeserializer; +import org.apache.flink.runtime.util.DataOutputSerializer; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; + +/** + * Serialization and deserialization of messages exchanged between + * {@link KvStateClient} and {@link KvStateServer}. + * + *

The binary messages have the following format: + * + *

+ *                     <------ Frame ------------------------->
+ *                    +----------------------------------------+
+ *                    |        HEADER (8)      | PAYLOAD (VAR) |
+ * +------------------+----------------------------------------+
+ * | FRAME LENGTH (4) | VERSION (4) | TYPE (4) | CONTENT (VAR) |
+ * +------------------+----------------------------------------+
+ * 
+ * + *

The concrete content of a message depends on the {@link KvStateRequestType}. + */ +public final class KvStateRequestSerializer { + + /** The serialization version ID. */ + private static final int VERSION = 0x79a1b710; + + /** Byte length of the header. */ + private static final int HEADER_LENGTH = 8; + + // ------------------------------------------------------------------------ + // Serialization + // ------------------------------------------------------------------------ + + /** + * Allocates a buffer and serializes the KvState request into it. + * + * @param alloc ByteBuf allocator for the buffer to + * serialize message into + * @param requestId ID for this request + * @param kvStateId ID of the requested KvState instance + * @param serializedKeyAndNamespace Serialized key and namespace to request + * from the KvState instance. + * @return Serialized KvState request message + */ + public static ByteBuf serializeKvStateRequest( + ByteBufAllocator alloc, + long requestId, + KvStateID kvStateId, + byte[] serializedKeyAndNamespace) { + + // Header + request ID + KvState ID + Serialized namespace + int frameLength = HEADER_LENGTH + 8 + (8 + 8) + (4 + serializedKeyAndNamespace.length); + ByteBuf buf = alloc.ioBuffer(frameLength + 4); // +4 for frame length + + buf.writeInt(frameLength); + + writeHeader(buf, KvStateRequestType.REQUEST); + + buf.writeLong(requestId); + buf.writeLong(kvStateId.getLowerPart()); + buf.writeLong(kvStateId.getUpperPart()); + buf.writeInt(serializedKeyAndNamespace.length); + buf.writeBytes(serializedKeyAndNamespace); + + return buf; + } + + /** + * Allocates a buffer and serializes the KvState request result into it. + * + * @param alloc ByteBuf allocator for the buffer to serialize message into + * @param requestId ID for this request + * @param serializedResult Serialized Result + * @return Serialized KvState request result message + */ + public static ByteBuf serializeKvStateRequestResult( + ByteBufAllocator alloc, + long requestId, + byte[] serializedResult) { + + Preconditions.checkNotNull(serializedResult, "Serialized result"); + + // Header + request ID + serialized result + int frameLength = HEADER_LENGTH + 8 + 4 + serializedResult.length; + + ByteBuf buf = alloc.ioBuffer(frameLength); + + buf.writeInt(frameLength); + writeHeader(buf, KvStateRequestType.REQUEST_RESULT); + buf.writeLong(requestId); + + buf.writeInt(serializedResult.length); + buf.writeBytes(serializedResult); + + return buf; + } + + /** + * Allocates a buffer and serializes the KvState request failure into it. + * + * @param alloc ByteBuf allocator for the buffer to serialize message into + * @param requestId ID of the request responding to + * @param cause Failure cause + * @return Serialized KvState request failure message + * @throws IOException Serialization failures are forwarded + */ + public static ByteBuf serializeKvStateRequestFailure( + ByteBufAllocator alloc, + long requestId, + Throwable cause) throws IOException { + + ByteBuf buf = alloc.ioBuffer(); + + // Frame length is set at the end + buf.writeInt(0); + + writeHeader(buf, KvStateRequestType.REQUEST_FAILURE); + + // Message + buf.writeLong(requestId); + + try (ByteBufOutputStream bbos = new ByteBufOutputStream(buf); + ObjectOutputStream out = new ObjectOutputStream(bbos)) { + + out.writeObject(cause); + } + + // Set frame length + int frameLength = buf.readableBytes() - 4; + buf.setInt(0, frameLength); + + return buf; + } + + /** + * Allocates a buffer and serializes the server failure into it. + * + *

The cause must not be or contain any user types as causes. + * + * @param alloc ByteBuf allocator for the buffer to serialize message into + * @param cause Failure cause + * @return Serialized server failure message + * @throws IOException Serialization failures are forwarded + */ + public static ByteBuf serializeServerFailure(ByteBufAllocator alloc, Throwable cause) throws IOException { + ByteBuf buf = alloc.ioBuffer(); + + // Frame length is set at end + buf.writeInt(0); + + writeHeader(buf, KvStateRequestType.SERVER_FAILURE); + + try (ByteBufOutputStream bbos = new ByteBufOutputStream(buf); + ObjectOutputStream out = new ObjectOutputStream(bbos)) { + + out.writeObject(cause); + } + + // Set frame length + int frameLength = buf.readableBytes() - 4; + buf.setInt(0, frameLength); + + return buf; + } + + // ------------------------------------------------------------------------ + // Deserialization + // ------------------------------------------------------------------------ + + /** + * Deserializes the header and returns the request type. + * + * @param buf Buffer to deserialize (expected to be at header position) + * @return Deserialzied request type + * @throws IllegalArgumentException If unexpected message version or message type + */ + public static KvStateRequestType deserializeHeader(ByteBuf buf) { + // Check the version + int version = buf.readInt(); + if (version != VERSION) { + throw new IllegalArgumentException("Illegal message version " + version + + ". Expected: " + VERSION + "."); + } + + // Get the message type + int msgType = buf.readInt(); + KvStateRequestType[] values = KvStateRequestType.values(); + if (msgType >= 0 && msgType <= values.length) { + return values[msgType]; + } else { + throw new IllegalArgumentException("Illegal message type with index " + msgType); + } + } + + /** + * Deserializes the KvState request message. + * + *

Important: the returned buffer is sliced from the + * incoming ByteBuf stream and retained. Therefore, it needs to be recycled + * by the consumer. + * + * @param buf Buffer to deserialize (expected to be positioned after header) + * @return Deserialized KvStateRequest + */ + public static KvStateRequest deserializeKvStateRequest(ByteBuf buf) { + long requestId = buf.readLong(); + KvStateID kvStateId = new KvStateID(buf.readLong(), buf.readLong()); + + // Serialized key and namespace + int length = buf.readInt(); + + if (length < 0) { + throw new IllegalArgumentException("Negative length for serialized key and namespace. " + + "This indicates a serialization error."); + } + + // Copy the buffer in order to be able to safely recycle the ByteBuf + byte[] serializedKeyAndNamespace = new byte[length]; + if (length > 0) { + buf.readBytes(serializedKeyAndNamespace); + } + + return new KvStateRequest(requestId, kvStateId, serializedKeyAndNamespace); + } + + /** + * Deserializes the KvState request result. + * + * @param buf Buffer to deserialize (expected to be positioned after header) + * @return Deserialized KvStateRequestResult + */ + public static KvStateRequestResult deserializeKvStateRequestResult(ByteBuf buf) { + long requestId = buf.readLong(); + + // Serialized KvState + int length = buf.readInt(); + + if (length < 0) { + throw new IllegalArgumentException("Negative length for serialized result. " + + "This indicates a serialization error."); + } + + byte[] serializedValue = new byte[length]; + + if (length > 0) { + buf.readBytes(serializedValue); + } + + return new KvStateRequestResult(requestId, serializedValue); + } + + /** + * Deserializes the KvState request failure. + * + * @param buf Buffer to deserialize (expected to be positioned after header) + * @return Deserialized KvStateRequestFailure + */ + public static KvStateRequestFailure deserializeKvStateRequestFailure(ByteBuf buf) throws IOException, ClassNotFoundException { + long requestId = buf.readLong(); + + Throwable cause; + try (ByteBufInputStream bbis = new ByteBufInputStream(buf); + ObjectInputStream in = new ObjectInputStream(bbis)) { + + cause = (Throwable) in.readObject(); + } + + return new KvStateRequestFailure(requestId, cause); + } + + /** + * Deserializes the KvState request failure. + * + * @param buf Buffer to deserialize (expected to be positioned after header) + * @return Deserialized KvStateRequestFailure + * @throws IOException Serialization failure are forwarded + * @throws ClassNotFoundException If Exception type can not be loaded + */ + public static Throwable deserializeServerFailure(ByteBuf buf) throws IOException, ClassNotFoundException { + try (ByteBufInputStream bbis = new ByteBufInputStream(buf); + ObjectInputStream in = new ObjectInputStream(bbis)) { + + return (Throwable) in.readObject(); + } + } + + // ------------------------------------------------------------------------ + // Generic serialization utils + // ------------------------------------------------------------------------ + + /** + * Serializes the key and namespace into a {@link ByteBuffer}. + * + *

The serialized format matches the RocksDB state backend key format, i.e. + * the key and namespace don't have to be deserialized for RocksDB lookups. + * + * @param key Key to serialize + * @param keySerializer Serializer for the key + * @param namespace Namespace to serialize + * @param namespaceSerializer Serializer for the namespace + * @param Key type + * @param Namespace type + * @return Buffer holding the serialized key and namespace + * @throws IOException Serialization errors are forwarded + */ + public static byte[] serializeKeyAndNamespace( + K key, + TypeSerializer keySerializer, + N namespace, + TypeSerializer namespaceSerializer) throws IOException { + + DataOutputSerializer dos = new DataOutputSerializer(32); + + keySerializer.serialize(key, dos); + dos.writeByte(42); + namespaceSerializer.serialize(namespace, dos); + + return dos.getCopyOfBuffer(); + } + + /** + * Deserializes the key and namespace into a {@link Tuple2}. + * + * @param serializedKeyAndNamespace Serialized key and namespace + * @param keySerializer Serializer for the key + * @param namespaceSerializer Serializer for the namespace + * @param Key type + * @param Namespace + * @return Tuple2 holding deserialized key and namespace + * @throws IOException Serialization errors are forwarded + * @throws IllegalStateException If unexpected magic number between key and namespace + */ + public static Tuple2 deserializeKeyAndNamespace( + byte[] serializedKeyAndNamespace, + TypeSerializer keySerializer, + TypeSerializer namespaceSerializer) throws IOException { + + DataInputDeserializer dis = new DataInputDeserializer( + serializedKeyAndNamespace, + 0, + serializedKeyAndNamespace.length); + + K key = keySerializer.deserialize(dis); + byte magicNumber = dis.readByte(); + if (magicNumber != 42) { + throw new IllegalArgumentException("Unexpected magic number " + magicNumber + + ". This indicates a mismatch in the key serializers used by the " + + "KvState instance and this access."); + } + N namespace = namespaceSerializer.deserialize(dis); + + if (dis.available() > 0) { + throw new IllegalArgumentException("Unconsumed bytes in the serialized key " + + "and namespace. This indicates a mismatch in the key/namespace " + + "serializers used by the KvState instance and this access."); + } + + return new Tuple2<>(key, namespace); + } + + /** + * Serializes the value with the given serializer. + * + * @param value Value of type T to serialize + * @param serializer Serializer for T + * @param Type of the value + * @return Serialized value or null if value null + * @throws IOException On failure during serialization + */ + public static byte[] serializeValue(T value, TypeSerializer serializer) throws IOException { + if (value != null) { + // Serialize + DataOutputSerializer dos = new DataOutputSerializer(32); + serializer.serialize(value, dos); + return dos.getCopyOfBuffer(); + } else { + return null; + } + } + + /** + * Deserializes the value with the given serializer. + * + * @param serializedValue Serialized value of type T + * @param serializer Serializer for T + * @param Type of the value + * @return Deserialized value or null if the serialized value + * is null + * @throws IOException On failure during deserialization + */ + public static T deserializeValue(byte[] serializedValue, TypeSerializer serializer) throws IOException { + if (serializedValue == null) { + return null; + } else { + DataInputDeserializer deser = new DataInputDeserializer(serializedValue, 0, serializedValue.length); + return serializer.deserialize(deser); + } + } + + /** + * Serializes all values of the Iterable with the given serializer. + * + * @param values Values of type T to serialize + * @param serializer Serializer for T + * @param Type of the values + * @return Serialized values or null if values null or empty + * @throws IOException On failure during serialization + */ + public static byte[] serializeList(Iterable values, TypeSerializer serializer) throws IOException { + if (values != null) { + Iterator it = values.iterator(); + + if (it.hasNext()) { + // Serialize + DataOutputSerializer dos = new DataOutputSerializer(32); + + while (it.hasNext()) { + serializer.serialize(it.next(), dos); + + // This byte added here in order to have the binary format + // prescribed by RocksDB. + dos.write(0); + } + + return dos.getCopyOfBuffer(); + } else { + return null; + } + } else { + return null; + } + } + + /** + * Deserializes all values with the given serializer. + * + * @param serializedValue Serialized value of type List + * @param serializer Serializer for T + * @param Type of the value + * @return Deserialized list or null if the serialized value + * is null + * @throws IOException On failure during deserialization + */ + public static List deserializeList(byte[] serializedValue, TypeSerializer serializer) throws IOException { + if (serializedValue != null) { + DataInputDeserializer in = new DataInputDeserializer(serializedValue, 0, serializedValue.length); + + List result = new ArrayList<>(); + while (in.available() > 0) { + result.add(serializer.deserialize(in)); + + // The expected binary format has a single byte separator. We + // want a consistent binary format in order to not need any + // special casing during deserialization. A "cleaner" format + // would skip this extra byte, but would require a memory copy + // for RocksDB, which stores the data serialized in this way + // for lists. + if (in.available() > 0) { + in.readByte(); + } + } + + return result; + } else { + return null; + } + } + + // ------------------------------------------------------------------------ + + /** + * Helper for writing the header. + * + * @param buf Buffer to serialize header into + * @param requestType Result type to serialize + */ + private static void writeHeader(ByteBuf buf, KvStateRequestType requestType) { + buf.writeInt(VERSION); + buf.writeInt(requestType.ordinal()); + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/message/KvStateRequestType.java b/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/message/KvStateRequestType.java new file mode 100644 index 0000000000000..de7270a1c135c --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/message/KvStateRequestType.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.query.netty.message; + +import org.apache.flink.runtime.query.netty.KvStateServer; + +/** + * Expected message types when communicating with the {@link KvStateServer}. + */ +public enum KvStateRequestType { + + /** Request a KvState instance. */ + REQUEST, + + /** Successful response to a KvStateRequest. */ + REQUEST_RESULT, + + /** Failure response to a KvStateRequest. */ + REQUEST_FAILURE, + + /** Generic server failure. */ + SERVER_FAILURE + +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/package-info.java b/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/package-info.java new file mode 100644 index 0000000000000..7e8de40858fd1 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/package-info.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * This package contains all Netty-based client/server classes used to query + * KvState instances. + * + *

Server and Client

+ * + *

Both server and client expect received binary messages to contain a frame + * length field. Netty's {@link io.netty.handler.codec.LengthFieldBasedFrameDecoder} + * is used to fully receive the frame before giving it to the respective client + * or server handler. + * + *

Connection establishment and release happens by the client. The server + * only closes a connection if a fatal failure happens that cannot be resolved + * otherwise. + * + *

The is a single server per task manager and a single client can be shared + * by multiple Threads. + * + *

See also: + *

    + *
  • {@link org.apache.flink.runtime.query.netty.KvStateServer}
  • + *
  • {@link org.apache.flink.runtime.query.netty.KvStateServerHandler}
  • + *
  • {@link org.apache.flink.runtime.query.netty.KvStateClient}
  • + *
  • {@link org.apache.flink.runtime.query.netty.KvStateClientHandler}
  • + *
+ * + *

Serialization

+ * + *

The exchanged binary messages have the following format: + * + *

+ *                     <------ Frame ------------------------->
+ *                    +----------------------------------------+
+ *                    |        HEADER (8)      | PAYLOAD (VAR) |
+ * +------------------+----------------------------------------+
+ * | FRAME LENGTH (4) | VERSION (4) | TYPE (4) | CONTENT (VAR) |
+ * +------------------+----------------------------------------+
+ * 
+ * + *

For frame decoding, both server and client use Netty's {@link + * io.netty.handler.codec.LengthFieldBasedFrameDecoder}. Message serialization + * is done via static helpers in {@link org.apache.flink.runtime.query.netty.message.KvStateRequestSerializer}. + * The serialization helpers return {@link io.netty.buffer.ByteBuf} instances, + * which are ready to be sent to the client or server respectively as they + * contain the frame length. + * + *

See also: + *

    + *
  • {@link org.apache.flink.runtime.query.netty.message.KvStateRequestSerializer}
  • + *
+ * + *

Statistics

+ * + *

Both server and client keep track of request statistics via {@link + * org.apache.flink.runtime.query.netty.KvStateRequestStats}. + * + *

See also: + *

    + *
  • {@link org.apache.flink.runtime.query.netty.KvStateRequestStats}
  • + *
+ */ +package org.apache.flink.runtime.query.netty; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/util/DataInputDeserializer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/util/DataInputDeserializer.java index bdccdd121c01b..9822a834f6283 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/util/DataInputDeserializer.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/util/DataInputDeserializer.java @@ -95,6 +95,14 @@ public void releaseArrays() { // ---------------------------------------------------------------------------------------- // Data Input // ---------------------------------------------------------------------------------------- + + public int available() { + if (position < end) { + return end - position - 1; + } else { + return 0; + } + } @Override public boolean readBoolean() throws IOException { diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateClientHandlerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateClientHandlerTest.java new file mode 100644 index 0000000000000..31a96203269b4 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateClientHandlerTest.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.query.netty; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.embedded.EmbeddedChannel; +import org.apache.flink.runtime.query.netty.message.KvStateRequestSerializer; +import org.junit.Test; + +import java.nio.channels.ClosedChannelException; + +import static org.junit.Assert.assertEquals; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +public class KvStateClientHandlerTest { + + /** + * Tests that on reads the expected callback methods are called and read + * buffers are recycled. + */ + @Test + public void testReadCallbacksAndBufferRecycling() throws Exception { + KvStateClientHandlerCallback callback = mock(KvStateClientHandlerCallback.class); + + EmbeddedChannel channel = new EmbeddedChannel(new KvStateClientHandler(callback)); + + // + // Request success + // + ByteBuf buf = KvStateRequestSerializer.serializeKvStateRequestResult( + channel.alloc(), + 1222112277, + new byte[0]); + buf.skipBytes(4); // skip frame length + + // Verify callback + channel.writeInbound(buf); + verify(callback, times(1)).onRequestResult(eq(1222112277L), any(byte[].class)); + assertEquals("Buffer not recycled", 0, buf.refCnt()); + + // + // Request failure + // + buf = KvStateRequestSerializer.serializeKvStateRequestFailure( + channel.alloc(), + 1222112278, + new RuntimeException("Expected test Exception")); + buf.skipBytes(4); // skip frame length + + // Verify callback + channel.writeInbound(buf); + verify(callback, times(1)).onRequestFailure(eq(1222112278L), any(RuntimeException.class)); + assertEquals("Buffer not recycled", 0, buf.refCnt()); + + // + // Server failure + // + buf = KvStateRequestSerializer.serializeServerFailure( + channel.alloc(), + new RuntimeException("Expected test Exception")); + buf.skipBytes(4); // skip frame length + + // Verify callback + channel.writeInbound(buf); + verify(callback, times(1)).onFailure(any(RuntimeException.class)); + + // + // Unexpected messages + // + buf = channel.alloc().buffer(4).writeInt(1223823); + + // Verify callback + channel.writeInbound(buf); + verify(callback, times(2)).onFailure(any(IllegalStateException.class)); + assertEquals("Buffer not recycled", 0, buf.refCnt()); + + // + // Exception caught + // + channel.pipeline().fireExceptionCaught(new RuntimeException("Expected test Exception")); + verify(callback, times(3)).onFailure(any(RuntimeException.class)); + + // + // Channel inactive + // + channel.pipeline().fireChannelInactive(); + verify(callback, times(4)).onFailure(any(ClosedChannelException.class)); + } + +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateClientTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateClientTest.java new file mode 100644 index 0000000000000..72d9f618465d6 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateClientTest.java @@ -0,0 +1,718 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.query.netty; + +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.SocketChannel; +import io.netty.channel.socket.nio.NioServerSocketChannel; +import io.netty.handler.codec.LengthFieldBasedFrameDecoder; +import org.apache.flink.api.common.JobID; +import org.apache.flink.api.common.state.ValueStateDescriptor; +import org.apache.flink.api.common.typeutils.base.IntSerializer; +import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.query.KvStateID; +import org.apache.flink.runtime.query.KvStateRegistry; +import org.apache.flink.runtime.query.KvStateServerAddress; +import org.apache.flink.runtime.query.netty.message.KvStateRequest; +import org.apache.flink.runtime.query.netty.message.KvStateRequestSerializer; +import org.apache.flink.runtime.query.netty.message.KvStateRequestType; +import org.apache.flink.runtime.state.VoidNamespace; +import org.apache.flink.runtime.state.VoidNamespaceSerializer; +import org.apache.flink.runtime.state.memory.MemValueState; +import org.apache.flink.util.NetUtils; +import org.junit.AfterClass; +import org.junit.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import scala.concurrent.Await; +import scala.concurrent.Future; +import scala.concurrent.duration.Deadline; +import scala.concurrent.duration.FiniteDuration; + +import java.net.ConnectException; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.UnknownHostException; +import java.nio.channels.ClosedChannelException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +public class KvStateClientTest { + + private static final Logger LOG = LoggerFactory.getLogger(KvStateClientTest.class); + + // Thread pool for client bootstrap (shared between tests) + private static final NioEventLoopGroup NIO_GROUP = new NioEventLoopGroup(); + + private final static FiniteDuration TEST_TIMEOUT = new FiniteDuration(100, TimeUnit.SECONDS); + + @AfterClass + public static void tearDown() throws Exception { + if (NIO_GROUP != null) { + NIO_GROUP.shutdownGracefully(); + } + } + + /** + * Tests simple queries, of which half succeed and half fail. + */ + @Test + public void testSimpleRequests() throws Exception { + Deadline deadline = TEST_TIMEOUT.fromNow(); + AtomicKvStateRequestStats stats = new AtomicKvStateRequestStats(); + + KvStateClient client = null; + Channel serverChannel = null; + + try { + client = new KvStateClient(1, stats); + + // Random result + final byte[] expected = new byte[1024]; + ThreadLocalRandom.current().nextBytes(expected); + + final LinkedBlockingQueue received = new LinkedBlockingQueue<>(); + final AtomicReference channel = new AtomicReference<>(); + + serverChannel = createServerChannel(new ChannelInboundHandlerAdapter() { + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + channel.set(ctx.channel()); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + received.add((ByteBuf) msg); + } + }); + + KvStateServerAddress serverAddress = getKvStateServerAddress(serverChannel); + + List> futures = new ArrayList<>(); + + int numQueries = 1024; + + for (int i = 0; i < numQueries; i++) { + futures.add(client.getKvState(serverAddress, new KvStateID(), new byte[0])); + } + + // Respond to messages + Exception testException = new RuntimeException("Expected test Exception"); + + for (int i = 0; i < numQueries; i++) { + ByteBuf buf = received.poll(deadline.timeLeft().toMillis(), TimeUnit.MILLISECONDS); + assertNotNull("Receive timed out", buf); + + Channel ch = channel.get(); + assertNotNull("Channel not active", ch); + + assertEquals(KvStateRequestType.REQUEST, KvStateRequestSerializer.deserializeHeader(buf)); + KvStateRequest request = KvStateRequestSerializer.deserializeKvStateRequest(buf); + + buf.release(); + + if (i % 2 == 0) { + ByteBuf response = KvStateRequestSerializer.serializeKvStateRequestResult( + serverChannel.alloc(), + request.getRequestId(), + expected); + + ch.writeAndFlush(response); + } else { + ByteBuf response = KvStateRequestSerializer.serializeKvStateRequestFailure( + serverChannel.alloc(), + request.getRequestId(), + testException); + + ch.writeAndFlush(response); + } + } + + for (int i = 0; i < numQueries; i++) { + if (i % 2 == 0) { + byte[] serializedResult = Await.result(futures.get(i), deadline.timeLeft()); + assertArrayEquals(expected, serializedResult); + } else { + try { + Await.result(futures.get(i), deadline.timeLeft()); + fail("Did not throw expected Exception"); + } catch (RuntimeException ignored) { + // Expected + } + } + } + + assertEquals(numQueries, stats.getNumRequests()); + int expectedRequests = numQueries / 2; + + // Counts can take some time to propagate + while (deadline.hasTimeLeft() && (stats.getNumSuccessful() != expectedRequests || + stats.getNumFailed() != expectedRequests)) { + Thread.sleep(100); + } + + assertEquals(expectedRequests, stats.getNumSuccessful()); + assertEquals(expectedRequests, stats.getNumFailed()); + } finally { + if (client != null) { + client.shutDown(); + } + + if (serverChannel != null) { + serverChannel.close(); + } + + assertEquals("Channel leak", 0, stats.getNumConnections()); + } + } + + /** + * Tests that a request to an unavailable host is failed with ConnectException. + */ + @Test + public void testRequestUnavailableHost() throws Exception { + Deadline deadline = TEST_TIMEOUT.fromNow(); + AtomicKvStateRequestStats stats = new AtomicKvStateRequestStats(); + KvStateClient client = null; + + try { + client = new KvStateClient(1, stats); + + int availablePort = NetUtils.getAvailablePort(); + + KvStateServerAddress serverAddress = new KvStateServerAddress( + InetAddress.getLocalHost(), + availablePort); + + Future future = client.getKvState(serverAddress, new KvStateID(), new byte[0]); + + try { + Await.result(future, deadline.timeLeft()); + fail("Did not throw expected ConnectException"); + } catch (ConnectException ignored) { + // Expected + } + } finally { + if (client != null) { + client.shutDown(); + } + + assertEquals("Channel leak", 0, stats.getNumConnections()); + } + } + + /** + * Multiple threads concurrently fire queries. + */ + @Test + public void testConcurrentQueries() throws Exception { + Deadline deadline = TEST_TIMEOUT.fromNow(); + AtomicKvStateRequestStats stats = new AtomicKvStateRequestStats(); + + ExecutorService executor = null; + KvStateClient client = null; + Channel serverChannel = null; + + final byte[] serializedResult = new byte[1024]; + ThreadLocalRandom.current().nextBytes(serializedResult); + + try { + int numQueryTasks = 4; + final int numQueriesPerTask = 1024; + + executor = Executors.newFixedThreadPool(numQueryTasks); + + client = new KvStateClient(1, stats); + + serverChannel = createServerChannel(new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + ByteBuf buf = (ByteBuf) msg; + assertEquals(KvStateRequestType.REQUEST, KvStateRequestSerializer.deserializeHeader(buf)); + KvStateRequest request = KvStateRequestSerializer.deserializeKvStateRequest(buf); + + buf.release(); + + ByteBuf response = KvStateRequestSerializer.serializeKvStateRequestResult( + ctx.alloc(), + request.getRequestId(), + serializedResult); + + ctx.channel().writeAndFlush(response); + } + }); + + final KvStateServerAddress serverAddress = getKvStateServerAddress(serverChannel); + + final KvStateClient finalClient = client; + Callable>> queryTask = new Callable>>() { + @Override + public List> call() throws Exception { + List> results = new ArrayList<>(numQueriesPerTask); + + for (int i = 0; i < numQueriesPerTask; i++) { + results.add(finalClient.getKvState( + serverAddress, + new KvStateID(), + new byte[0])); + } + + return results; + } + }; + + // Submit query tasks + List>>> futures = new ArrayList<>(); + for (int i = 0; i < numQueryTasks; i++) { + futures.add(executor.submit(queryTask)); + } + + // Verify results + for (java.util.concurrent.Future>> future : futures) { + List> results = future.get(deadline.timeLeft().toMillis(), TimeUnit.MILLISECONDS); + for (Future result : results) { + byte[] actual = Await.result(result, deadline.timeLeft()); + assertArrayEquals(serializedResult, actual); + } + } + + int totalQueries = numQueryTasks * numQueriesPerTask; + + // Counts can take some time to propagate + while (deadline.hasTimeLeft() && (stats.getNumSuccessful() != totalQueries || + stats.getNumFailed() != totalQueries)) { + Thread.sleep(100); + } + + assertEquals(totalQueries, stats.getNumRequests()); + assertEquals(totalQueries, stats.getNumSuccessful()); + } finally { + if (executor != null) { + executor.shutdown(); + } + + if (serverChannel != null) { + serverChannel.close(); + } + + if (client != null) { + client.shutDown(); + } + + assertEquals("Channel leak", 0, stats.getNumConnections()); + } + } + + /** + * Tests that a server failure closes the connection and removes it from + * the established connections. + */ + @Test + public void testFailureClosesChannel() throws Exception { + Deadline deadline = TEST_TIMEOUT.fromNow(); + AtomicKvStateRequestStats stats = new AtomicKvStateRequestStats(); + + KvStateClient client = null; + Channel serverChannel = null; + + try { + client = new KvStateClient(1, stats); + + final LinkedBlockingQueue received = new LinkedBlockingQueue<>(); + final AtomicReference channel = new AtomicReference<>(); + + serverChannel = createServerChannel(new ChannelInboundHandlerAdapter() { + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + channel.set(ctx.channel()); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + received.add((ByteBuf) msg); + } + }); + + KvStateServerAddress serverAddress = getKvStateServerAddress(serverChannel); + + // Requests + List> futures = new ArrayList<>(); + futures.add(client.getKvState(serverAddress, new KvStateID(), new byte[0])); + futures.add(client.getKvState(serverAddress, new KvStateID(), new byte[0])); + + ByteBuf buf = received.poll(deadline.timeLeft().toMillis(), TimeUnit.MILLISECONDS); + assertNotNull("Receive timed out", buf); + buf.release(); + + buf = received.poll(deadline.timeLeft().toMillis(), TimeUnit.MILLISECONDS); + assertNotNull("Receive timed out", buf); + buf.release(); + + assertEquals(1, stats.getNumConnections()); + + Channel ch = channel.get(); + assertNotNull("Channel not active", ch); + + // Respond with failure + ch.writeAndFlush(KvStateRequestSerializer.serializeServerFailure( + serverChannel.alloc(), + new RuntimeException("Expected test server failure"))); + + try { + Await.result(futures.remove(0), deadline.timeLeft()); + fail("Did not throw expected server failure"); + } catch (RuntimeException ignored) { + // Expected + } + + try { + Await.result(futures.remove(0), deadline.timeLeft()); + fail("Did not throw expected server failure"); + } catch (RuntimeException ignored) { + // Expected + } + + assertEquals(0, stats.getNumConnections()); + + // Counts can take some time to propagate + while (deadline.hasTimeLeft() && (stats.getNumSuccessful() != 0 || + stats.getNumFailed() != 2)) { + Thread.sleep(100); + } + + assertEquals(2, stats.getNumRequests()); + assertEquals(0, stats.getNumSuccessful()); + assertEquals(2, stats.getNumFailed()); + } finally { + if (client != null) { + client.shutDown(); + } + + if (serverChannel != null) { + serverChannel.close(); + } + + assertEquals("Channel leak", 0, stats.getNumConnections()); + } + } + + /** + * Tests that a server channel close, closes the connection and removes it + * from the established connections. + */ + @Test + public void testServerClosesChannel() throws Exception { + Deadline deadline = TEST_TIMEOUT.fromNow(); + AtomicKvStateRequestStats stats = new AtomicKvStateRequestStats(); + + KvStateClient client = null; + Channel serverChannel = null; + + try { + client = new KvStateClient(1, stats); + + final AtomicBoolean received = new AtomicBoolean(); + final AtomicReference channel = new AtomicReference<>(); + + serverChannel = createServerChannel(new ChannelInboundHandlerAdapter() { + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + channel.set(ctx.channel()); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + received.set(true); + } + }); + + KvStateServerAddress serverAddress = getKvStateServerAddress(serverChannel); + + // Requests + Future future = client.getKvState(serverAddress, new KvStateID(), new byte[0]); + + while (!received.get() && deadline.hasTimeLeft()) { + Thread.sleep(50); + } + assertTrue("Receive timed out", received.get()); + + assertEquals(1, stats.getNumConnections()); + + channel.get().close().await(deadline.timeLeft().toMillis(), TimeUnit.MILLISECONDS); + + try { + Await.result(future, deadline.timeLeft()); + fail("Did not throw expected server failure"); + } catch (ClosedChannelException ignored) { + // Expected + } + + assertEquals(0, stats.getNumConnections()); + + // Counts can take some time to propagate + while (deadline.hasTimeLeft() && (stats.getNumSuccessful() != 0 || + stats.getNumFailed() != 1)) { + Thread.sleep(100); + } + + assertEquals(1, stats.getNumRequests()); + assertEquals(0, stats.getNumSuccessful()); + assertEquals(1, stats.getNumFailed()); + } finally { + if (client != null) { + client.shutDown(); + } + + if (serverChannel != null) { + serverChannel.close(); + } + + assertEquals("Channel leak", 0, stats.getNumConnections()); + } + } + + /** + * Tests multiple clients querying multiple servers until 100k queries have + * been processed. At this point, the client is shut down and its verified + * that all ongoing requests are failed. + */ + @Test + public void testClientServerIntegration() throws Exception { + // Config + final int numServers = 2; + final int numServerEventLoopThreads = 2; + final int numServerQueryThreads = 2; + + final int numClientEventLoopThreads = 4; + final int numClientsTasks = 8; + + final int batchSize = 16; + + final FiniteDuration timeout = new FiniteDuration(10, TimeUnit.SECONDS); + + AtomicKvStateRequestStats clientStats = new AtomicKvStateRequestStats(); + + KvStateClient client = null; + ExecutorService clientTaskExecutor = null; + final KvStateServer[] server = new KvStateServer[numServers]; + + try { + client = new KvStateClient(numClientEventLoopThreads, clientStats); + clientTaskExecutor = Executors.newFixedThreadPool(numClientsTasks); + + // Create state + ValueStateDescriptor desc = new ValueStateDescriptor<>("any", IntSerializer.INSTANCE, null); + desc.setQueryable("any"); + + MemValueState kvState = new MemValueState<>( + IntSerializer.INSTANCE, + VoidNamespaceSerializer.INSTANCE, + desc); + + // Create servers + KvStateRegistry[] registry = new KvStateRegistry[numServers]; + AtomicKvStateRequestStats[] serverStats = new AtomicKvStateRequestStats[numServers]; + final KvStateID[] ids = new KvStateID[numServers]; + + for (int i = 0; i < numServers; i++) { + registry[i] = new KvStateRegistry(); + serverStats[i] = new AtomicKvStateRequestStats(); + server[i] = new KvStateServer( + InetAddress.getLocalHost(), + 0, + numServerEventLoopThreads, + numServerQueryThreads, + registry[i], + serverStats[i]); + + server[i].start(); + + // Value per server + kvState.setCurrentKey(1010 + i); + kvState.setCurrentNamespace(VoidNamespace.INSTANCE); + kvState.update(201 + i); + + // Register KvState (one state instance for all server) + ids[i] = registry[i].registerKvState(new JobID(), new JobVertexID(), 0, "any", kvState); + } + + final KvStateClient finalClient = client; + Callable queryTask = new Callable() { + @Override + public Void call() throws Exception { + while (true) { + if (Thread.interrupted()) { + throw new InterruptedException(); + } + + // Random server permutation + List random = new ArrayList<>(); + for (int j = 0; j < batchSize; j++) { + random.add(j); + } + Collections.shuffle(random); + + // Dispatch queries + List> futures = new ArrayList<>(batchSize); + + for (int j = 0; j < batchSize; j++) { + int targetServer = random.get(j) % numServers; + + byte[] serializedKeyAndNamespace = KvStateRequestSerializer.serializeKeyAndNamespace( + 1010 + targetServer, + IntSerializer.INSTANCE, + VoidNamespace.INSTANCE, + VoidNamespaceSerializer.INSTANCE); + + futures.add(finalClient.getKvState( + server[targetServer].getAddress(), + ids[targetServer], + serializedKeyAndNamespace)); + } + + // Verify results + for (int j = 0; j < batchSize; j++) { + int targetServer = random.get(j) % numServers; + + Future future = futures.get(j); + byte[] buf = Await.result(future, timeout); + int value = KvStateRequestSerializer.deserializeValue(buf, IntSerializer.INSTANCE); + assertEquals(201 + targetServer, value); + } + } + } + }; + + // Submit tasks + List> taskFutures = new ArrayList<>(); + for (int i = 0; i < numClientsTasks; i++) { + taskFutures.add(clientTaskExecutor.submit(queryTask)); + } + + long numRequests; + while ((numRequests = clientStats.getNumRequests()) < 100_000) { + Thread.sleep(100); + LOG.info("Number of requests {}/100_000", numRequests); + } + + // Shut down + client.shutDown(); + + for (java.util.concurrent.Future future : taskFutures) { + try { + future.get(); + fail("Did not throw expected Exception after shut down"); + } catch (ExecutionException t) { + if (t.getCause() instanceof ClosedChannelException || + t.getCause() instanceof IllegalStateException) { + // Expected + } else { + t.printStackTrace(); + fail("Failed with unexpected Exception type: " + t.getClass().getName()); + } + } + } + + assertEquals("Connection leak (client)", 0, clientStats.getNumConnections()); + for (int i = 0; i < numServers; i++) { + boolean success = false; + int numRetries = 0; + while (!success) { + try { + assertEquals("Connection leak (server)", 0, serverStats[i].getNumConnections()); + success = true; + } catch (Throwable t) { + if (numRetries < 10) { + LOG.info("Retrying connection leak check (server)"); + Thread.sleep((numRetries + 1) * 50); + numRetries++; + } else { + throw t; + } + } + } + } + } finally { + if (client != null) { + client.shutDown(); + } + + for (int i = 0; i < numServers; i++) { + if (server[i] != null) { + server[i].shutDown(); + } + } + + if (clientTaskExecutor != null) { + clientTaskExecutor.shutdown(); + } + } + } + + // ------------------------------------------------------------------------ + + private Channel createServerChannel(final ChannelHandler... handlers) throws UnknownHostException, InterruptedException { + ServerBootstrap bootstrap = new ServerBootstrap() + // Bind address and port + .localAddress(InetAddress.getLocalHost(), 0) + // NIO server channels + .group(NIO_GROUP) + .channel(NioServerSocketChannel.class) + // See initializer for pipeline details + .childHandler(new ChannelInitializer() { + @Override + protected void initChannel(SocketChannel ch) throws Exception { + ch.pipeline() + .addLast(new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4)) + .addLast(handlers); + } + }); + + return bootstrap.bind().sync().channel(); + } + + private KvStateServerAddress getKvStateServerAddress(Channel serverChannel) { + InetSocketAddress localAddress = (InetSocketAddress) serverChannel.localAddress(); + + return new KvStateServerAddress(localAddress.getAddress(), localAddress.getPort()); + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateServerHandlerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateServerHandlerTest.java new file mode 100644 index 0000000000000..6ad7ece9815ca --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateServerHandlerTest.java @@ -0,0 +1,622 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.query.netty; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandler; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.LengthFieldBasedFrameDecoder; +import org.apache.flink.api.common.JobID; +import org.apache.flink.api.common.state.ValueStateDescriptor; +import org.apache.flink.api.common.typeutils.base.IntSerializer; +import org.apache.flink.api.common.typeutils.base.StringSerializer; +import org.apache.flink.api.common.typeutils.base.array.BytePrimitiveArraySerializer; +import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.query.KvStateID; +import org.apache.flink.runtime.query.KvStateRegistry; +import org.apache.flink.runtime.query.netty.message.KvStateRequestFailure; +import org.apache.flink.runtime.query.netty.message.KvStateRequestResult; +import org.apache.flink.runtime.query.netty.message.KvStateRequestSerializer; +import org.apache.flink.runtime.query.netty.message.KvStateRequestType; +import org.apache.flink.runtime.state.KvState; +import org.apache.flink.runtime.state.VoidNamespace; +import org.apache.flink.runtime.state.VoidNamespaceSerializer; +import org.apache.flink.runtime.state.memory.MemValueState; +import org.junit.AfterClass; +import org.junit.Test; + +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeoutException; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class KvStateServerHandlerTest { + + /** Shared Thread pool for query execution */ + private final static ExecutorService TEST_THREAD_POOL = Executors.newSingleThreadExecutor(); + + private final static int READ_TIMEOUT_MILLIS = 10000; + + @AfterClass + public static void tearDown() throws Exception { + if (TEST_THREAD_POOL != null) { + TEST_THREAD_POOL.shutdown(); + } + } + + /** + * Tests a simple successful query via an EmbeddedChannel. + */ + @Test + public void testSimpleQuery() throws Exception { + KvStateRegistry registry = new KvStateRegistry(); + AtomicKvStateRequestStats stats = new AtomicKvStateRequestStats(); + + KvStateServerHandler handler = new KvStateServerHandler(registry, TEST_THREAD_POOL, stats); + EmbeddedChannel channel = new EmbeddedChannel(getFrameDecoder(), handler); + + // Register state + ValueStateDescriptor desc = new ValueStateDescriptor<>("any", IntSerializer.INSTANCE, null); + desc.setQueryable("any"); + + MemValueState kvState = new MemValueState<>( + IntSerializer.INSTANCE, + VoidNamespaceSerializer.INSTANCE, + desc); + + KvStateID kvStateId = registry.registerKvState( + new JobID(), + new JobVertexID(), + 0, + "vanilla", + kvState); + + // Update the KvState and request it + int expectedValue = 712828289; + + int key = 99812822; + kvState.setCurrentKey(key); + kvState.setCurrentNamespace(VoidNamespace.INSTANCE); + + kvState.update(expectedValue); + + byte[] serializedKeyAndNamespace = KvStateRequestSerializer.serializeKeyAndNamespace( + key, + IntSerializer.INSTANCE, + VoidNamespace.INSTANCE, + VoidNamespaceSerializer.INSTANCE); + + long requestId = Integer.MAX_VALUE + 182828L; + ByteBuf request = KvStateRequestSerializer.serializeKvStateRequest( + channel.alloc(), + requestId, + kvStateId, + serializedKeyAndNamespace); + + // Write the request and wait for the response + channel.writeInbound(request); + + ByteBuf buf = (ByteBuf) readInboundBlocking(channel); + buf.skipBytes(4); // skip frame length + + // Verify the response + assertEquals(KvStateRequestType.REQUEST_RESULT, KvStateRequestSerializer.deserializeHeader(buf)); + KvStateRequestResult response = KvStateRequestSerializer.deserializeKvStateRequestResult(buf); + + assertEquals(requestId, response.getRequestId()); + + int actualValue = KvStateRequestSerializer.deserializeValue(response.getSerializedResult(), IntSerializer.INSTANCE); + assertEquals(expectedValue, actualValue); + + assertEquals(1, stats.getNumRequests()); + assertEquals(1, stats.getNumSuccessful()); + } + + /** + * Tests the failure response with {@link UnknownKvStateID} as cause on + * queries for unregistered KvStateIDs. + */ + @Test + public void testQueryUnknownKvStateID() throws Exception { + KvStateRegistry registry = new KvStateRegistry(); + AtomicKvStateRequestStats stats = new AtomicKvStateRequestStats(); + + KvStateServerHandler handler = new KvStateServerHandler(registry, TEST_THREAD_POOL, stats); + EmbeddedChannel channel = new EmbeddedChannel(getFrameDecoder(), handler); + + long requestId = Integer.MAX_VALUE + 182828L; + ByteBuf request = KvStateRequestSerializer.serializeKvStateRequest( + channel.alloc(), + requestId, + new KvStateID(), + new byte[0]); + + // Write the request and wait for the response + channel.writeInbound(request); + + ByteBuf buf = (ByteBuf) readInboundBlocking(channel); + buf.skipBytes(4); // skip frame length + + // Verify the response + assertEquals(KvStateRequestType.REQUEST_FAILURE, KvStateRequestSerializer.deserializeHeader(buf)); + KvStateRequestFailure response = KvStateRequestSerializer.deserializeKvStateRequestFailure(buf); + + assertEquals(requestId, response.getRequestId()); + + assertTrue("Did not respond with expected failure cause", response.getCause() instanceof UnknownKvStateID); + + assertEquals(1, stats.getNumRequests()); + assertEquals(1, stats.getNumFailed()); + } + + /** + * Tests the failure response with {@link UnknownKeyOrNamespace} as cause + * on queries for non-existing keys. + */ + @Test + public void testQueryUnknownKey() throws Exception { + KvStateRegistry registry = new KvStateRegistry(); + AtomicKvStateRequestStats stats = new AtomicKvStateRequestStats(); + + KvStateServerHandler handler = new KvStateServerHandler(registry, TEST_THREAD_POOL, stats); + EmbeddedChannel channel = new EmbeddedChannel(getFrameDecoder(), handler); + + // Register state + ValueStateDescriptor desc = new ValueStateDescriptor<>("any", IntSerializer.INSTANCE, null); + desc.setQueryable("any"); + + MemValueState kvState = new MemValueState<>( + IntSerializer.INSTANCE, + VoidNamespaceSerializer.INSTANCE, + desc); + + KvStateID kvStateId = registry.registerKvState( + new JobID(), + new JobVertexID(), + 0, + "vanilla", + kvState); + + byte[] serializedKeyAndNamespace = KvStateRequestSerializer.serializeKeyAndNamespace( + 1238283, + IntSerializer.INSTANCE, + VoidNamespace.INSTANCE, + VoidNamespaceSerializer.INSTANCE); + + long requestId = Integer.MAX_VALUE + 22982L; + ByteBuf request = KvStateRequestSerializer.serializeKvStateRequest( + channel.alloc(), + requestId, + kvStateId, + serializedKeyAndNamespace); + + // Write the request and wait for the response + channel.writeInbound(request); + + ByteBuf buf = (ByteBuf) readInboundBlocking(channel); + buf.skipBytes(4); // skip frame length + + // Verify the response + assertEquals(KvStateRequestType.REQUEST_FAILURE, KvStateRequestSerializer.deserializeHeader(buf)); + KvStateRequestFailure response = KvStateRequestSerializer.deserializeKvStateRequestFailure(buf); + + assertEquals(requestId, response.getRequestId()); + + assertTrue("Did not respond with expected failure cause", response.getCause() instanceof UnknownKeyOrNamespace); + + assertEquals(1, stats.getNumRequests()); + assertEquals(1, stats.getNumFailed()); + } + + /** + * Tests the failure response on a failure on the {@link KvState#getSerializedValue(byte[])} + * call. + */ + @Test + public void testFailureOnGetSerializedValue() throws Exception { + KvStateRegistry registry = new KvStateRegistry(); + AtomicKvStateRequestStats stats = new AtomicKvStateRequestStats(); + + KvStateServerHandler handler = new KvStateServerHandler(registry, TEST_THREAD_POOL, stats); + EmbeddedChannel channel = new EmbeddedChannel(getFrameDecoder(), handler); + + // Failing KvState + KvState kvState = mock(KvState.class); + when(kvState.getSerializedValue(any(byte[].class))) + .thenThrow(new RuntimeException("Expected test Exception")); + + KvStateID kvStateId = registry.registerKvState( + new JobID(), + new JobVertexID(), + 0, + "vanilla", + kvState); + + ByteBuf request = KvStateRequestSerializer.serializeKvStateRequest( + channel.alloc(), + 282872, + kvStateId, + new byte[0]); + + // Write the request and wait for the response + channel.writeInbound(request); + + ByteBuf buf = (ByteBuf) readInboundBlocking(channel); + buf.skipBytes(4); // skip frame length + + // Verify the response + assertEquals(KvStateRequestType.REQUEST_FAILURE, KvStateRequestSerializer.deserializeHeader(buf)); + KvStateRequestFailure response = KvStateRequestSerializer.deserializeKvStateRequestFailure(buf); + + assertTrue(response.getCause().getMessage().contains("Expected test Exception")); + + assertEquals(1, stats.getNumRequests()); + assertEquals(1, stats.getNumFailed()); + } + + /** + * Tests that the channel is closed if an Exception reaches the channel + * handler. + */ + @Test + public void testCloseChannelOnExceptionCaught() throws Exception { + KvStateRegistry registry = new KvStateRegistry(); + AtomicKvStateRequestStats stats = new AtomicKvStateRequestStats(); + + KvStateServerHandler handler = new KvStateServerHandler(registry, TEST_THREAD_POOL, stats); + EmbeddedChannel channel = new EmbeddedChannel(handler); + + channel.pipeline().fireExceptionCaught(new RuntimeException("Expected test Exception")); + + ByteBuf buf = (ByteBuf) readInboundBlocking(channel); + buf.skipBytes(4); // skip frame length + + // Verify the response + assertEquals(KvStateRequestType.SERVER_FAILURE, KvStateRequestSerializer.deserializeHeader(buf)); + Throwable response = KvStateRequestSerializer.deserializeServerFailure(buf); + + assertTrue(response.getMessage().contains("Expected test Exception")); + + channel.closeFuture().await(READ_TIMEOUT_MILLIS); + assertFalse(channel.isActive()); + } + + /** + * Tests the failure response on a rejected execution, because the query + * executor has been closed. + */ + @Test + public void testQueryExecutorShutDown() throws Exception { + KvStateRegistry registry = new KvStateRegistry(); + AtomicKvStateRequestStats stats = new AtomicKvStateRequestStats(); + + ExecutorService closedExecutor = Executors.newSingleThreadExecutor(); + closedExecutor.shutdown(); + assertTrue(closedExecutor.isShutdown()); + + KvStateServerHandler handler = new KvStateServerHandler(registry, closedExecutor, stats); + EmbeddedChannel channel = new EmbeddedChannel(getFrameDecoder(), handler); + + // Register state + ValueStateDescriptor desc = new ValueStateDescriptor<>("any", IntSerializer.INSTANCE, null); + desc.setQueryable("any"); + + MemValueState kvState = new MemValueState<>( + IntSerializer.INSTANCE, + VoidNamespaceSerializer.INSTANCE, + desc); + + KvStateID kvStateId = registry.registerKvState( + new JobID(), + new JobVertexID(), + 0, + "vanilla", + kvState); + + ByteBuf request = KvStateRequestSerializer.serializeKvStateRequest( + channel.alloc(), + 282872, + kvStateId, + new byte[0]); + + // Write the request and wait for the response + channel.writeInbound(request); + + ByteBuf buf = (ByteBuf) readInboundBlocking(channel); + buf.skipBytes(4); // skip frame length + + // Verify the response + assertEquals(KvStateRequestType.REQUEST_FAILURE, KvStateRequestSerializer.deserializeHeader(buf)); + KvStateRequestFailure response = KvStateRequestSerializer.deserializeKvStateRequestFailure(buf); + + assertTrue(response.getCause().getMessage().contains("RejectedExecutionException")); + + assertEquals(1, stats.getNumRequests()); + assertEquals(1, stats.getNumFailed()); + } + + /** + * Tests response on unexpected messages. + */ + @Test + public void testUnexpectedMessage() throws Exception { + KvStateRegistry registry = new KvStateRegistry(); + AtomicKvStateRequestStats stats = new AtomicKvStateRequestStats(); + + KvStateServerHandler handler = new KvStateServerHandler(registry, TEST_THREAD_POOL, stats); + EmbeddedChannel channel = new EmbeddedChannel(getFrameDecoder(), handler); + + // Write the request and wait for the response + ByteBuf unexpectedMessage = Unpooled.buffer(8); + unexpectedMessage.writeInt(4); + unexpectedMessage.writeInt(123238213); + + channel.writeInbound(unexpectedMessage); + + ByteBuf buf = (ByteBuf) readInboundBlocking(channel); + buf.skipBytes(4); // skip frame length + + // Verify the response + assertEquals(KvStateRequestType.SERVER_FAILURE, KvStateRequestSerializer.deserializeHeader(buf)); + Throwable response = KvStateRequestSerializer.deserializeServerFailure(buf); + + assertEquals(0, stats.getNumRequests()); + assertEquals(0, stats.getNumFailed()); + + unexpectedMessage = KvStateRequestSerializer.serializeKvStateRequestResult( + channel.alloc(), + 192, + new byte[0]); + + channel.writeInbound(unexpectedMessage); + + buf = (ByteBuf) readInboundBlocking(channel); + buf.skipBytes(4); // skip frame length + + // Verify the response + assertEquals(KvStateRequestType.SERVER_FAILURE, KvStateRequestSerializer.deserializeHeader(buf)); + response = KvStateRequestSerializer.deserializeServerFailure(buf); + + assertTrue("Unexpected failure cause " + response.getClass().getName(), response instanceof IllegalArgumentException); + + assertEquals(0, stats.getNumRequests()); + assertEquals(0, stats.getNumFailed()); + } + + /** + * Tests that incoming buffer instances are recycled. + */ + @Test + public void testIncomingBufferIsRecycled() throws Exception { + KvStateRegistry registry = new KvStateRegistry(); + AtomicKvStateRequestStats stats = new AtomicKvStateRequestStats(); + + KvStateServerHandler handler = new KvStateServerHandler(registry, TEST_THREAD_POOL, stats); + EmbeddedChannel channel = new EmbeddedChannel(getFrameDecoder(), handler); + + ByteBuf request = KvStateRequestSerializer.serializeKvStateRequest( + channel.alloc(), + 282872, + new KvStateID(), + new byte[0]); + + assertEquals(1, request.refCnt()); + + // Write regular request + channel.writeInbound(request); + assertEquals("Buffer not recycled", 0, request.refCnt()); + + // Write unexpected msg + ByteBuf unexpected = channel.alloc().buffer(8); + unexpected.writeInt(4); + unexpected.writeInt(4); + + assertEquals(1, unexpected.refCnt()); + + channel.writeInbound(unexpected); + assertEquals("Buffer not recycled", 0, unexpected.refCnt()); + } + + /** + * Tests the failure response if the serializers don't match. + */ + @Test + public void testSerializerMismatch() throws Exception { + KvStateRegistry registry = new KvStateRegistry(); + AtomicKvStateRequestStats stats = new AtomicKvStateRequestStats(); + + KvStateServerHandler handler = new KvStateServerHandler(registry, TEST_THREAD_POOL, stats); + EmbeddedChannel channel = new EmbeddedChannel(getFrameDecoder(), handler); + + // Register state + ValueStateDescriptor desc = new ValueStateDescriptor<>("any", IntSerializer.INSTANCE, null); + desc.setQueryable("any"); + + MemValueState kvState = new MemValueState<>( + IntSerializer.INSTANCE, + VoidNamespaceSerializer.INSTANCE, + desc); + + KvStateID kvStateId = registry.registerKvState( + new JobID(), + new JobVertexID(), + 0, + "vanilla", + kvState); + + int key = 99812822; + + // Update the KvState + kvState.setCurrentKey(key); + kvState.setCurrentNamespace(VoidNamespace.INSTANCE); + kvState.update(712828289); + + byte[] wrongKeyAndNamespace = KvStateRequestSerializer.serializeKeyAndNamespace( + "wrong-key-type", + StringSerializer.INSTANCE, + "wrong-namespace-type", + StringSerializer.INSTANCE); + + byte[] wrongNamespace = KvStateRequestSerializer.serializeKeyAndNamespace( + key, + IntSerializer.INSTANCE, + "wrong-namespace-type", + StringSerializer.INSTANCE); + + ByteBuf request = KvStateRequestSerializer.serializeKvStateRequest( + channel.alloc(), + 182828, + kvStateId, + wrongKeyAndNamespace); + + // Write the request and wait for the response + channel.writeInbound(request); + + ByteBuf buf = (ByteBuf) readInboundBlocking(channel); + buf.skipBytes(4); // skip frame length + + // Verify the response + assertEquals(KvStateRequestType.REQUEST_FAILURE, KvStateRequestSerializer.deserializeHeader(buf)); + KvStateRequestFailure response = KvStateRequestSerializer.deserializeKvStateRequestFailure(buf); + assertEquals(182828, response.getRequestId()); + assertTrue(response.getCause().getMessage().contains("IllegalArgumentException")); + + // Repeat with wrong namespace only + request = KvStateRequestSerializer.serializeKvStateRequest( + channel.alloc(), + 182829, + kvStateId, + wrongNamespace); + + // Write the request and wait for the response + channel.writeInbound(request); + + buf = (ByteBuf) readInboundBlocking(channel); + buf.skipBytes(4); // skip frame length + + // Verify the response + assertEquals(KvStateRequestType.REQUEST_FAILURE, KvStateRequestSerializer.deserializeHeader(buf)); + response = KvStateRequestSerializer.deserializeKvStateRequestFailure(buf); + assertEquals(182829, response.getRequestId()); + assertTrue(response.getCause().getMessage().contains("IllegalArgumentException")); + + assertEquals(2, stats.getNumRequests()); + assertEquals(2, stats.getNumFailed()); + } + + /** + * Tests that large responses are chunked. + */ + @Test + public void testChunkedResponse() throws Exception { + KvStateRegistry registry = new KvStateRegistry(); + KvStateRequestStats stats = new AtomicKvStateRequestStats(); + + KvStateServerHandler handler = new KvStateServerHandler(registry, TEST_THREAD_POOL, stats); + EmbeddedChannel channel = new EmbeddedChannel(getFrameDecoder(), handler); + + // Register state + ValueStateDescriptor desc = new ValueStateDescriptor<>("any", BytePrimitiveArraySerializer.INSTANCE, null); + desc.setQueryable("any"); + + MemValueState kvState = new MemValueState<>( + IntSerializer.INSTANCE, + VoidNamespaceSerializer.INSTANCE, + desc); + + KvStateID kvStateId = registry.registerKvState( + new JobID(), + new JobVertexID(), + 0, + "vanilla", + kvState); + + // Update KvState + byte[] bytes = new byte[2 * channel.config().getWriteBufferHighWaterMark()]; + + byte current = 0; + for (int i = 0; i < bytes.length; i++) { + bytes[i] = current++; + } + + int key = 99812822; + kvState.setCurrentKey(key); + kvState.setCurrentNamespace(VoidNamespace.INSTANCE); + kvState.update(bytes); + + // Request + byte[] serializedKeyAndNamespace = KvStateRequestSerializer.serializeKeyAndNamespace( + key, + IntSerializer.INSTANCE, + VoidNamespace.INSTANCE, + VoidNamespaceSerializer.INSTANCE); + + long requestId = Integer.MAX_VALUE + 182828L; + ByteBuf request = KvStateRequestSerializer.serializeKvStateRequest( + channel.alloc(), + requestId, + kvStateId, + serializedKeyAndNamespace); + + // Write the request and wait for the response + channel.writeInbound(request); + + Object msg = readInboundBlocking(channel); + assertTrue("Not ChunkedByteBuf", msg instanceof ChunkedByteBuf); + } + + // ------------------------------------------------------------------------ + + /** + * Queries the embedded channel for data. + */ + private Object readInboundBlocking(EmbeddedChannel channel) throws InterruptedException, TimeoutException { + final int sleepMillis = 50; + + int sleptMillis = 0; + + Object msg = null; + while (sleptMillis < READ_TIMEOUT_MILLIS && + (msg = channel.readOutbound()) == null) { + + Thread.sleep(sleepMillis); + sleptMillis += sleepMillis; + } + + if (msg == null) { + throw new TimeoutException(); + } else { + return msg; + } + } + + /** + * Frame length decoder (expected by the serialized messages). + */ + private ChannelHandler getFrameDecoder() { + return new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4); + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateServerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateServerTest.java new file mode 100644 index 0000000000000..d653f73ee3a22 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateServerTest.java @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.query.netty; + +import io.netty.bootstrap.Bootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.SocketChannel; +import io.netty.channel.socket.nio.NioSocketChannel; +import io.netty.handler.codec.LengthFieldBasedFrameDecoder; +import org.apache.flink.api.common.JobID; +import org.apache.flink.api.common.state.ValueStateDescriptor; +import org.apache.flink.api.common.typeutils.base.IntSerializer; +import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.query.KvStateID; +import org.apache.flink.runtime.query.KvStateRegistry; +import org.apache.flink.runtime.query.KvStateServerAddress; +import org.apache.flink.runtime.query.netty.message.KvStateRequestResult; +import org.apache.flink.runtime.query.netty.message.KvStateRequestSerializer; +import org.apache.flink.runtime.query.netty.message.KvStateRequestType; +import org.apache.flink.runtime.state.VoidNamespace; +import org.apache.flink.runtime.state.VoidNamespaceSerializer; +import org.apache.flink.runtime.state.memory.MemValueState; +import org.junit.AfterClass; +import org.junit.Test; + +import java.net.InetAddress; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; + +import static org.junit.Assert.assertEquals; + +public class KvStateServerTest { + + // Thread pool for client bootstrap (shared between tests) + private static final NioEventLoopGroup NIO_GROUP = new NioEventLoopGroup(); + + private final static int TIMEOUT_MILLIS = 10000; + + @AfterClass + public static void tearDown() throws Exception { + if (NIO_GROUP != null) { + NIO_GROUP.shutdownGracefully(); + } + } + + /** + * Tests a simple successful query via a SocketChannel. + */ + @Test + public void testSimpleRequest() throws Exception { + KvStateServer server = null; + Bootstrap bootstrap = null; + + try { + KvStateRegistry registry = new KvStateRegistry(); + KvStateRequestStats stats = new AtomicKvStateRequestStats(); + + server = new KvStateServer(InetAddress.getLocalHost(), 0, 1, 1, registry, stats); + server.start(); + + KvStateServerAddress serverAddress = server.getAddress(); + + // Register state + MemValueState kvState = new MemValueState<>( + IntSerializer.INSTANCE, + VoidNamespaceSerializer.INSTANCE, + new ValueStateDescriptor<>("any", IntSerializer.INSTANCE, null)); + + KvStateID kvStateId = registry.registerKvState( + new JobID(), + new JobVertexID(), + 0, + "vanilla", + kvState); + + // Update KvState + int expectedValue = 712828289; + + int key = 99812822; + kvState.setCurrentKey(key); + kvState.setCurrentNamespace(VoidNamespace.INSTANCE); + kvState.update(expectedValue); + + // Request + byte[] serializedKeyAndNamespace = KvStateRequestSerializer.serializeKeyAndNamespace( + key, + IntSerializer.INSTANCE, + VoidNamespace.INSTANCE, + VoidNamespaceSerializer.INSTANCE); + + // Connect to the server + final BlockingQueue responses = new LinkedBlockingQueue<>(); + bootstrap = createBootstrap( + new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4), + new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + responses.add((ByteBuf) msg); + } + }); + + Channel channel = bootstrap + .connect(serverAddress.getHost(), serverAddress.getPort()) + .sync().channel(); + + long requestId = Integer.MAX_VALUE + 182828L; + ByteBuf request = KvStateRequestSerializer.serializeKvStateRequest( + channel.alloc(), + requestId, + kvStateId, + serializedKeyAndNamespace); + + channel.writeAndFlush(request); + + ByteBuf buf = responses.poll(TIMEOUT_MILLIS, TimeUnit.MILLISECONDS); + + assertEquals(KvStateRequestType.REQUEST_RESULT, KvStateRequestSerializer.deserializeHeader(buf)); + KvStateRequestResult response = KvStateRequestSerializer.deserializeKvStateRequestResult(buf); + + assertEquals(requestId, response.getRequestId()); + int actualValue = KvStateRequestSerializer.deserializeValue(response.getSerializedResult(), IntSerializer.INSTANCE); + assertEquals(expectedValue, actualValue); + } finally { + if (server != null) { + server.shutDown(); + } + + if (bootstrap != null) { + EventLoopGroup group = bootstrap.group(); + if (group != null) { + group.shutdownGracefully(); + } + } + } + } + + /** + * Creates a client bootstrap. + */ + private Bootstrap createBootstrap(final ChannelHandler... handlers) { + return new Bootstrap().group(NIO_GROUP).channel(NioSocketChannel.class) + .handler(new ChannelInitializer() { + @Override + protected void initChannel(SocketChannel ch) throws Exception { + ch.pipeline().addLast(handlers); + } + }); + } + +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/message/KvStateRequestSerializerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/message/KvStateRequestSerializerTest.java new file mode 100644 index 0000000000000..a68c84b214304 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/message/KvStateRequestSerializerTest.java @@ -0,0 +1,258 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.query.netty.message; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.UnpooledByteBufAllocator; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.common.typeutils.base.LongSerializer; +import org.apache.flink.api.common.typeutils.base.StringSerializer; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.runtime.query.KvStateID; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ThreadLocalRandom; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; + +public class KvStateRequestSerializerTest { + + private final ByteBufAllocator alloc = UnpooledByteBufAllocator.DEFAULT; + + /** + * Tests KvState request serialization. + */ + @Test + public void testKvStateRequestSerialization() throws Exception { + long requestId = Integer.MAX_VALUE + 1337L; + KvStateID kvStateId = new KvStateID(); + byte[] serializedKeyAndNamespace = randomByteArray(1024); + + ByteBuf buf = KvStateRequestSerializer.serializeKvStateRequest( + alloc, + requestId, + kvStateId, + serializedKeyAndNamespace); + + int frameLength = buf.readInt(); + assertEquals(KvStateRequestType.REQUEST, KvStateRequestSerializer.deserializeHeader(buf)); + KvStateRequest request = KvStateRequestSerializer.deserializeKvStateRequest(buf); + assertEquals(buf.readerIndex(), frameLength + 4); + + assertEquals(requestId, request.getRequestId()); + assertEquals(kvStateId, request.getKvStateId()); + assertArrayEquals(serializedKeyAndNamespace, request.getSerializedKeyAndNamespace()); + } + + /** + * Tests KvState request serialization with zero-length serialized key and namespace. + */ + @Test + public void testKvStateRequestSerializationWithZeroLengthKeyAndNamespace() throws Exception { + byte[] serializedKeyAndNamespace = new byte[0]; + + ByteBuf buf = KvStateRequestSerializer.serializeKvStateRequest( + alloc, + 1823, + new KvStateID(), + serializedKeyAndNamespace); + + int frameLength = buf.readInt(); + assertEquals(KvStateRequestType.REQUEST, KvStateRequestSerializer.deserializeHeader(buf)); + KvStateRequest request = KvStateRequestSerializer.deserializeKvStateRequest(buf); + assertEquals(buf.readerIndex(), frameLength + 4); + + assertArrayEquals(serializedKeyAndNamespace, request.getSerializedKeyAndNamespace()); + } + + /** + * Tests that we don't try to be smart about null key and namespace. + * They should be treated explicitly. + */ + @Test(expected = NullPointerException.class) + public void testNullPointerExceptionOnNullSerializedKeyAndNamepsace() throws Exception { + new KvStateRequest(0, new KvStateID(), null); + } + + /** + * Tests KvState request result serialization. + */ + @Test + public void testKvStateRequestResultSerialization() throws Exception { + long requestId = Integer.MAX_VALUE + 72727278L; + byte[] serializedResult = randomByteArray(1024); + + ByteBuf buf = KvStateRequestSerializer.serializeKvStateRequestResult( + alloc, + requestId, + serializedResult); + + int frameLength = buf.readInt(); + assertEquals(KvStateRequestType.REQUEST_RESULT, KvStateRequestSerializer.deserializeHeader(buf)); + KvStateRequestResult request = KvStateRequestSerializer.deserializeKvStateRequestResult(buf); + assertEquals(buf.readerIndex(), frameLength + 4); + + assertEquals(requestId, request.getRequestId()); + + assertArrayEquals(serializedResult, request.getSerializedResult()); + } + + /** + * Tests KvState request result serialization with zero-length serialized result. + */ + @Test + public void testKvStateRequestResultSerializationWithZeroLengthSerializedResult() throws Exception { + byte[] serializedResult = new byte[0]; + + ByteBuf buf = KvStateRequestSerializer.serializeKvStateRequestResult( + alloc, + 72727278, + serializedResult); + + int frameLength = buf.readInt(); + + assertEquals(KvStateRequestType.REQUEST_RESULT, KvStateRequestSerializer.deserializeHeader(buf)); + KvStateRequestResult request = KvStateRequestSerializer.deserializeKvStateRequestResult(buf); + assertEquals(buf.readerIndex(), frameLength + 4); + + assertArrayEquals(serializedResult, request.getSerializedResult()); + } + + /** + * Tests that we don't try to be smart about null results. + * They should be treated explicitly. + */ + @Test(expected = NullPointerException.class) + public void testNullPointerExceptionOnNullSerializedResult() throws Exception { + new KvStateRequestResult(0, null); + } + + /** + * Tests KvState request failure serialization. + */ + @Test + public void testKvStateRequestFailureSerialization() throws Exception { + long requestId = Integer.MAX_VALUE + 1111222L; + IllegalStateException cause = new IllegalStateException("Expected test"); + + ByteBuf buf = KvStateRequestSerializer.serializeKvStateRequestFailure( + alloc, + requestId, + cause); + + int frameLength = buf.readInt(); + assertEquals(KvStateRequestType.REQUEST_FAILURE, KvStateRequestSerializer.deserializeHeader(buf)); + KvStateRequestFailure request = KvStateRequestSerializer.deserializeKvStateRequestFailure(buf); + assertEquals(buf.readerIndex(), frameLength + 4); + + assertEquals(requestId, request.getRequestId()); + assertEquals(cause.getClass(), request.getCause().getClass()); + assertEquals(cause.getMessage(), request.getCause().getMessage()); + } + + /** + * Tests KvState server failure serialization. + */ + @Test + public void testServerFailureSerialization() throws Exception { + IllegalStateException cause = new IllegalStateException("Expected test"); + + ByteBuf buf = KvStateRequestSerializer.serializeServerFailure(alloc, cause); + + int frameLength = buf.readInt(); + assertEquals(KvStateRequestType.SERVER_FAILURE, KvStateRequestSerializer.deserializeHeader(buf)); + Throwable request = KvStateRequestSerializer.deserializeServerFailure(buf); + assertEquals(buf.readerIndex(), frameLength + 4); + + assertEquals(cause.getClass(), request.getClass()); + assertEquals(cause.getMessage(), request.getMessage()); + } + + /** + * Tests key and namespace serialization utils. + */ + @Test + public void testKeyAndNamespaceSerialization() throws Exception { + TypeSerializer keySerializer = LongSerializer.INSTANCE; + TypeSerializer namespaceSerializer = StringSerializer.INSTANCE; + + long expectedKey = Integer.MAX_VALUE + 12323L; + String expectedNamespace = "knilf"; + + byte[] serializedKeyAndNamespace = KvStateRequestSerializer.serializeKeyAndNamespace( + expectedKey, keySerializer, expectedNamespace, namespaceSerializer); + + Tuple2 actual = KvStateRequestSerializer.deserializeKeyAndNamespace( + serializedKeyAndNamespace, keySerializer, namespaceSerializer); + + assertEquals(expectedKey, actual.f0.longValue()); + assertEquals(expectedNamespace, actual.f1); + } + + /** + * Tests value serialization utils. + */ + @Test + public void testValueSerialization() throws Exception { + TypeSerializer valueSerializer = LongSerializer.INSTANCE; + long expectedValue = Long.MAX_VALUE - 1292929292L; + + byte[] serializedValue = KvStateRequestSerializer.serializeValue(expectedValue, valueSerializer); + long actualValue = KvStateRequestSerializer.deserializeValue(serializedValue, valueSerializer); + + assertEquals(expectedValue, actualValue); + } + + /** + * Tests list serialization utils. + */ + @Test + public void testListSerialization() throws Exception { + TypeSerializer valueSerializer = LongSerializer.INSTANCE; + + // List + int numElements = 10; + + List expectedValues = new ArrayList<>(); + for (int i = 0; i < numElements; i++) { + expectedValues.add(ThreadLocalRandom.current().nextLong()); + } + + byte[] serializedValues = KvStateRequestSerializer.serializeList(expectedValues, valueSerializer); + List actualValues = KvStateRequestSerializer.deserializeList(serializedValues, valueSerializer); + assertEquals(expectedValues, actualValues); + + // Single value + long expectedValue = ThreadLocalRandom.current().nextLong(); + byte[] serializedValue = KvStateRequestSerializer.serializeValue(expectedValue, valueSerializer); + List actualValue = KvStateRequestSerializer.deserializeList(serializedValue, valueSerializer); + assertEquals(1, actualValue.size()); + assertEquals(expectedValue, actualValue.get(0).longValue()); + } + + private byte[] randomByteArray(int capacity) { + byte[] bytes = new byte[capacity]; + ThreadLocalRandom.current().nextBytes(bytes); + return bytes; + } +} From 7b927e15c360607824f92d9904a0643ad03186c5 Mon Sep 17 00:00:00 2001 From: Ufuk Celebi Date: Mon, 30 May 2016 14:08:03 +0200 Subject: [PATCH 4/6] [FLINK-3779] [runtime] Add KvStateLocation lookup service - Adds an Akka-based KvStateLocation lookup service to be used by the client to look up location information. --- .../AkkaKvStateLocationLookupService.java | 320 +++++++++++++++ .../query/KvStateLocationLookupService.java | 49 +++ .../runtime/query/UnknownJobManager.java | 33 ++ .../AkkaKvStateLocationLookupServiceTest.java | 383 ++++++++++++++++++ 4 files changed, 785 insertions(+) create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/query/AkkaKvStateLocationLookupService.java create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/query/KvStateLocationLookupService.java create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/query/UnknownJobManager.java create mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/query/AkkaKvStateLocationLookupServiceTest.java diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/query/AkkaKvStateLocationLookupService.java b/flink-runtime/src/main/java/org/apache/flink/runtime/query/AkkaKvStateLocationLookupService.java new file mode 100644 index 0000000000000..ed93b2a28831a --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/query/AkkaKvStateLocationLookupService.java @@ -0,0 +1,320 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.query; + +import akka.actor.ActorRef; +import akka.actor.ActorSystem; +import akka.dispatch.Futures; +import akka.dispatch.Mapper; +import akka.dispatch.Recover; +import akka.pattern.Patterns; +import org.apache.flink.api.common.JobID; +import org.apache.flink.runtime.akka.AkkaUtils; +import org.apache.flink.runtime.instance.ActorGateway; +import org.apache.flink.runtime.instance.AkkaActorGateway; +import org.apache.flink.runtime.leaderretrieval.LeaderRetrievalListener; +import org.apache.flink.runtime.leaderretrieval.LeaderRetrievalService; +import org.apache.flink.util.Preconditions; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import scala.concurrent.Future; +import scala.concurrent.duration.FiniteDuration; +import scala.reflect.ClassTag$; + +import java.util.UUID; +import java.util.concurrent.Callable; + +/** + * Akka-based {@link KvStateLocationLookupService} that retrieves the current + * JobManager address and uses it for lookups. + */ +class AkkaKvStateLocationLookupService implements KvStateLocationLookupService, LeaderRetrievalListener { + + private static final Logger LOG = LoggerFactory.getLogger(KvStateLocationLookupService.class); + + /** Future returned when no JobManager is available */ + private static final Future UNKNOWN_JOB_MANAGER = Futures.failed(new UnknownJobManager()); + + /** Leader retrieval service to retrieve the current job manager. */ + private final LeaderRetrievalService leaderRetrievalService; + + /** The actor system used to resolve the JobManager address. */ + private final ActorSystem actorSystem; + + /** Timeout for JobManager ask-requests. */ + private final FiniteDuration askTimeout; + + /** Retry strategy factory on future failures. */ + private final LookupRetryStrategyFactory retryStrategyFactory; + + /** Current job manager future. */ + private volatile Future jobManagerFuture = UNKNOWN_JOB_MANAGER; + + /** + * Creates the Akka-based {@link KvStateLocationLookupService}. + * + * @param leaderRetrievalService Leader retrieval service to use. + * @param actorSystem Actor system to use. + * @param askTimeout Timeout for JobManager ask-requests. + * @param retryStrategyFactory Retry strategy if no JobManager available. + */ + AkkaKvStateLocationLookupService( + LeaderRetrievalService leaderRetrievalService, + ActorSystem actorSystem, + FiniteDuration askTimeout, + LookupRetryStrategyFactory retryStrategyFactory) { + + this.leaderRetrievalService = Preconditions.checkNotNull(leaderRetrievalService, "Leader retrieval service"); + this.actorSystem = Preconditions.checkNotNull(actorSystem, "Actor system"); + this.askTimeout = Preconditions.checkNotNull(askTimeout, "Ask Timeout"); + this.retryStrategyFactory = Preconditions.checkNotNull(retryStrategyFactory, "Retry strategy factory"); + } + + public void start() { + try { + leaderRetrievalService.start(this); + } catch (Exception e) { + LOG.error("Failed to start leader retrieval service", e); + throw new RuntimeException(e); + } + } + + public void shutDown() { + try { + leaderRetrievalService.stop(); + } catch (Exception e) { + LOG.error("Failed to stop leader retrieval service", e); + throw new RuntimeException(e); + } + } + + @Override + @SuppressWarnings("unchecked") + public Future getKvStateLookupInfo(final JobID jobId, final String registrationName) { + return getKvStateLookupInfo(jobId, registrationName, retryStrategyFactory.createRetryStrategy()); + } + + /** + * Returns a future holding the {@link KvStateLocation} for the given job + * and KvState registration name. + * + *

If there is currently no JobManager registered with the service, the + * request is retried. The retry behaviour is specified by the + * {@link LookupRetryStrategy} of the lookup service. + * + * @param jobId JobID the KvState instance belongs to + * @param registrationName Name under which the KvState has been registered + * @param lookupRetryStrategy Retry strategy to use for retries on UnknownJobManager failures. + * @return Future holding the {@link KvStateLocation} + */ + @SuppressWarnings("unchecked") + private Future getKvStateLookupInfo( + final JobID jobId, + final String registrationName, + final LookupRetryStrategy lookupRetryStrategy) { + + return jobManagerFuture + .flatMap(new Mapper>() { + @Override + public Future apply(ActorGateway jobManager) { + // Lookup the KvStateLocation + Object msg = new KvStateMessage.LookupKvStateLocation(jobId, registrationName); + return jobManager.ask(msg, askTimeout); + } + }, actorSystem.dispatcher()) + .mapTo(ClassTag$.MODULE$.apply(KvStateLocation.class)) + .recoverWith(new Recover>() { + @Override + public Future recover(Throwable failure) throws Throwable { + // If the Future fails with UnknownJobManager, retry + // the request. Otherwise all Futures will be failed + // during the start up phase, when the JobManager did + // not notify this service yet or leadership is lost + // intermittently. + if (failure instanceof UnknownJobManager && lookupRetryStrategy.tryRetry()) { + return Patterns.after( + lookupRetryStrategy.getRetryDelay(), + actorSystem.scheduler(), + actorSystem.dispatcher(), + new Callable>() { + @Override + public Future call() throws Exception { + return getKvStateLookupInfo( + jobId, + registrationName, + lookupRetryStrategy); + } + }); + } else { + return Futures.failed(failure); + } + } + }, actorSystem.dispatcher()); + } + + @Override + public void notifyLeaderAddress(String leaderAddress, final UUID leaderSessionID) { + if (LOG.isDebugEnabled()) { + LOG.debug("Received leader address notification {}:{}", leaderAddress, leaderSessionID); + } + + if (leaderAddress == null) { + jobManagerFuture = UNKNOWN_JOB_MANAGER; + } else { + jobManagerFuture = AkkaUtils.getActorRefFuture(leaderAddress, actorSystem, askTimeout) + .map(new Mapper() { + @Override + public ActorGateway apply(ActorRef actorRef) { + return new AkkaActorGateway(actorRef, leaderSessionID); + } + }, actorSystem.dispatcher()); + } + } + + @Override + public void handleError(Exception exception) { + jobManagerFuture = Futures.failed(exception); + } + + // ------------------------------------------------------------------------ + + /** + * Retry strategy for failed lookups. + * + *

Usage: + *

+	 * LookupRetryStrategy retryStrategy = LookupRetryStrategyFactory.create();
+	 *
+	 * if (retryStrategy.tryRetry()) {
+	 *     // OK to retry
+	 *     FiniteDuration retryDelay = retryStrategy.getRetryDelay();
+	 * }
+	 * 
+ */ + interface LookupRetryStrategy { + + /** + * Returns the current retry. + * + * @return Current retry delay. + */ + FiniteDuration getRetryDelay(); + + /** + * Tries another retry and returns whether it is allowed or not. + * + * @return Whether it is allowed to do another restart or not. + */ + boolean tryRetry(); + + } + + /** + * Factory for retry strategies. + */ + interface LookupRetryStrategyFactory { + + /** + * Creates a new retry strategy. + * + * @return The retry strategy. + */ + LookupRetryStrategy createRetryStrategy(); + + } + + /** + * Factory for disabled retries. + */ + static class DisabledLookupRetryStrategyFactory implements LookupRetryStrategyFactory { + + private static final DisabledLookupRetryStrategy RETRY_STRATEGY = new DisabledLookupRetryStrategy(); + + @Override + public LookupRetryStrategy createRetryStrategy() { + return RETRY_STRATEGY; + } + + private static class DisabledLookupRetryStrategy implements LookupRetryStrategy { + + @Override + public FiniteDuration getRetryDelay() { + return FiniteDuration.Zero(); + } + + @Override + public boolean tryRetry() { + return false; + } + } + + } + + /** + * Factory for fixed delay retries. + */ + static class FixedDelayLookupRetryStrategyFactory implements LookupRetryStrategyFactory { + + private final int maxRetries; + private final FiniteDuration retryDelay; + + FixedDelayLookupRetryStrategyFactory(int maxRetries, FiniteDuration retryDelay) { + this.maxRetries = maxRetries; + this.retryDelay = retryDelay; + } + + @Override + public LookupRetryStrategy createRetryStrategy() { + return new FixedDelayLookupRetryStrategy(maxRetries, retryDelay); + } + + private static class FixedDelayLookupRetryStrategy implements LookupRetryStrategy { + + private final Object retryLock = new Object(); + private final int maxRetries; + private final FiniteDuration retryDelay; + private int numRetries; + + public FixedDelayLookupRetryStrategy(int maxRetries, FiniteDuration retryDelay) { + Preconditions.checkArgument(maxRetries >= 0, "Negative number maximum retries"); + this.maxRetries = maxRetries; + this.retryDelay = Preconditions.checkNotNull(retryDelay, "Retry delay"); + } + + @Override + public FiniteDuration getRetryDelay() { + synchronized (retryLock) { + return retryDelay; + } + } + + @Override + public boolean tryRetry() { + synchronized (retryLock) { + if (numRetries < maxRetries) { + numRetries++; + return true; + } else { + return false; + } + } + } + } + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/query/KvStateLocationLookupService.java b/flink-runtime/src/main/java/org/apache/flink/runtime/query/KvStateLocationLookupService.java new file mode 100644 index 0000000000000..cce432eeb2a07 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/query/KvStateLocationLookupService.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.query; + +import org.apache.flink.api.common.JobID; +import scala.concurrent.Future; + +/** + * {@link KvStateLocation} lookup service. + */ +public interface KvStateLocationLookupService { + + /** + * Starts the lookup service. + */ + void start(); + + /** + * Shuts down the lookup service. + */ + void shutDown(); + + /** + * Returns a future holding the {@link KvStateLocation} for the given job + * and KvState registration name. + * + * @param jobId JobID the KvState instance belongs to + * @param registrationName Name under which the KvState has been registered + * @return Future holding the {@link KvStateLocation} + */ + Future getKvStateLookupInfo(JobID jobId, String registrationName); + +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/query/UnknownJobManager.java b/flink-runtime/src/main/java/org/apache/flink/runtime/query/UnknownJobManager.java new file mode 100644 index 0000000000000..3549ed6bab7f8 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/query/UnknownJobManager.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.query; + +/** + * Exception to fail Future with if no JobManager is currently registered at + * the {@link KvStateLocationLookupService}. + */ +class UnknownJobManager extends Exception { + + private static final long serialVersionUID = 1L; + + public UnknownJobManager() { + super("Unknown JobManager. Either the JobManager has not registered yet " + + "or has lost leadership."); + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/query/AkkaKvStateLocationLookupServiceTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/query/AkkaKvStateLocationLookupServiceTest.java new file mode 100644 index 0000000000000..e9950fbd7009b --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/query/AkkaKvStateLocationLookupServiceTest.java @@ -0,0 +1,383 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.query; + +import akka.actor.ActorRef; +import akka.actor.ActorSystem; +import akka.actor.Props; +import akka.actor.Status; +import org.apache.flink.api.common.JobID; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.runtime.akka.AkkaUtils; +import org.apache.flink.runtime.akka.FlinkUntypedActor; +import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.leaderelection.TestingLeaderRetrievalService; +import org.apache.flink.runtime.query.AkkaKvStateLocationLookupService.LookupRetryStrategy; +import org.apache.flink.runtime.query.AkkaKvStateLocationLookupService.LookupRetryStrategyFactory; +import org.apache.flink.runtime.query.KvStateMessage.LookupKvStateLocation; +import org.apache.flink.util.Preconditions; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; +import scala.concurrent.Await; +import scala.concurrent.Future; +import scala.concurrent.duration.FiniteDuration; + +import java.util.ArrayDeque; +import java.util.Queue; +import java.util.UUID; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +public class AkkaKvStateLocationLookupServiceTest { + + /** The default timeout. */ + private static final FiniteDuration TIMEOUT = new FiniteDuration(10, TimeUnit.SECONDS); + + /** Test actor system shared between the tests. */ + private static ActorSystem testActorSystem; + + @BeforeClass + public static void setUp() throws Exception { + testActorSystem = AkkaUtils.createLocalActorSystem(new Configuration()); + } + + @AfterClass + public static void tearDown() throws Exception { + if (testActorSystem != null) { + testActorSystem.shutdown(); + } + } + + /** + * Tests responses if no leader notification has been reported or leadership + * has been lost (leaderAddress = null). + */ + @Test + public void testNoJobManagerRegistered() throws Exception { + TestingLeaderRetrievalService leaderRetrievalService = new TestingLeaderRetrievalService(); + Queue received = new LinkedBlockingQueue<>(); + + AkkaKvStateLocationLookupService lookupService = new AkkaKvStateLocationLookupService( + leaderRetrievalService, + testActorSystem, + TIMEOUT, + new AkkaKvStateLocationLookupService.DisabledLookupRetryStrategyFactory()); + + lookupService.start(); + + // + // No leader registered initially => fail with UnknownJobManager + // + try { + JobID jobId = new JobID(); + String name = "coffee"; + + Future locationFuture = lookupService.getKvStateLookupInfo(jobId, name); + + Await.result(locationFuture, TIMEOUT); + fail("Did not throw expected Exception"); + } catch (UnknownJobManager ignored) { + // Expected + } + + assertEquals("Received unexpected lookup", 0, received.size()); + + // + // Leader registration => communicate with new leader + // + UUID leaderSessionId = null; + KvStateLocation expected = new KvStateLocation(new JobID(), new JobVertexID(), 8282, "tea"); + + ActorRef testActor = LookupResponseActor.create(received, leaderSessionId, expected); + + String testActorAddress = AkkaUtils.getAkkaURL(testActorSystem, testActor); + + // Notify the service about a leader + leaderRetrievalService.notifyListener(testActorAddress, leaderSessionId); + + JobID jobId = new JobID(); + String name = "tea"; + + // Verify that the leader response is handled + KvStateLocation location = Await.result(lookupService.getKvStateLookupInfo(jobId, name), TIMEOUT); + assertEquals(expected, location); + + // Verify that the correct message was sent to the leader + assertEquals(1, received.size()); + + verifyLookupMsg(received.poll(), jobId, name); + + // + // Leader loss => fail with UnknownJobManager + // + leaderRetrievalService.notifyListener(null, null); + + try { + Future locationFuture = lookupService + .getKvStateLookupInfo(new JobID(), "coffee"); + + Await.result(locationFuture, TIMEOUT); + fail("Did not throw expected Exception"); + } catch (UnknownJobManager ignored) { + // Expected + } + + // No new messages received + assertEquals(0, received.size()); + } + + /** + * Tests that messages are properly decorated with the leader session ID. + */ + @Test + public void testLeaderSessionIdChange() throws Exception { + TestingLeaderRetrievalService leaderRetrievalService = new TestingLeaderRetrievalService(); + Queue received = new LinkedBlockingQueue<>(); + + AkkaKvStateLocationLookupService lookupService = new AkkaKvStateLocationLookupService( + leaderRetrievalService, + testActorSystem, + TIMEOUT, + new AkkaKvStateLocationLookupService.DisabledLookupRetryStrategyFactory()); + + lookupService.start(); + + // Create test actors with random leader session IDs + KvStateLocation expected1 = new KvStateLocation(new JobID(), new JobVertexID(), 8282, "salt"); + UUID leaderSessionId1 = UUID.randomUUID(); + ActorRef testActor1 = LookupResponseActor.create(received, leaderSessionId1, expected1); + String testActorAddress1 = AkkaUtils.getAkkaURL(testActorSystem, testActor1); + + KvStateLocation expected2 = new KvStateLocation(new JobID(), new JobVertexID(), 22321, "pepper"); + UUID leaderSessionId2 = UUID.randomUUID(); + ActorRef testActor2 = LookupResponseActor.create(received, leaderSessionId1, expected2); + String testActorAddress2 = AkkaUtils.getAkkaURL(testActorSystem, testActor2); + + JobID jobId = new JobID(); + + // + // Notify about first leader + // + leaderRetrievalService.notifyListener(testActorAddress1, leaderSessionId1); + + KvStateLocation location = Await.result(lookupService.getKvStateLookupInfo(jobId, "rock"), TIMEOUT); + assertEquals(expected1, location); + + assertEquals(1, received.size()); + verifyLookupMsg(received.poll(), jobId, "rock"); + + // + // Notify about second leader + // + leaderRetrievalService.notifyListener(testActorAddress2, leaderSessionId2); + + location = Await.result(lookupService.getKvStateLookupInfo(jobId, "roll"), TIMEOUT); + assertEquals(expected2, location); + + assertEquals(1, received.size()); + verifyLookupMsg(received.poll(), jobId, "roll"); + } + + /** + * Tests that lookups are retried when no leader notification is available. + */ + @Test + public void testRetryOnUnknownJobManager() throws Exception { + final Queue retryStrategies = new LinkedBlockingQueue<>(); + + LookupRetryStrategyFactory retryStrategy = + new LookupRetryStrategyFactory() { + @Override + public LookupRetryStrategy createRetryStrategy() { + return retryStrategies.poll(); + } + }; + + final TestingLeaderRetrievalService leaderRetrievalService = new TestingLeaderRetrievalService(); + + AkkaKvStateLocationLookupService lookupService = new AkkaKvStateLocationLookupService( + leaderRetrievalService, + testActorSystem, + TIMEOUT, + retryStrategy); + + lookupService.start(); + + // + // Test call to retry + // + final AtomicBoolean hasRetried = new AtomicBoolean(); + retryStrategies.add( + new LookupRetryStrategy() { + @Override + public FiniteDuration getRetryDelay() { + return FiniteDuration.Zero(); + } + + @Override + public boolean tryRetry() { + if (hasRetried.compareAndSet(false, true)) { + return true; + } + return false; + } + }); + + Future locationFuture = lookupService.getKvStateLookupInfo(new JobID(), "yessir"); + + Await.ready(locationFuture, TIMEOUT); + assertTrue("Did not retry ", hasRetried.get()); + + // + // Test leader notification after retry + // + Queue received = new LinkedBlockingQueue<>(); + + KvStateLocation expected = new KvStateLocation(new JobID(), new JobVertexID(), 12122, "garlic"); + ActorRef testActor = LookupResponseActor.create(received, null, expected); + final String testActorAddress = AkkaUtils.getAkkaURL(testActorSystem, testActor); + + retryStrategies.add(new LookupRetryStrategy() { + @Override + public FiniteDuration getRetryDelay() { + return FiniteDuration.apply(100, TimeUnit.MILLISECONDS); + } + + @Override + public boolean tryRetry() { + leaderRetrievalService.notifyListener(testActorAddress, null); + return true; + } + }); + + KvStateLocation location = Await.result(lookupService.getKvStateLookupInfo(new JobID(), "yessir"), TIMEOUT); + assertEquals(expected, location); + } + + @Test + public void testUnexpectedResponseType() throws Exception { + TestingLeaderRetrievalService leaderRetrievalService = new TestingLeaderRetrievalService(); + Queue received = new LinkedBlockingQueue<>(); + + AkkaKvStateLocationLookupService lookupService = new AkkaKvStateLocationLookupService( + leaderRetrievalService, + testActorSystem, + TIMEOUT, + new AkkaKvStateLocationLookupService.DisabledLookupRetryStrategyFactory()); + + lookupService.start(); + + // Create test actors with random leader session IDs + String expected = "unexpected-response-type"; + ActorRef testActor = LookupResponseActor.create(received, null, expected); + String testActorAddress = AkkaUtils.getAkkaURL(testActorSystem, testActor); + + leaderRetrievalService.notifyListener(testActorAddress, null); + + try { + Await.result(lookupService.getKvStateLookupInfo(new JobID(), "spicy"), TIMEOUT); + fail("Did not throw expected Exception"); + } catch (Throwable ignored) { + // Expected + } + } + + private final static class LookupResponseActor extends FlinkUntypedActor { + + /** Received lookup messages */ + private final Queue receivedLookups; + + /** Responses on KvStateMessage.LookupKvStateLocation messages */ + private final Queue lookupResponses; + + /** The leader session ID */ + private UUID leaderSessionId; + + public LookupResponseActor( + Queue receivedLookups, + UUID leaderSessionId, Object... lookupResponses) { + + this.receivedLookups = Preconditions.checkNotNull(receivedLookups, "Received lookups"); + this.leaderSessionId = leaderSessionId; + this.lookupResponses = new ArrayDeque<>(); + + if (lookupResponses != null) { + for (Object resp : lookupResponses) { + this.lookupResponses.add(resp); + } + } + } + + @Override + public void handleMessage(Object message) throws Exception { + if (message instanceof LookupKvStateLocation) { + // Add to received lookups queue + receivedLookups.add((LookupKvStateLocation) message); + + Object msg = lookupResponses.poll(); + if (msg != null) { + if (msg instanceof Throwable) { + sender().tell(new Status.Failure((Throwable) msg), self()); + } else { + sender().tell(new Status.Success(msg), self()); + } + } + } else if (message instanceof UUID) { + this.leaderSessionId = (UUID) message; + } else { + LOG.debug("Received unhandled message: {}", message); + } + } + + @Override + protected UUID getLeaderSessionID() { + return leaderSessionId; + } + + private static ActorRef create( + Queue receivedLookups, + UUID leaderSessionId, + Object... lookupResponses) { + + return testActorSystem.actorOf(Props.create( + LookupResponseActor.class, + receivedLookups, + leaderSessionId, + lookupResponses)); + } + } + + private static void verifyLookupMsg( + LookupKvStateLocation lookUpMsg, + JobID expectedJobId, + String expectedName) { + + assertNotNull(lookUpMsg); + assertEquals(expectedJobId, lookUpMsg.getJobId()); + assertEquals(expectedName, lookUpMsg.getRegistrationName()); + } + +} From 21f7be0ece90b1afaa5efd1bd24fef3128bb4d30 Mon Sep 17 00:00:00 2001 From: Ufuk Celebi Date: Mon, 30 May 2016 14:08:24 +0200 Subject: [PATCH 5/6] [FLINK-3779] [runtime] Add QueryableStateClient - Adds a client, which works with the network client and location lookup service to query KvState instances. - Furthermore, location information is cached. --- docs/setup/config.md | 14 + .../flink/configuration/ConfigConstants.java | 38 ++ .../io/network/NetworkEnvironment.java | 40 +- .../runtime/query/QueryableStateClient.java | 355 ++++++++++++++++ .../query/UnknownKvStateKeyGroupLocation.java | 29 ++ .../runtime/query/UnknownKvStateLocation.java | 35 ++ .../flink/runtime/query/package-info.java | 60 +++ .../NetworkEnvironmentConfiguration.scala | 3 + .../runtime/taskmanager/TaskManager.scala | 15 + .../io/network/NetworkEnvironmentTest.java | 5 +- .../query/QueryableStateClientTest.java | 394 ++++++++++++++++++ ...kManagerComponentsStartupShutdownTest.java | 4 +- 12 files changed, 975 insertions(+), 17 deletions(-) create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/query/QueryableStateClient.java create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/query/UnknownKvStateKeyGroupLocation.java create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/query/UnknownKvStateLocation.java create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/query/package-info.java create mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/query/QueryableStateClientTest.java diff --git a/docs/setup/config.md b/docs/setup/config.md index 5baccfabb0807..41b7ba7de527c 100644 --- a/docs/setup/config.md +++ b/docs/setup/config.md @@ -313,6 +313,20 @@ of the JobManager, because the same ActorSystem is used. Its not possible to use - `env.log.dir`: (Defaults to the `log` directory under Flink's home) Defines the directory where the Flink logs are saved. It has to be an absolute path. +## Queryable State + +### Server + +- `query.server.port`: Port to bind queryable state server to (Default: `0`, binds to random port). +- `query.server.network-threads`: Number of network (Netty's event loop) Threads for queryable state server (Default: `0`, picks number of slots). +- `query.server.query-threads`: Number of query Threads for queryable state server (Default: `0`, picks number of slots). + +### Client + +- `query.client.network-threads`: Number of network (Netty's event loop) Threads for queryable state client (Default: `0`, picks number of available cores as returned by `Runtime.getRuntime().availableProcessors()`). +- `query.client.lookup.num-retries`: Number of retries on KvState lookup failure due to unavailable JobManager (Default: `3`). +- `query.client.lookup.retry-delay`: Retry delay in milliseconds on KvState lookup failure due to unavailable JobManager (Default: `1000`). + ## Metrics - `metrics.reporters`: The list of named reporters, i.e. "foo,bar". diff --git a/flink-core/src/main/java/org/apache/flink/configuration/ConfigConstants.java b/flink-core/src/main/java/org/apache/flink/configuration/ConfigConstants.java index 928497fd3f306..98a843dd33e1c 100644 --- a/flink-core/src/main/java/org/apache/flink/configuration/ConfigConstants.java +++ b/flink-core/src/main/java/org/apache/flink/configuration/ConfigConstants.java @@ -1066,6 +1066,44 @@ public final class ConfigConstants { /** ZooKeeper default leader port. */ public static final int DEFAULT_ZOOKEEPER_LEADER_PORT = 3888; + // ------------------------- Queryable state ------------------------------ + + /** Port to bind KvState server to. */ + public static final String QUERYABLE_STATE_SERVER_PORT = "query.server.port"; + + /** Number of network (event loop) threads for the KvState server. */ + public static final String QUERYABLE_STATE_SERVER_NETWORK_THREADS = "query.server.network-threads"; + + /** Number of query threads for the KvState server. */ + public static final String QUERYABLE_STATE_SERVER_QUERY_THREADS = "query.server.query-threads"; + + /** Default port to bind KvState server to (0 => pick random free port). */ + public static final int DEFAULT_QUERYABLE_STATE_SERVER_PORT = 0; + + /** Default Number of network (event loop) threads for the KvState server (0 => #slots). */ + public static final int DEFAULT_QUERYABLE_STATE_SERVER_NETWORK_THREADS = 0; + + /** Default number of query threads for the KvState server (0 => #slots). */ + public static final int DEFAULT_QUERYABLE_STATE_SERVER_QUERY_THREADS = 0; + + /** Number of network (event loop) threads for the KvState client. */ + public static final String QUERYABLE_STATE_CLIENT_NETWORK_THREADS = "query.client.network-threads"; + + /** Number of retries on location lookup failures. */ + public static final String QUERYABLE_STATE_CLIENT_LOOKUP_RETRIES = "query.client.lookup.num-retries"; + + /** Retry delay on location lookup failures (millis). */ + public static final String QUERYABLE_STATE_CLIENT_LOOKUP_RETRY_DELAY = "query.client.lookup.retry-delay"; + + /** Default number of query threads for the KvState client (0 => #cores) */ + public static final int DEFAULT_QUERYABLE_STATE_CLIENT_NETWORK_THREADS = 0; + + /** Default number of retries on location lookup failures. */ + public static final int DEFAULT_QUERYABLE_STATE_CLIENT_LOOKUP_RETRIES = 3; + + /** Default retry delay on location lookup failures. */ + public static final int DEFAULT_QUERYABLE_STATE_CLIENT_LOOKUP_RETRY_DELAY = 1000; + // ----------------------------- Environment Variables ---------------------------- /** The environment variable name which contains the location of the configuration directory */ diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java index 283d804509893..844bc2dadea2c 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java @@ -45,11 +45,11 @@ import org.apache.flink.runtime.query.KvStateRegistryListener; import org.apache.flink.runtime.query.KvStateServerAddress; import org.apache.flink.runtime.query.TaskKvStateRegistry; +import org.apache.flink.runtime.query.netty.DisabledKvStateRequestStats; +import org.apache.flink.runtime.query.netty.KvStateServer; import org.apache.flink.runtime.taskmanager.NetworkEnvironmentConfiguration; import org.apache.flink.runtime.taskmanager.Task; import org.apache.flink.runtime.taskmanager.TaskManager; -import org.apache.flink.runtime.query.netty.AtomicKvStateRequestStats; -import org.apache.flink.runtime.query.netty.KvStateServer; import org.apache.flink.util.Preconditions; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -242,21 +242,33 @@ public void associateWithTaskManagerAndJobManager( try { kvStateRegistry = new KvStateRegistry(); - kvStateServer = new KvStateServer( - connectionInfo.address(), - 0, - 1, - 10, - kvStateRegistry, - new AtomicKvStateRequestStats()); + if (nettyConfig.isDefined()) { + int numNetworkThreads = configuration.queryServerNetworkThreads(); + if (numNetworkThreads == 0) { + numNetworkThreads = nettyConfig.get().getNumberOfSlots(); + } - kvStateServer.start(); + int numQueryThreads = configuration.queryServerNetworkThreads(); + if (numQueryThreads == 0) { + numQueryThreads = nettyConfig.get().getNumberOfSlots(); + } - KvStateRegistryListener listener = new JobManagerKvStateRegistryListener( - jobManagerGateway, - kvStateServer.getAddress()); + kvStateServer = new KvStateServer( + connectionInfo.address(), + configuration.queryServerPort(), + numNetworkThreads, + numQueryThreads, + kvStateRegistry, + new DisabledKvStateRequestStats()); - kvStateRegistry.registerListener(listener); + kvStateServer.start(); + + KvStateRegistryListener listener = new JobManagerKvStateRegistryListener( + jobManagerGateway, + kvStateServer.getAddress()); + + kvStateRegistry.registerListener(listener); + } } catch (Throwable t) { throw new IOException("Failed to instantiate KvState management components: " + t.getMessage(), t); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/query/QueryableStateClient.java b/flink-runtime/src/main/java/org/apache/flink/runtime/query/QueryableStateClient.java new file mode 100644 index 0000000000000..0e1ea57f339d5 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/query/QueryableStateClient.java @@ -0,0 +1,355 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.query; + +import akka.actor.ActorSystem; +import akka.dispatch.Futures; +import akka.dispatch.Mapper; +import akka.dispatch.Recover; +import org.apache.flink.api.common.JobID; +import org.apache.flink.api.common.state.StateDescriptor; +import org.apache.flink.configuration.ConfigConstants; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.configuration.IllegalConfigurationException; +import org.apache.flink.runtime.akka.AkkaUtils; +import org.apache.flink.runtime.leaderretrieval.LeaderRetrievalService; +import org.apache.flink.runtime.query.AkkaKvStateLocationLookupService.FixedDelayLookupRetryStrategyFactory; +import org.apache.flink.runtime.query.AkkaKvStateLocationLookupService.LookupRetryStrategyFactory; +import org.apache.flink.runtime.query.netty.DisabledKvStateRequestStats; +import org.apache.flink.runtime.query.netty.KvStateClient; +import org.apache.flink.runtime.query.netty.KvStateServer; +import org.apache.flink.runtime.query.netty.UnknownKeyOrNamespace; +import org.apache.flink.runtime.query.netty.UnknownKvStateID; +import org.apache.flink.runtime.util.LeaderRetrievalUtils; +import org.apache.flink.util.MathUtils; +import org.apache.flink.util.Preconditions; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import scala.Option; +import scala.Some; +import scala.Tuple2; +import scala.concurrent.ExecutionContext; +import scala.concurrent.Future; +import scala.concurrent.duration.Duration; +import scala.concurrent.duration.FiniteDuration; + +import java.net.ConnectException; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; + +/** + * Client for queryable state. + * + *

You can mark state as queryable via {@link StateDescriptor#setQueryable(String)}. + * The state instance created from this descriptor will be published for queries + * when it's created on the TaskManagers and the location will be reported to + * the JobManager. + * + *

The client resolves the location of the requested KvState via the + * JobManager. Resolved locations are cached. When the server address of the + * requested KvState instance is determined, the client sends out a request to + * the server. + */ +public class QueryableStateClient { + + private static final Logger LOG = LoggerFactory.getLogger(QueryableStateClient.class); + + /** + * {@link KvStateLocation} lookup to resolve the address of KvState instances. + */ + private final KvStateLocationLookupService lookupService; + + /** + * Network client for queries against {@link KvStateServer} instances. + */ + private final KvStateClient kvStateClient; + + /** + * Execution context. + */ + private final ExecutionContext executionContext; + + /** + * Cache for {@link KvStateLocation} instances keyed by job and name. + */ + private final ConcurrentMap, Future> lookupCache = + new ConcurrentHashMap<>(); + + /** This is != null, iff we started the actor system. */ + private final ActorSystem actorSystem; + + /** + * Creates a client from the given configuration. + * + *

This will create multiple Thread pools: one for the started actor + * system and another for the network client. + * + * @param config Configuration to use. + * @throws Exception Failures are forwarded + */ + public QueryableStateClient(Configuration config) throws Exception { + Preconditions.checkNotNull(config, "Configuration"); + + // Create a leader retrieval service + LeaderRetrievalService leaderRetrievalService = LeaderRetrievalUtils + .createLeaderRetrievalService(config); + + // Get the ask timeout + String askTimeoutString = config.getString( + ConfigConstants.AKKA_ASK_TIMEOUT, + ConfigConstants.DEFAULT_AKKA_ASK_TIMEOUT); + + Duration timeout = FiniteDuration.apply(askTimeoutString); + if (!timeout.isFinite()) { + throw new IllegalConfigurationException(ConfigConstants.AKKA_ASK_TIMEOUT + + " is not a finite timeout ('" + askTimeoutString + "')"); + } + + FiniteDuration askTimeout = (FiniteDuration) timeout; + + int lookupRetries = config.getInteger( + ConfigConstants.QUERYABLE_STATE_CLIENT_LOOKUP_RETRIES, + ConfigConstants.DEFAULT_QUERYABLE_STATE_CLIENT_LOOKUP_RETRIES); + + int lookupRetryDelayMillis = config.getInteger( + ConfigConstants.QUERYABLE_STATE_CLIENT_LOOKUP_RETRY_DELAY, + ConfigConstants.DEFAULT_QUERYABLE_STATE_CLIENT_LOOKUP_RETRY_DELAY); + + // Retries if no JobManager is around + LookupRetryStrategyFactory retryStrategy = new FixedDelayLookupRetryStrategyFactory( + lookupRetries, + FiniteDuration.apply(lookupRetryDelayMillis, "ms")); + + // Create the actor system + @SuppressWarnings("unchecked") + Option> remoting = new Some(new Tuple2<>("", 0)); + this.actorSystem = AkkaUtils.createActorSystem(config, remoting); + + AkkaKvStateLocationLookupService lookupService = new AkkaKvStateLocationLookupService( + leaderRetrievalService, + actorSystem, + askTimeout, + retryStrategy); + + int numEventLoopThreads = config.getInteger( + ConfigConstants.QUERYABLE_STATE_CLIENT_NETWORK_THREADS, + ConfigConstants.DEFAULT_QUERYABLE_STATE_CLIENT_NETWORK_THREADS); + + if (numEventLoopThreads == 0) { + numEventLoopThreads = Runtime.getRuntime().availableProcessors(); + } + + // Create the network client + KvStateClient networkClient = new KvStateClient( + numEventLoopThreads, + new DisabledKvStateRequestStats()); + + this.lookupService = lookupService; + this.kvStateClient = networkClient; + this.executionContext = actorSystem.dispatcher(); + + this.lookupService.start(); + } + + /** + * Creates a client. + * + * @param lookupService Location lookup service + * @param kvStateClient Network client for queries + * @param executionContext Execution context for futures + */ + public QueryableStateClient( + KvStateLocationLookupService lookupService, + KvStateClient kvStateClient, + ExecutionContext executionContext) { + + this.lookupService = Preconditions.checkNotNull(lookupService, "KvStateLocationLookupService"); + this.kvStateClient = Preconditions.checkNotNull(kvStateClient, "KvStateClient"); + this.executionContext = Preconditions.checkNotNull(executionContext, "ExecutionContext"); + this.actorSystem = null; + + this.lookupService.start(); + } + + /** + * Returns the execution context of this client. + * + * @return The execution context used by the client. + */ + public ExecutionContext getExecutionContext() { + return executionContext; + } + + /** + * Shuts down the client and all components. + */ + public void shutDown() { + try { + lookupService.shutDown(); + } catch (Throwable t) { + LOG.error("Failed to shut down KvStateLookupService", t); + } + + try { + kvStateClient.shutDown(); + } catch (Throwable t) { + LOG.error("Failed to shut down KvStateClient", t); + } + + if (actorSystem != null) { + try { + actorSystem.shutdown(); + } catch (Throwable t) { + LOG.error("Failed to shut down ActorSystem"); + } + } + } + + /** + * Returns a future holding the serialized request result. + * + *

If the server does not serve a KvState instance with the given ID, + * the Future will be failed with a {@link UnknownKvStateID}. + * + *

If the KvState instance does not hold any data for the given key + * and namespace, the Future will be failed with a {@link UnknownKeyOrNamespace}. + * + *

All other failures are forwarded to the Future. + * + * @param jobId JobID of the job the queryable state + * belongs to + * @param queryableStateName Name under which the state is queryable + * @param keyHashCode Integer hash code of the key (result of + * a call to {@link Object#hashCode()} + * @param serializedKeyAndNamespace Serialized key and namespace to query + * KvState instance with + * @return Future holding the serialized result + */ + @SuppressWarnings("unchecked") + public Future getKvState( + final JobID jobId, + final String queryableStateName, + final int keyHashCode, + final byte[] serializedKeyAndNamespace) { + + return getKvState(jobId, queryableStateName, keyHashCode, serializedKeyAndNamespace, false) + .recoverWith(new Recover>() { + @Override + public Future recover(Throwable failure) throws Throwable { + if (failure instanceof UnknownKvStateID || + failure instanceof UnknownKvStateKeyGroupLocation || + failure instanceof UnknownKvStateLocation || + failure instanceof ConnectException) { + // These failures are likely to be caused by out-of-sync + // KvStateLocation. Therefore we retry this query and + // force look up the location. + return getKvState( + jobId, + queryableStateName, + keyHashCode, + serializedKeyAndNamespace, + true); + } else { + return Futures.failed(failure); + } + } + }, executionContext); + } + + /** + * Returns a future holding the serialized request result. + * + * @param jobId JobID of the job the queryable state + * belongs to + * @param queryableStateName Name under which the state is queryable + * @param keyHashCode Integer hash code of the key (result of + * a call to {@link Object#hashCode()} + * @param serializedKeyAndNamespace Serialized key and namespace to query + * KvState instance with + * @param forceLookup Flag to force lookup of the {@link KvStateLocation} + * @return Future holding the serialized result + */ + private Future getKvState( + final JobID jobId, + final String queryableStateName, + final int keyHashCode, + final byte[] serializedKeyAndNamespace, + boolean forceLookup) { + + return getKvStateLookupInfo(jobId, queryableStateName, forceLookup) + .flatMap(new Mapper>() { + @Override + public Future apply(KvStateLocation lookup) { + int keyGroupIndex = MathUtils.murmurHash(keyHashCode) % lookup.getNumKeyGroups(); + + KvStateServerAddress serverAddress = lookup.getKvStateServerAddress(keyGroupIndex); + if (serverAddress == null) { + return Futures.failed(new UnknownKvStateKeyGroupLocation()); + } else { + // Query server + KvStateID kvStateId = lookup.getKvStateID(keyGroupIndex); + return kvStateClient.getKvState(serverAddress, kvStateId, serializedKeyAndNamespace); + } + } + }, executionContext); + } + + /** + * Lookup the {@link KvStateLocation} for the given job and queryable state + * name. + * + *

The job manager will be queried for the location only if forced or no + * cached location can be found. There are no guarantees about + * + * @param jobId JobID the state instance belongs to. + * @param queryableStateName Name under which the state instance has been published. + * @param forceUpdate Flag to indicate whether to force a update via the lookup service. + * @return Future holding the KvStateLocation + */ + private Future getKvStateLookupInfo( + JobID jobId, + final String queryableStateName, + boolean forceUpdate) { + + if (forceUpdate) { + Future lookupFuture = lookupService + .getKvStateLookupInfo(jobId, queryableStateName); + lookupCache.put(new Tuple2<>(jobId, queryableStateName), lookupFuture); + return lookupFuture; + } else { + Tuple2 cacheKey = new Tuple2<>(jobId, queryableStateName); + final Future cachedFuture = lookupCache.get(cacheKey); + + if (cachedFuture == null) { + Future lookupFuture = lookupService + .getKvStateLookupInfo(jobId, queryableStateName); + + Future previous = lookupCache.putIfAbsent(cacheKey, lookupFuture); + if (previous == null) { + return lookupFuture; + } else { + return previous; + } + } else { + return cachedFuture; + } + } + } + +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/query/UnknownKvStateKeyGroupLocation.java b/flink-runtime/src/main/java/org/apache/flink/runtime/query/UnknownKvStateKeyGroupLocation.java new file mode 100644 index 0000000000000..8f62be5be8609 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/query/UnknownKvStateKeyGroupLocation.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.query; + +/** + * Exception thrown if there is no location information available for the given + * key group in a {@link KvStateLocation} instance. + */ +class UnknownKvStateKeyGroupLocation extends Exception { + + private static final long serialVersionUID = 1L; + +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/query/UnknownKvStateLocation.java b/flink-runtime/src/main/java/org/apache/flink/runtime/query/UnknownKvStateLocation.java new file mode 100644 index 0000000000000..38cc1ccbdfcc7 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/query/UnknownKvStateLocation.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.query; + +/** + * Thrown if there is no {@link KvStateLocation} found for the requested + * registration name. + * + *

This indicates that the requested KvState instance is not registered + * under this name (yet). + */ +public class UnknownKvStateLocation extends Exception { + + private static final long serialVersionUID = 1L; + + public UnknownKvStateLocation(String registrationName) { + super("No KvStateLocation found for KvState instance with name '" + registrationName + "'."); + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/query/package-info.java b/flink-runtime/src/main/java/org/apache/flink/runtime/query/package-info.java new file mode 100644 index 0000000000000..07a4396fecd69 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/query/package-info.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * This package contains all KvState query related classes. + * + *

TaskManager and JobManager

+ * + *

State backends register queryable state instances at the {@link + * org.apache.flink.runtime.query.KvStateRegistry}. + * There is one registry per TaskManager. Registered KvState instances are + * reported to the JobManager, where they are aggregated at the {@link + * org.apache.flink.runtime.query.KvStateLocationRegistry}. + * + *

Instances of {@link org.apache.flink.runtime.query.KvStateLocation} contain + * all information needed for a client to query a KvState instance. + * + *

See also: + *

    + *
  • {@link org.apache.flink.runtime.query.KvStateRegistry}
  • + *
  • {@link org.apache.flink.runtime.query.TaskKvStateRegistry}
  • + *
  • {@link org.apache.flink.runtime.query.KvStateLocation}
  • + *
  • {@link org.apache.flink.runtime.query.KvStateLocationRegistry}
  • + *
+ * + *

Client

+ * + * The {@link org.apache.flink.runtime.query.QueryableStateClient} is used + * to query KvState instances. The client takes care of {@link + * org.apache.flink.runtime.query.KvStateLocation} lookup and caching. Queries + * are then dispatched via the network client. + * + *

JobManager Communication

+ * + *

The JobManager is queried for {@link org.apache.flink.runtime.query.KvStateLocation} + * instances via the {@link org.apache.flink.runtime.query.KvStateLocationLookupService}. + * The client caches resolved locations and dispatches queries directly to the + * TaskManager. + * + *

TaskManager Communication

+ * + *

After the location has been resolved, the TaskManager is queried via the + * {@link org.apache.flink.runtime.query.netty.KvStateClient}. + */ +package org.apache.flink.runtime.query; diff --git a/flink-runtime/src/main/scala/org/apache/flink/runtime/taskmanager/NetworkEnvironmentConfiguration.scala b/flink-runtime/src/main/scala/org/apache/flink/runtime/taskmanager/NetworkEnvironmentConfiguration.scala index 065211c934e2b..0788d7c28d3a4 100644 --- a/flink-runtime/src/main/scala/org/apache/flink/runtime/taskmanager/NetworkEnvironmentConfiguration.scala +++ b/flink-runtime/src/main/scala/org/apache/flink/runtime/taskmanager/NetworkEnvironmentConfiguration.scala @@ -27,5 +27,8 @@ case class NetworkEnvironmentConfiguration( networkBufferSize: Int, memoryType: MemoryType, ioMode: IOMode, + queryServerPort: Int, + queryServerNetworkThreads: Int, + queryServerQueryThreads: Int, nettyConfig: Option[NettyConfig] = None, partitionRequestInitialAndMaxBackoff: (Integer, Integer) = (500, 3000)) diff --git a/flink-runtime/src/main/scala/org/apache/flink/runtime/taskmanager/TaskManager.scala b/flink-runtime/src/main/scala/org/apache/flink/runtime/taskmanager/TaskManager.scala index 7c4b867a114ec..e732214ceec98 100644 --- a/flink-runtime/src/main/scala/org/apache/flink/runtime/taskmanager/TaskManager.scala +++ b/flink-runtime/src/main/scala/org/apache/flink/runtime/taskmanager/TaskManager.scala @@ -2097,11 +2097,26 @@ object TaskManager { val ioMode : IOMode = if (syncOrAsync == "async") IOMode.ASYNC else IOMode.SYNC + val queryServerPort = configuration.getInteger( + ConfigConstants.QUERYABLE_STATE_SERVER_PORT, + ConfigConstants.DEFAULT_QUERYABLE_STATE_SERVER_PORT) + + val queryServerNetworkThreads = configuration.getInteger( + ConfigConstants.QUERYABLE_STATE_SERVER_NETWORK_THREADS, + ConfigConstants.DEFAULT_QUERYABLE_STATE_SERVER_NETWORK_THREADS) + + val queryServerQueryThreads = configuration.getInteger( + ConfigConstants.QUERYABLE_STATE_SERVER_QUERY_THREADS, + ConfigConstants.DEFAULT_QUERYABLE_STATE_SERVER_QUERY_THREADS) + val networkConfig = NetworkEnvironmentConfiguration( numNetworkBuffers, pageSize, memType, ioMode, + queryServerPort, + queryServerNetworkThreads, + queryServerQueryThreads, nettyConfig) // ----> timeouts, library caching, profiling diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/NetworkEnvironmentTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/NetworkEnvironmentTest.java index 938e66190e80e..4597e3bf2e897 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/NetworkEnvironmentTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/NetworkEnvironmentTest.java @@ -80,7 +80,7 @@ public void testAssociateDisassociate() { NettyConfig nettyConf = new NettyConfig(InetAddress.getLocalHost(), port, BUFFER_SIZE, 1, new Configuration()); NetworkEnvironmentConfiguration config = new NetworkEnvironmentConfiguration( NUM_BUFFERS, BUFFER_SIZE, MemoryType.HEAP, - IOManager.IOMode.SYNC, new Some<>(nettyConf), + IOManager.IOMode.SYNC, 0, 0, 0, new Some<>(nettyConf), new Tuple2<>(0, 0)); NetworkEnvironment env = new NetworkEnvironment( @@ -174,6 +174,9 @@ public void testEagerlyDeployConsumers() throws Exception { 1024, MemoryType.HEAP, IOManager.IOMode.SYNC, + 0, + 0, + 0, Some.empty(), new Tuple2<>(0, 0)); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/query/QueryableStateClientTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/query/QueryableStateClientTest.java new file mode 100644 index 0000000000000..36f2b45135159 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/query/QueryableStateClientTest.java @@ -0,0 +1,394 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.query; + +import akka.actor.ActorSystem; +import akka.dispatch.Futures; +import org.apache.flink.api.common.JobID; +import org.apache.flink.api.common.state.ValueStateDescriptor; +import org.apache.flink.api.common.typeutils.base.IntSerializer; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.runtime.akka.AkkaUtils; +import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.query.netty.AtomicKvStateRequestStats; +import org.apache.flink.runtime.query.netty.KvStateClient; +import org.apache.flink.runtime.query.netty.KvStateServer; +import org.apache.flink.runtime.query.netty.UnknownKvStateID; +import org.apache.flink.runtime.query.netty.message.KvStateRequestSerializer; +import org.apache.flink.runtime.state.VoidNamespace; +import org.apache.flink.runtime.state.VoidNamespaceSerializer; +import org.apache.flink.runtime.state.memory.MemValueState; +import org.apache.flink.util.MathUtils; +import org.junit.AfterClass; +import org.junit.Test; +import scala.concurrent.Await; +import scala.concurrent.Future; +import scala.concurrent.duration.FiniteDuration; + +import java.net.ConnectException; +import java.net.InetAddress; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.TimeUnit; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyString; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class QueryableStateClientTest { + + private static final ActorSystem testActorSystem = AkkaUtils.createLocalActorSystem(new Configuration()); + + private static final FiniteDuration timeout = new FiniteDuration(100, TimeUnit.SECONDS); + + @AfterClass + public static void tearDown() throws Exception { + if (testActorSystem != null) { + testActorSystem.shutdown(); + } + } + + /** + * All failures should lead to a retry with a forced location lookup. + * + * UnknownKvStateID, UnknownKvStateKeyGroupLocation, UnknownKvStateLocation, + * ConnectException are checked explicitly as these indicate out-of-sync + * KvStateLocation. + */ + @Test + public void testForceLookupOnOutdatedLocation() throws Exception { + KvStateLocationLookupService lookupService = mock(KvStateLocationLookupService.class); + KvStateClient networkClient = mock(KvStateClient.class); + + QueryableStateClient client = new QueryableStateClient( + lookupService, + networkClient, + testActorSystem.dispatcher()); + + try { + JobID jobId = new JobID(); + int numKeyGroups = 4; + + // + // UnknownKvStateLocation + // + String query1 = "lucky"; + + Future unknownKvStateLocation = Futures.failed( + new UnknownKvStateLocation(query1)); + + when(lookupService.getKvStateLookupInfo(eq(jobId), eq(query1))) + .thenReturn(unknownKvStateLocation); + + Future result = client.getKvState( + jobId, + query1, + 0, + new byte[0]); + + try { + Await.result(result, timeout); + fail("Did not throw expected UnknownKvStateLocation exception"); + } catch (UnknownKvStateLocation ignored) { + // Expected + } + + verify(lookupService, times(2)).getKvStateLookupInfo(eq(jobId), eq(query1)); + + // + // UnknownKvStateKeyGroupLocation + // + String query2 = "unlucky"; + + Future unknownKeyGroupLocation = Futures.successful( + new KvStateLocation(jobId, new JobVertexID(), numKeyGroups, query2)); + + when(lookupService.getKvStateLookupInfo(eq(jobId), eq(query2))) + .thenReturn(unknownKeyGroupLocation); + + result = client.getKvState(jobId, query2, 0, new byte[0]); + + try { + Await.result(result, timeout); + fail("Did not throw expected UnknownKvStateKeyGroupLocation exception"); + } catch (UnknownKvStateKeyGroupLocation ignored) { + // Expected + } + + verify(lookupService, times(2)).getKvStateLookupInfo(eq(jobId), eq(query2)); + + // + // UnknownKvStateID + // + String query3 = "water"; + KvStateID kvStateId = new KvStateID(); + Future unknownKvStateId = Futures.failed(new UnknownKvStateID(kvStateId)); + + KvStateServerAddress serverAddress = new KvStateServerAddress(InetAddress.getLocalHost(), 12323); + KvStateLocation location = new KvStateLocation(jobId, new JobVertexID(), numKeyGroups, query3); + for (int i = 0; i < numKeyGroups; i++) { + location.registerKvState(i, kvStateId, serverAddress); + } + + when(lookupService.getKvStateLookupInfo(eq(jobId), eq(query3))) + .thenReturn(Futures.successful(location)); + + when(networkClient.getKvState(eq(serverAddress), eq(kvStateId), any(byte[].class))) + .thenReturn(unknownKvStateId); + + result = client.getKvState(jobId, query3, 0, new byte[0]); + + try { + Await.result(result, timeout); + fail("Did not throw expected UnknownKvStateID exception"); + } catch (UnknownKvStateID ignored) { + // Expected + } + + verify(lookupService, times(2)).getKvStateLookupInfo(eq(jobId), eq(query3)); + + // + // ConnectException + // + String query4 = "space"; + Future connectException = Futures.failed(new ConnectException()); + kvStateId = new KvStateID(); + + serverAddress = new KvStateServerAddress(InetAddress.getLocalHost(), 11123); + location = new KvStateLocation(jobId, new JobVertexID(), numKeyGroups, query4); + for (int i = 0; i < numKeyGroups; i++) { + location.registerKvState(i, kvStateId, serverAddress); + } + + when(lookupService.getKvStateLookupInfo(eq(jobId), eq(query4))) + .thenReturn(Futures.successful(location)); + + when(networkClient.getKvState(eq(serverAddress), eq(kvStateId), any(byte[].class))) + .thenReturn(connectException); + + result = client.getKvState(jobId, query4, 0, new byte[0]); + + try { + Await.result(result, timeout); + fail("Did not throw expected ConnectException exception"); + } catch (ConnectException ignored) { + // Expected + } + + verify(lookupService, times(2)).getKvStateLookupInfo(eq(jobId), eq(query4)); + + // + // Other Exceptions don't lead to a retry no retry + // + String query5 = "universe"; + Future exception = Futures.failed(new RuntimeException("Test exception")); + when(lookupService.getKvStateLookupInfo(eq(jobId), eq(query5))) + .thenReturn(exception); + + client.getKvState(jobId, query5, 0, new byte[0]); + + verify(lookupService, times(1)).getKvStateLookupInfo(eq(jobId), eq(query5)); + } finally { + client.shutDown(); + } + } + + /** + * Tests queries against multiple servers. + * + *

The servers are populated with different keys and the client queries + * all available keys from all servers. + */ + @Test + public void testIntegrationWithKvStateServer() throws Exception { + // Config + int numServers = 2; + int numKeys = 1024; + + JobID jobId = new JobID(); + JobVertexID jobVertexId = new JobVertexID(); + + KvStateServer[] servers = new KvStateServer[numServers]; + AtomicKvStateRequestStats[] serverStats = new AtomicKvStateRequestStats[numServers]; + + QueryableStateClient client = null; + KvStateClient networkClient = null; + AtomicKvStateRequestStats networkClientStats = new AtomicKvStateRequestStats(); + + try { + KvStateRegistry[] registries = new KvStateRegistry[numServers]; + KvStateID[] kvStateIds = new KvStateID[numServers]; + List> kvStates = new ArrayList<>(); + + // Start the servers + for (int i = 0; i < numServers; i++) { + registries[i] = new KvStateRegistry(); + serverStats[i] = new AtomicKvStateRequestStats(); + servers[i] = new KvStateServer(InetAddress.getLocalHost(), 0, 1, 1, registries[i], serverStats[i]); + servers[i].start(); + + // Register state + MemValueState kvState = new MemValueState<>( + IntSerializer.INSTANCE, + VoidNamespaceSerializer.INSTANCE, + new ValueStateDescriptor<>("any", IntSerializer.INSTANCE, null)); + + kvStates.add(kvState); + + kvStateIds[i] = registries[i].registerKvState( + jobId, + new JobVertexID(), + i, // key group index + "choco", + kvState); + } + + int[] expectedRequests = new int[numServers]; + + for (int key = 0; key < numKeys; key++) { + int targetKeyGroupIndex = MathUtils.murmurHash(key) % numServers; + expectedRequests[targetKeyGroupIndex]++; + + MemValueState kvState = kvStates.get(targetKeyGroupIndex); + + kvState.setCurrentKey(key); + kvState.setCurrentNamespace(VoidNamespace.INSTANCE); + kvState.update(1337 + key); + } + + // Location lookup service + KvStateLocation location = new KvStateLocation(jobId, jobVertexId, numServers, "choco"); + for (int keyGroupIndex = 0; keyGroupIndex < numServers; keyGroupIndex++) { + location.registerKvState(keyGroupIndex, kvStateIds[keyGroupIndex], servers[keyGroupIndex].getAddress()); + } + + KvStateLocationLookupService lookupService = mock(KvStateLocationLookupService.class); + when(lookupService.getKvStateLookupInfo(eq(jobId), eq("choco"))) + .thenReturn(Futures.successful(location)); + + // The client + networkClient = new KvStateClient(1, networkClientStats); + + client = new QueryableStateClient(lookupService, networkClient, testActorSystem.dispatcher()); + + // Send all queries + List> futures = new ArrayList<>(numKeys); + for (int key = 0; key < numKeys; key++) { + byte[] serializedKeyAndNamespace = KvStateRequestSerializer.serializeKeyAndNamespace( + key, + IntSerializer.INSTANCE, + VoidNamespace.INSTANCE, + VoidNamespaceSerializer.INSTANCE); + + futures.add(client.getKvState(jobId, "choco", key, serializedKeyAndNamespace)); + } + + // Verify results + Future> future = Futures.sequence(futures, testActorSystem.dispatcher()); + Iterable results = Await.result(future, timeout); + + int index = 0; + for (byte[] buffer : results) { + int deserializedValue = KvStateRequestSerializer.deserializeValue(buffer, IntSerializer.INSTANCE); + assertEquals(1337 + index, deserializedValue); + index++; + } + + // Verify requests + for (int i = 0; i < numServers; i++) { + int numRetries = 10; + for (int retry = 0; retry < numRetries; retry++) { + try { + assertEquals("Unexpected number of requests", expectedRequests[i], serverStats[i].getNumRequests()); + assertEquals("Unexpected success requests", expectedRequests[i], serverStats[i].getNumSuccessful()); + assertEquals("Unexpected failed requests", 0, serverStats[i].getNumFailed()); + break; + } catch (Throwable t) { + // Retry + if (retry == numRetries-1) { + throw t; + } else { + Thread.sleep(100); + } + } + } + } + } finally { + if (client != null) { + client.shutDown(); + } + + if (networkClient != null) { + networkClient.shutDown(); + } + + for (KvStateServer server : servers) { + if (server != null) { + server.shutDown(); + } + } + } + } + + /** + * Tests that the QueryableState client correctly caches location lookups + * keyed by both job and name. This test is mainly due to a previous bug due + * to which cache entries were by name only. This is a problem, because the + * same client can be used to query multiple jobs. + */ + @Test + public void testLookupMultipleJobIds() throws Exception { + String name = "unique-per-job"; + + // Exact contents don't matter here + KvStateLocation location = new KvStateLocation(new JobID(), new JobVertexID(), 1, name); + location.registerKvState(0, new KvStateID(), new KvStateServerAddress(InetAddress.getLocalHost(), 892)); + + JobID jobId1 = new JobID(); + JobID jobId2 = new JobID(); + + KvStateLocationLookupService lookupService = mock(KvStateLocationLookupService.class); + + when(lookupService.getKvStateLookupInfo(any(JobID.class), anyString())) + .thenReturn(Futures.successful(location)); + + KvStateClient networkClient = mock(KvStateClient.class); + when(networkClient.getKvState(any(KvStateServerAddress.class), any(KvStateID.class), any(byte[].class))) + .thenReturn(Futures.successful(new byte[0])); + + QueryableStateClient client = new QueryableStateClient( + lookupService, + networkClient, + testActorSystem.dispatcher()); + + // Query ies with same name, but different job IDs should lead to a + // single lookup per query and job ID. + client.getKvState(jobId1, name, 0, new byte[0]); + client.getKvState(jobId2, name, 0, new byte[0]); + + verify(lookupService, times(1)).getKvStateLookupInfo(eq(jobId1), eq(name)); + verify(lookupService, times(1)).getKvStateLookupInfo(eq(jobId2), eq(name)); + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskManagerComponentsStartupShutdownTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskManagerComponentsStartupShutdownTest.java index ca7157a33b482..147a3e0accef5 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskManagerComponentsStartupShutdownTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskManagerComponentsStartupShutdownTest.java @@ -98,8 +98,8 @@ public void testComponentsStartupShutdown() { config); final NetworkEnvironmentConfiguration netConf = new NetworkEnvironmentConfiguration( - 32, BUFFER_SIZE, MemoryType.HEAP, IOManager.IOMode.SYNC, Option.empty(), - new Tuple2(0, 0)); + 32, BUFFER_SIZE, MemoryType.HEAP, IOManager.IOMode.SYNC, 0, 0, 0, + Option.empty(), new Tuple2(0, 0)); final InstanceConnectionInfo connectionInfo = new InstanceConnectionInfo(InetAddress.getLocalHost(), 10000); From 69935d1f9658ecce1cafcdab37ef27e933dcb02a Mon Sep 17 00:00:00 2001 From: Ufuk Celebi Date: Mon, 30 May 2016 14:08:34 +0200 Subject: [PATCH 6/6] [FLINK-3779] [streaming-java, streaming-scala] Add QueryableStateStream to KeyedStream [runtime, test-utils, tests] - Exposes queryable state on the API via KeyedStream#asQueryableState(String, StateDescriptor). This creates and operator, which consumes the keyed stream and exposes the stream as queryable state. --- .../minicluster/FlinkMiniCluster.scala | 2 +- .../minicluster/LocalFlinkMiniCluster.scala | 2 +- .../streaming/api/datastream/KeyedStream.java | 121 ++ .../api/datastream/QueryableStateStream.java | 87 ++ .../query/AbstractQueryableStateOperator.java | 84 ++ .../QueryableAppendingStateOperator.java | 45 + .../query/QueryableValueStateOperator.java | 45 + .../streaming/api/scala/KeyedStream.scala | 118 +- .../test/util/ForkableFlinkMiniCluster.scala | 25 + .../test/query/QueryableStateITCase.java | 1237 +++++++++++++++++ 10 files changed, 1762 insertions(+), 4 deletions(-) create mode 100644 flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/QueryableStateStream.java create mode 100644 flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/query/AbstractQueryableStateOperator.java create mode 100644 flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/query/QueryableAppendingStateOperator.java create mode 100644 flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/query/QueryableValueStateOperator.java create mode 100644 flink-tests/src/test/java/org/apache/flink/test/query/QueryableStateITCase.java diff --git a/flink-runtime/src/main/scala/org/apache/flink/runtime/minicluster/FlinkMiniCluster.scala b/flink-runtime/src/main/scala/org/apache/flink/runtime/minicluster/FlinkMiniCluster.scala index 5074b8c1fbac1..f6e9360350d92 100644 --- a/flink-runtime/src/main/scala/org/apache/flink/runtime/minicluster/FlinkMiniCluster.scala +++ b/flink-runtime/src/main/scala/org/apache/flink/runtime/minicluster/FlinkMiniCluster.scala @@ -92,7 +92,7 @@ abstract class FlinkMiniCluster( val numJobManagers = getNumberOfJobManagers - val numTaskManagers = configuration.getInteger( + var numTaskManagers = configuration.getInteger( ConfigConstants.LOCAL_NUMBER_TASK_MANAGER, ConfigConstants.DEFAULT_LOCAL_NUMBER_TASK_MANAGER) diff --git a/flink-runtime/src/main/scala/org/apache/flink/runtime/minicluster/LocalFlinkMiniCluster.scala b/flink-runtime/src/main/scala/org/apache/flink/runtime/minicluster/LocalFlinkMiniCluster.scala index 5bebd48de2a29..d30c0470deb42 100644 --- a/flink-runtime/src/main/scala/org/apache/flink/runtime/minicluster/LocalFlinkMiniCluster.scala +++ b/flink-runtime/src/main/scala/org/apache/flink/runtime/minicluster/LocalFlinkMiniCluster.scala @@ -30,7 +30,7 @@ import org.apache.flink.runtime.clusterframework.types.ResourceID import org.apache.flink.runtime.io.network.netty.NettyConfig import org.apache.flink.runtime.jobmanager.{MemoryArchivist, JobManager} import org.apache.flink.runtime.messages.JobManagerMessages -import org.apache.flink.runtime.messages.JobManagerMessages.{StoppingFailure, StoppingResponse, RunningJobsStatus, RunningJobs} +import org.apache.flink.runtime.messages.JobManagerMessages.{CancellationFailure, CancellationResponse, StoppingFailure, StoppingResponse, RunningJobsStatus, RunningJobs} import org.apache.flink.runtime.taskmanager.TaskManager import org.apache.flink.runtime.util.EnvironmentInformation diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/KeyedStream.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/KeyedStream.java index cbf115b20b236..6998890309768 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/KeyedStream.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/KeyedStream.java @@ -22,6 +22,10 @@ import org.apache.flink.annotation.Public; import org.apache.flink.api.common.functions.FoldFunction; import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.common.state.FoldingStateDescriptor; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.state.ReducingStateDescriptor; +import org.apache.flink.api.common.state.ValueStateDescriptor; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.Utils; import org.apache.flink.api.java.functions.KeySelector; @@ -30,6 +34,8 @@ import org.apache.flink.streaming.api.functions.aggregation.AggregationFunction; import org.apache.flink.streaming.api.functions.aggregation.ComparableAggregator; import org.apache.flink.streaming.api.functions.aggregation.SumAggregator; +import org.apache.flink.streaming.api.functions.query.QueryableAppendingStateOperator; +import org.apache.flink.streaming.api.functions.query.QueryableValueStateOperator; import org.apache.flink.streaming.api.functions.sink.SinkFunction; import org.apache.flink.streaming.api.operators.OneInputStreamOperator; import org.apache.flink.streaming.api.operators.StreamGroupedFold; @@ -52,6 +58,8 @@ import org.apache.flink.streaming.runtime.partitioner.HashPartitioner; import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner; +import java.util.UUID; + /** * A {@code KeyedStream} represents a {@link DataStream} on which operator state is * partitioned by key using a provided {@link KeySelector}. Typical operations supported by a @@ -515,4 +523,117 @@ protected SingleOutputStreamOperator aggregate(AggregationFunction aggrega clean(aggregate), getType().createSerializer(getExecutionConfig())); return transform("Keyed Aggregation", getType(), operator); } + + /** + * Publishes the keyed stream as queryable ValueState instance. + * + * @param queryableStateName Name under which to the publish the queryable state instance + * @return Queryable state instance + */ + @PublicEvolving + public QueryableStateStream asQueryableState(String queryableStateName) { + ValueStateDescriptor valueStateDescriptor = new ValueStateDescriptor( + UUID.randomUUID().toString(), + getType(), + null); + + return asQueryableState(queryableStateName, valueStateDescriptor); + } + + /** + * Publishes the keyed stream as a queryable ValueState instance. + * + * @param queryableStateName Name under which to the publish the queryable state instance + * @param stateDescriptor State descriptor to create state instance from + * @return Queryable state instance + */ + @PublicEvolving + public QueryableStateStream asQueryableState( + String queryableStateName, + ValueStateDescriptor stateDescriptor) { + + transform("Queryable state: " + queryableStateName, + getType(), + new QueryableValueStateOperator<>(queryableStateName, stateDescriptor)); + + stateDescriptor.initializeSerializerUnlessSet(getExecutionConfig()); + + return new QueryableStateStream<>( + queryableStateName, + stateDescriptor.getSerializer(), + getKeyType().createSerializer(getExecutionConfig())); + } + + /** + * Publishes the keyed stream as a queryable ListStance instance. + * + * @param queryableStateName Name under which to the publish the queryable state instance + * @param stateDescriptor State descriptor to create state instance from + * @return Queryable state instance + */ + @PublicEvolving + public QueryableStateStream asQueryableState( + String queryableStateName, + ListStateDescriptor stateDescriptor) { + + transform("Queryable state: " + queryableStateName, + getType(), + new QueryableAppendingStateOperator<>(queryableStateName, stateDescriptor)); + + stateDescriptor.initializeSerializerUnlessSet(getExecutionConfig()); + + return new QueryableStateStream<>( + queryableStateName, + getType().createSerializer(getExecutionConfig()), + getKeyType().createSerializer(getExecutionConfig())); + } + + /** + * Publishes the keyed stream as a queryable FoldingState instance. + * + * @param queryableStateName Name under which to the publish the queryable state instance + * @param stateDescriptor State descriptor to create state instance from + * @return Queryable state instance + */ + @PublicEvolving + public QueryableStateStream asQueryableState( + String queryableStateName, + FoldingStateDescriptor stateDescriptor) { + + transform("Queryable state: " + queryableStateName, + getType(), + new QueryableAppendingStateOperator<>(queryableStateName, stateDescriptor)); + + stateDescriptor.initializeSerializerUnlessSet(getExecutionConfig()); + + return new QueryableStateStream<>( + queryableStateName, + stateDescriptor.getSerializer(), + getKeyType().createSerializer(getExecutionConfig())); + } + + /** + * Publishes the keyed stream as a queryable ReducingState instance. + * + * @param queryableStateName Name under which to the publish the queryable state instance + * @param stateDescriptor State descriptor to create state instance from + * @return Queryable state instance + */ + @PublicEvolving + public QueryableStateStream asQueryableState( + String queryableStateName, + ReducingStateDescriptor stateDescriptor) { + + transform("Queryable state: " + queryableStateName, + getType(), + new QueryableAppendingStateOperator<>(queryableStateName, stateDescriptor)); + + stateDescriptor.initializeSerializerUnlessSet(getExecutionConfig()); + + return new QueryableStateStream<>( + queryableStateName, + stateDescriptor.getSerializer(), + getKeyType().createSerializer(getExecutionConfig())); + } + } diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/QueryableStateStream.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/QueryableStateStream.java new file mode 100644 index 0000000000000..d0de2ab7f6bc0 --- /dev/null +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/QueryableStateStream.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.api.datastream; + +import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.util.Preconditions; + +/** + * Queryable state stream instance. + * + * @param State key type + * @param State value type + */ +@PublicEvolving +public class QueryableStateStream { + + /** Name under which the state is queryable. */ + private final String queryableStateName; + + /** Key serializer for the state instance. */ + private final TypeSerializer keySerializer; + + /** Value serializer for the state instance. */ + private final TypeSerializer valueSerializer; + + /** + * Creates a queryable state stream. + * + * @param queryableStateName Name under which to publish the queryable state instance + * @param valueSerializer Value serializer for the state instance + * @param keySerializer Key serializer for the state instance + */ + public QueryableStateStream( + String queryableStateName, + TypeSerializer valueSerializer, + TypeSerializer keySerializer) { + + this.queryableStateName = Preconditions.checkNotNull(queryableStateName, "Queryable state name"); + this.valueSerializer = Preconditions.checkNotNull(valueSerializer, "Value serializer"); + this.keySerializer = Preconditions.checkNotNull(keySerializer, "Key serializer"); + } + + /** + * Returns the name under which the state can be queried. + * + * @return Name under which state can be queried. + */ + public String getQueryableStateName() { + return queryableStateName; + } + + /** + * Returns the value serializer for the queryable state instance. + * + * @return Value serializer for the state instance + */ + public TypeSerializer getValueSerializer() { + return valueSerializer; + } + + /** + * Returns the key serializer for the queryable state instance. + * + * @return Key serializer for the state instance. + */ + public TypeSerializer getKeySerializer() { + return keySerializer; + } + +} diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/query/AbstractQueryableStateOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/query/AbstractQueryableStateOperator.java new file mode 100644 index 0000000000000..09c9b01179696 --- /dev/null +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/query/AbstractQueryableStateOperator.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.api.functions.query; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.state.State; +import org.apache.flink.api.common.state.StateDescriptor; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.watermark.Watermark; +import org.apache.flink.util.Preconditions; + +/** + * Internal operator handling queryable state instances (setup and update). + * + * @param State type + * @param Input type + */ +@Internal +abstract class AbstractQueryableStateOperator + extends AbstractStreamOperator + implements OneInputStreamOperator { + + /** State descriptor for the queryable state instance. */ + protected final StateDescriptor stateDescriptor; + + /** + * Name under which the queryable state is registered. + */ + protected final String registrationName; + + /** + * The state instance created on open. This is updated by the subclasses + * of this class, because the state update interface depends on the state + * type (e.g. AppendingState#add(IN) vs. ValueState#update(OUT)). + */ + protected transient S state; + + public AbstractQueryableStateOperator( + String registrationName, + StateDescriptor stateDescriptor) { + + this.registrationName = Preconditions.checkNotNull(registrationName, "Registration name"); + this.stateDescriptor = Preconditions.checkNotNull(stateDescriptor, "State descriptor"); + + if (stateDescriptor.isQueryable()) { + String name = stateDescriptor.getQueryableStateName(); + if (!name.equals(registrationName)) { + throw new IllegalArgumentException("StateDescriptor already marked as " + + "queryable with name '" + name + "', but created operator with name '" + + registrationName + "'."); + } // else: all good, already registered with same name + } else { + stateDescriptor.setQueryable(registrationName); + } + } + + @Override + public void open() throws Exception { + super.open(); + state = getPartitionedState(stateDescriptor); + } + + @Override + public void processWatermark(Watermark mark) throws Exception { + // Nothing to do + } +} diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/query/QueryableAppendingStateOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/query/QueryableAppendingStateOperator.java new file mode 100644 index 0000000000000..7ac14ed1d3b42 --- /dev/null +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/query/QueryableAppendingStateOperator.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.api.functions.query; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.state.AppendingState; +import org.apache.flink.api.common.state.StateDescriptor; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; + +/** + * Internal operator handling queryable AppendingState instances. + * + * @param Input type + */ +@Internal +public class QueryableAppendingStateOperator extends AbstractQueryableStateOperator, IN> { + + public QueryableAppendingStateOperator( + String registrationName, + StateDescriptor, ?> stateDescriptor) { + + super(registrationName, stateDescriptor); + } + + @Override + public void processElement(StreamRecord element) throws Exception { + state.add(element.getValue()); + } +} diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/query/QueryableValueStateOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/query/QueryableValueStateOperator.java new file mode 100644 index 0000000000000..49605a923db31 --- /dev/null +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/query/QueryableValueStateOperator.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.api.functions.query; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.state.StateDescriptor; +import org.apache.flink.api.common.state.ValueState; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; + +/** + * Internal operator handling queryable ValueState instances. + * + * @param Input type + */ +@Internal +public class QueryableValueStateOperator extends AbstractQueryableStateOperator, IN> { + + public QueryableValueStateOperator( + String registrationName, + StateDescriptor, IN> stateDescriptor) { + + super(registrationName, stateDescriptor); + } + + @Override + public void processElement(StreamRecord element) throws Exception { + state.update(element.getValue()); + } +} diff --git a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/KeyedStream.scala b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/KeyedStream.scala index 359b5b1de6f95..68eebeaa12275 100644 --- a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/KeyedStream.scala +++ b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/KeyedStream.scala @@ -18,13 +18,15 @@ package org.apache.flink.streaming.api.scala -import org.apache.flink.annotation.{PublicEvolving, Internal, Public} +import org.apache.flink.annotation.{Internal, Public, PublicEvolving} import org.apache.flink.api.common.functions._ +import org.apache.flink.api.common.state.{FoldingStateDescriptor, ListStateDescriptor, ReducingStateDescriptor, ValueStateDescriptor} import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.common.typeutils.TypeSerializer -import org.apache.flink.streaming.api.datastream.{DataStream => JavaStream, KeyedStream => KeyedJavaStream, WindowedStream => WindowedJavaStream} +import org.apache.flink.streaming.api.datastream.{DataStream => JavaStream, KeyedStream => KeyedJavaStream, QueryableStateStream, WindowedStream => WindowedJavaStream} import org.apache.flink.streaming.api.functions.aggregation.AggregationFunction.AggregationType import org.apache.flink.streaming.api.functions.aggregation.{ComparableAggregator, SumAggregator} +import org.apache.flink.streaming.api.functions.query.{QueryableAppendingStateOperator, QueryableValueStateOperator} import org.apache.flink.streaming.api.operators.StreamGroupedReduce import org.apache.flink.streaming.api.scala.function.StatefulFunction import org.apache.flink.streaming.api.windowing.assigners._ @@ -369,5 +371,117 @@ class KeyedStream[T, K](javaStream: KeyedJavaStream[T, K]) extends DataStream[T] flatMap(flatMapper) } + + /** + * Publishes the keyed stream as a queryable ValueState instance. + * + * @param queryableStateName Name under which to the publish the queryable state instance + * @return Queryable state instance + */ + @PublicEvolving + def asQueryableState(queryableStateName: String) : QueryableStateStream[K, T] = { + val stateDescriptor = new ValueStateDescriptor( + queryableStateName, + dataType.createSerializer(executionConfig), + null.asInstanceOf[T]) + + asQueryableState(queryableStateName, stateDescriptor) + } + + /** + * Publishes the keyed stream as a queryable ValueState instance. + * + * @param queryableStateName Name under which to the publish the queryable state instance + * @param stateDescriptor State descriptor to create state instance from + * @return Queryable state instance + */ + @PublicEvolving + def asQueryableState( + queryableStateName: String, + stateDescriptor: ValueStateDescriptor[T]) : QueryableStateStream[K, T] = { + + transform( + s"Queryable state: $queryableStateName", + new QueryableValueStateOperator(queryableStateName, stateDescriptor))(dataType) + + stateDescriptor.initializeSerializerUnlessSet(executionConfig) + + new QueryableStateStream( + queryableStateName, + stateDescriptor.getSerializer, + getKeyType.createSerializer(executionConfig)) + } + + /** + * Publishes the keyed stream as a queryable ListState instance. + * + * @param queryableStateName Name under which to the publish the queryable state instance + * @param stateDescriptor State descriptor to create state instance from + * @return Queryable state instance + */ + @PublicEvolving + def asQueryableState( + queryableStateName: String, + stateDescriptor: ListStateDescriptor[T]) : QueryableStateStream[K, T] = { + + transform( + s"Queryable state: $queryableStateName", + new QueryableAppendingStateOperator(queryableStateName, stateDescriptor))(dataType) + + stateDescriptor.initializeSerializerUnlessSet(executionConfig) + + new QueryableStateStream( + queryableStateName, + stateDescriptor.getSerializer, + getKeyType.createSerializer(executionConfig)) + } + + /** + * Publishes the keyed stream as a queryable FoldingState instance. + * + * @param queryableStateName Name under which to the publish the queryable state instance + * @param stateDescriptor State descriptor to create state instance from + * @return Queryable state instance + */ + @PublicEvolving + def asQueryableState[ACC]( + queryableStateName: String, + stateDescriptor: FoldingStateDescriptor[T, ACC]) : QueryableStateStream[K, ACC] = { + + transform( + s"Queryable state: $queryableStateName", + new QueryableAppendingStateOperator(queryableStateName, stateDescriptor))(dataType) + + stateDescriptor.initializeSerializerUnlessSet(executionConfig) + + new QueryableStateStream( + queryableStateName, + stateDescriptor.getSerializer, + getKeyType.createSerializer(executionConfig)) + } + + /** + * Publishes the keyed stream as a queryable ReducingState instance. + * + * @param queryableStateName Name under which to the publish the queryable state instance + * @param stateDescriptor State descriptor to create state instance from + * @return Queryable state instance + */ + @PublicEvolving + def asQueryableState( + queryableStateName: String, + stateDescriptor: ReducingStateDescriptor[T]) : QueryableStateStream[K, T] = { + + transform( + s"Queryable state: $queryableStateName", + new QueryableAppendingStateOperator(queryableStateName, stateDescriptor))(dataType) + + stateDescriptor.initializeSerializerUnlessSet(executionConfig) + + new QueryableStateStream( + queryableStateName, + stateDescriptor.getSerializer, + getKeyType.createSerializer(executionConfig)) + } } diff --git a/flink-test-utils-parent/flink-test-utils/src/main/scala/org/apache/flink/test/util/ForkableFlinkMiniCluster.scala b/flink-test-utils-parent/flink-test-utils/src/main/scala/org/apache/flink/test/util/ForkableFlinkMiniCluster.scala index 79c5a2590caae..42c0a6a88f147 100644 --- a/flink-test-utils-parent/flink-test-utils/src/main/scala/org/apache/flink/test/util/ForkableFlinkMiniCluster.scala +++ b/flink-test-utils-parent/flink-test-utils/src/main/scala/org/apache/flink/test/util/ForkableFlinkMiniCluster.scala @@ -165,6 +165,31 @@ class ForkableFlinkMiniCluster( classOf[TestingTaskManager]) } + def addTaskManager(): Unit = { + if (useSingleActorSystem) { + (jobManagerActorSystems, taskManagerActors) match { + case (Some(jmSystems), Some(tmActors)) => + val index = numTaskManagers + taskManagerActors = Some(tmActors :+ startTaskManager(index, jmSystems(0))) + numTaskManagers += 1 + case _ => throw new IllegalStateException("Cluster has not been started properly.") + } + } else { + (taskManagerActorSystems, taskManagerActors) match { + case (Some(tmSystems), Some(tmActors)) => + val index = numTaskManagers + val newTmSystem = startTaskManagerActorSystem(index) + val newTmActor = startTaskManager(index, newTmSystem) + + taskManagerActorSystems = Some(tmSystems :+ newTmSystem) + taskManagerActors = Some(tmActors :+ newTmActor) + + numTaskManagers += 1 + case _ => throw new IllegalStateException("Cluster has not been started properly.") + } + } + } + def restartLeadingJobManager(): Unit = { this.synchronized { (jobManagerActorSystems, jobManagerActors) match { diff --git a/flink-tests/src/test/java/org/apache/flink/test/query/QueryableStateITCase.java b/flink-tests/src/test/java/org/apache/flink/test/query/QueryableStateITCase.java new file mode 100644 index 0000000000000..c31f4e5fad6d6 --- /dev/null +++ b/flink-tests/src/test/java/org/apache/flink/test/query/QueryableStateITCase.java @@ -0,0 +1,1237 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.test.query; + +import akka.actor.ActorRef; +import akka.actor.ActorSystem; +import akka.actor.PoisonPill; +import akka.dispatch.Futures; +import akka.dispatch.Mapper; +import akka.dispatch.OnSuccess; +import akka.dispatch.Recover; +import akka.pattern.Patterns; +import org.apache.flink.api.common.JobID; +import org.apache.flink.api.common.functions.FoldFunction; +import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.api.common.state.FoldingStateDescriptor; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.state.ReducingStateDescriptor; +import org.apache.flink.api.common.state.ValueStateDescriptor; +import org.apache.flink.api.common.typeutils.base.StringSerializer; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.configuration.ConfigConstants; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.runtime.akka.AkkaUtils; +import org.apache.flink.runtime.execution.SuppressRestartsException; +import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; +import org.apache.flink.runtime.executiongraph.ExecutionGraph; +import org.apache.flink.runtime.jobgraph.JobGraph; +import org.apache.flink.runtime.jobgraph.JobStatus; +import org.apache.flink.runtime.messages.JobManagerMessages; +import org.apache.flink.runtime.messages.JobManagerMessages.CancellationSuccess; +import org.apache.flink.runtime.messages.JobManagerMessages.JobFound; +import org.apache.flink.runtime.query.KvStateLocation; +import org.apache.flink.runtime.query.KvStateMessage; +import org.apache.flink.runtime.query.QueryableStateClient; +import org.apache.flink.runtime.query.netty.message.KvStateRequestSerializer; +import org.apache.flink.runtime.state.CheckpointListener; +import org.apache.flink.runtime.state.VoidNamespace; +import org.apache.flink.runtime.state.VoidNamespaceSerializer; +import org.apache.flink.runtime.testingUtils.TestingJobManagerMessages; +import org.apache.flink.runtime.testingUtils.TestingTaskManagerMessages; +import org.apache.flink.runtime.testingUtils.TestingTaskManagerMessages.ResponseRunningTasks; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.datastream.QueryableStateStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction; +import org.apache.flink.test.util.ForkableFlinkMiniCluster; +import org.apache.flink.util.MathUtils; +import org.apache.flink.util.Preconditions; +import org.apache.flink.util.TestLogger; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; +import scala.concurrent.Await; +import scala.concurrent.Future; +import scala.concurrent.duration.Deadline; +import scala.concurrent.duration.FiniteDuration; +import scala.reflect.ClassTag$; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.Callable; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicLongArray; + +import static org.apache.flink.runtime.testingUtils.TestingJobManagerMessages.ExecutionGraphFound; +import static org.apache.flink.runtime.testingUtils.TestingJobManagerMessages.JobStatusIs; +import static org.apache.flink.runtime.testingUtils.TestingJobManagerMessages.NotifyWhenJobStatus; +import static org.apache.flink.runtime.testingUtils.TestingJobManagerMessages.RequestExecutionGraph; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +public class QueryableStateITCase extends TestLogger { + + private final static FiniteDuration TEST_TIMEOUT = new FiniteDuration(100, TimeUnit.SECONDS); + + private final static ActorSystem TEST_ACTOR_SYSTEM = AkkaUtils.createDefaultActorSystem(); + + private final static int NUM_TMS = 2; + private final static int NUM_SLOTS_PER_TM = 4; + private final static int NUM_SLOTS = NUM_TMS * NUM_SLOTS_PER_TM; + + /** + * Shared between all the test. Make sure to have at least NUM_SLOTS + * available after your test finishes, e.g. cancel the job you submitted. + */ + private static ForkableFlinkMiniCluster cluster; + + @BeforeClass + public static void setup() { + try { + Configuration config = new Configuration(); + config.setInteger(ConfigConstants.TASK_MANAGER_MEMORY_SIZE_KEY, 4); + config.setInteger(ConfigConstants.LOCAL_NUMBER_TASK_MANAGER, NUM_TMS); + config.setInteger(ConfigConstants.TASK_MANAGER_NUM_TASK_SLOTS, NUM_SLOTS_PER_TM); + config.setInteger(ConfigConstants.QUERYABLE_STATE_CLIENT_NETWORK_THREADS, 1); + config.setInteger(ConfigConstants.QUERYABLE_STATE_SERVER_NETWORK_THREADS, 1); + + cluster = new ForkableFlinkMiniCluster(config, false); + cluster.start(true); + } catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @AfterClass + public static void tearDown() { + try { + cluster.shutdown(); + } catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + + if (TEST_ACTOR_SYSTEM != null) { + TEST_ACTOR_SYSTEM.shutdown(); + } + } + + /** + * Runs a simple topology producing random (key, 1) pairs at the sources (where + * number of keys is in fixed in range 0...numKeys). The records are keyed and + * a reducing queryable state instance is created, which sums up the records. + * + * After submitting the job in detached mode, the QueryableStateCLient is used + * to query the counts of each key in rounds until all keys have non-zero counts. + */ + @Test + @SuppressWarnings("unchecked") + public void testQueryableState() throws Exception { + // Config + final Deadline deadline = TEST_TIMEOUT.fromNow(); + final int numKeys = 1024; + + final QueryableStateClient client = new QueryableStateClient(cluster.configuration()); + + JobID jobId = null; + + try { + // + // Test program + // + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + env.setParallelism(NUM_SLOTS); + // Very important, because cluster is shared between tests and we + // don't explicitly check that all slots are available before + // submitting. + env.setRestartStrategy(RestartStrategies.fixedDelayRestart(Integer.MAX_VALUE, 1000)); + + DataStream> source = env + .addSource(new TestKeyRangeSource(numKeys)); + + // Reducing state + ReducingStateDescriptor> reducingState = new ReducingStateDescriptor<>( + "any-name", + new SumReduce(), + source.getType()); + + final String queryName = "hakuna-matata"; + + final QueryableStateStream> queryableState = + source.keyBy(new KeySelector, Integer>() { + @Override + public Integer getKey(Tuple2 value) throws Exception { + return value.f0; + } + }).asQueryableState(queryName, reducingState); + + // Submit the job graph + JobGraph jobGraph = env.getStreamGraph().getJobGraph(); + cluster.submitJobDetached(jobGraph); + + // + // Start querying + // + jobId = jobGraph.getJobID(); + + final AtomicLongArray counts = new AtomicLongArray(numKeys); + + final FiniteDuration retryDelay = new FiniteDuration(1, TimeUnit.SECONDS); + + boolean allNonZero = false; + while (!allNonZero && deadline.hasTimeLeft()) { + allNonZero = true; + + final List> futures = new ArrayList<>(numKeys); + + for (int i = 0; i < numKeys; i++) { + final int key = i; + + if (counts.get(key) > 0) { + // Skip this one + continue; + } else { + allNonZero = false; + } + + final byte[] serializedKey = KvStateRequestSerializer.serializeKeyAndNamespace( + key, + queryableState.getKeySerializer(), + VoidNamespace.INSTANCE, + VoidNamespaceSerializer.INSTANCE); + + Future serializedResult = getKvStateWithRetries( + client, + jobId, + queryName, + key, + serializedKey, + retryDelay); + + serializedResult.onSuccess(new OnSuccess() { + @Override + public void onSuccess(byte[] result) throws Throwable { + Tuple2 value = KvStateRequestSerializer.deserializeValue( + result, + queryableState.getValueSerializer()); + + counts.set(key, value.f1); + + assertEquals("Key mismatch", key, value.f0.intValue()); + } + }, TEST_ACTOR_SYSTEM.dispatcher()); + + futures.add(serializedResult); + } + + Future> futureSequence = Futures.sequence( + futures, + TEST_ACTOR_SYSTEM.dispatcher()); + + Await.ready(futureSequence, deadline.timeLeft()); + } + + assertTrue("Not all keys are non-zero", allNonZero); + + // All should be non-zero + for (int i = 0; i < numKeys; i++) { + long count = counts.get(i); + assertTrue("Count at position " + i + " is " + count, count > 0); + } + } finally { + // Free cluster resources + if (jobId != null) { + Future cancellation = cluster + .getLeaderGateway(deadline.timeLeft()) + .ask(new JobManagerMessages.CancelJob(jobId), deadline.timeLeft()) + .mapTo(ClassTag$.MODULE$.apply(CancellationSuccess.class)); + + Await.ready(cancellation, deadline.timeLeft()); + } + + client.shutDown(); + } + } + + /** + * Queries a random key and waits for some checkpoints to complete. After + * that the task manager where this key was held is killed. Then query the + * key again and check for the expected Exception. Finally, add another + * task manager and re-query the key (expecting a count >= the previous + * one). + */ + @Test + public void testQueryableStateWithTaskManagerFailure() throws Exception { + // Config + final Deadline deadline = TEST_TIMEOUT.fromNow(); + final int numKeys = 1024; + + final QueryableStateClient client = new QueryableStateClient(cluster.configuration()); + + JobID jobId = null; + + try { + // + // Test program + // + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + env.setParallelism(NUM_SLOTS); + // Very important, because cluster is shared between tests and we + // don't explicitly check that all slots are available before + // submitting. + env.setRestartStrategy(RestartStrategies.fixedDelayRestart(Integer.MAX_VALUE, 1000)); + env.getCheckpointConfig().setCheckpointInterval(1000); + + DataStream> source = env + .addSource(new TestKeyRangeSource(numKeys)); + + // Reducing state + ReducingStateDescriptor> reducingState = new ReducingStateDescriptor<>( + "any-name", + new SumReduce(), + source.getType()); + + final String queryName = "hakuna-matata"; + + final QueryableStateStream> queryableState = + source.keyBy(new KeySelector, Integer>() { + @Override + public Integer getKey(Tuple2 value) throws Exception { + return value.f0; + } + }).asQueryableState(queryName, reducingState); + + // Submit the job graph + JobGraph jobGraph = env.getStreamGraph().getJobGraph(); + cluster.submitJobDetached(jobGraph); + + // + // Start querying + // + jobId = jobGraph.getJobID(); + + final int key = ThreadLocalRandom.current().nextInt(numKeys); + + // Query a random key + final byte[] serializedKey = KvStateRequestSerializer.serializeKeyAndNamespace( + key, + queryableState.getKeySerializer(), + VoidNamespace.INSTANCE, + VoidNamespaceSerializer.INSTANCE); + + long countForKey = 0; + + boolean success = false; + while (!success && deadline.hasTimeLeft()) { + final FiniteDuration retryDelay = new FiniteDuration(1, TimeUnit.SECONDS); + Future serializedResultFuture = getKvStateWithRetries( + client, + jobId, + queryName, + key, + serializedKey, + retryDelay); + + byte[] serializedResult = Await.result(serializedResultFuture, deadline.timeLeft()); + + Tuple2 result = KvStateRequestSerializer.deserializeValue( + serializedResult, + queryableState.getValueSerializer()); + + countForKey = result.f1; + + assertEquals("Key mismatch", key, result.f0.intValue()); + success = countForKey > 1000; // Wait for some progress + } + + assertTrue("No progress for count", countForKey > 1000); + + long currentCheckpointId = TestKeyRangeSource.LATEST_CHECKPOINT_ID.get(); + long waitUntilCheckpointId = currentCheckpointId + 5; + + // Wait for some checkpoint after the query result + while (deadline.hasTimeLeft() && + TestKeyRangeSource.LATEST_CHECKPOINT_ID.get() < waitUntilCheckpointId) { + Thread.sleep(500); + } + + assertTrue("Did not complete enough checkpoints to continue", + TestKeyRangeSource.LATEST_CHECKPOINT_ID.get() >= waitUntilCheckpointId); + + // + // Find out on which TaskManager the KvState is located and kill that TaskManager + // + // This is the subtask index + int keyGroupIndex = MathUtils.murmurHash(key) % NUM_SLOTS; + + // Find out which task manager holds this key + Future egFuture = cluster.getLeaderGateway(deadline.timeLeft()) + .ask(new RequestExecutionGraph(jobId), deadline.timeLeft()) + .mapTo(ClassTag$.MODULE$.apply(ExecutionGraphFound.class)) + .map(new Mapper() { + @Override + public ExecutionGraph apply(ExecutionGraphFound found) { + return found.executionGraph(); + } + }, TEST_ACTOR_SYSTEM.dispatcher()); + ExecutionGraph eg = Await.result(egFuture, deadline.timeLeft()); + + Future locationFuture = cluster + .getLeaderGateway(deadline.timeLeft()) + .ask(new KvStateMessage.LookupKvStateLocation(jobId, queryName), deadline.timeLeft()) + .mapTo(ClassTag$.MODULE$.apply(KvStateLocation.class)); + + KvStateLocation location = Await.result(locationFuture, deadline.timeLeft()); + + ExecutionAttemptID executionId = eg.getJobVertex(location.getJobVertexId()) + .getTaskVertices()[keyGroupIndex] + .getCurrentExecutionAttempt() + .getAttemptId(); + + List taskManagers = cluster.getTaskManagersAsJava(); + ActorRef taskManagerToKill = null; + for (ActorRef taskManager : taskManagers) { + Future runningFuture = Patterns.ask( + taskManager, + TestingTaskManagerMessages.getRequestRunningTasksMessage(), + deadline.timeLeft().toMillis()) + .mapTo(ClassTag$.MODULE$.apply(ResponseRunningTasks.class)); + + ResponseRunningTasks running = Await.result(runningFuture, deadline.timeLeft()); + + if (running.asJava().containsKey(executionId)) { + taskManagerToKill = taskManager; + break; + } + } + + assertNotNull("Did not find TaskManager holding the key", taskManagerToKill); + + // Kill the task manager + taskManagerToKill.tell(PoisonPill.getInstance(), ActorRef.noSender()); + + success = false; + for (int i = 0; i < 10 && !success; i++) { + try { + // Wait for the expected error. We might have to retry if + // the query is very fast. + Await.result(client.getKvState(jobId, queryName, key, serializedKey), deadline.timeLeft()); + Thread.sleep(500); + } catch (Throwable ignored) { + success = true; + } + } + + assertTrue("Query did not fail", success); + + // Now start another task manager + cluster.addTaskManager(); + + final FiniteDuration retryDelay = new FiniteDuration(1, TimeUnit.SECONDS); + Future serializedResultFuture = getKvStateWithRetries( + client, + jobId, + queryName, + key, + serializedKey, + retryDelay); + + byte[] serializedResult = Await.result(serializedResultFuture, deadline.timeLeft()); + + Tuple2 result = KvStateRequestSerializer.deserializeValue( + serializedResult, + queryableState.getValueSerializer()); + + assertTrue("Count moved backwards", result.f1 >= countForKey); + } finally { + // Free cluster resources + if (jobId != null) { + Future cancellation = cluster + .getLeaderGateway(deadline.timeLeft()) + .ask(new JobManagerMessages.CancelJob(jobId), deadline.timeLeft()) + .mapTo(ClassTag$.MODULE$.apply(CancellationSuccess.class)); + + Await.ready(cancellation, deadline.timeLeft()); + } + + client.shutDown(); + } + } + + /** + * Tests that duplicate query registrations fail the job at the JobManager. + */ + @Test + public void testDuplicateRegistrationFailsJob() throws Exception { + // Config + final Deadline deadline = TEST_TIMEOUT.fromNow(); + final int numKeys = 1024; + + JobID jobId = null; + + try { + // + // Test program + // + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + env.setParallelism(NUM_SLOTS); + // Very important, because cluster is shared between tests and we + // don't explicitly check that all slots are available before + // submitting. + env.setRestartStrategy(RestartStrategies.fixedDelayRestart(Integer.MAX_VALUE, 1000)); + + DataStream> source = env + .addSource(new TestKeyRangeSource(numKeys)); + + // Reducing state + ReducingStateDescriptor> reducingState = new ReducingStateDescriptor<>( + "any-name", + new SumReduce(), + source.getType()); + + final String queryName = "duplicate-me"; + + final QueryableStateStream> queryableState = + source.keyBy(new KeySelector, Integer>() { + @Override + public Integer getKey(Tuple2 value) throws Exception { + return value.f0; + } + }).asQueryableState(queryName, reducingState); + + final QueryableStateStream> duplicate = + source.keyBy(new KeySelector, Integer>() { + @Override + public Integer getKey(Tuple2 value) throws Exception { + return value.f0; + } + }).asQueryableState(queryName); + + // Submit the job graph + JobGraph jobGraph = env.getStreamGraph().getJobGraph(); + jobId = jobGraph.getJobID(); + + Future failedFuture = cluster + .getLeaderGateway(deadline.timeLeft()) + .ask(new NotifyWhenJobStatus(jobId, JobStatus.FAILED), deadline.timeLeft()) + .mapTo(ClassTag$.MODULE$.apply(JobStatusIs.class)); + + cluster.submitJobDetached(jobGraph); + + JobStatusIs jobStatus = Await.result(failedFuture, deadline.timeLeft()); + assertEquals(JobStatus.FAILED, jobStatus.state()); + + // Get the job and check the cause + JobFound jobFound = Await.result( + cluster.getLeaderGateway(deadline.timeLeft()) + .ask(new JobManagerMessages.RequestJob(jobId), deadline.timeLeft()) + .mapTo(ClassTag$.MODULE$.apply(JobFound.class)), + deadline.timeLeft()); + + Throwable failureCause = jobFound.executionGraph().getFailureCause(); + + assertTrue("Not instance of SuppressRestartsException", failureCause instanceof SuppressRestartsException); + assertTrue("Not caused by IllegalStateException", failureCause.getCause() instanceof IllegalStateException); + Throwable duplicateException = failureCause.getCause(); + assertTrue("Exception does not contain registration name", duplicateException.getMessage().contains(queryName)); + } finally { + // Free cluster resources + if (jobId != null) { + Future cancellation = cluster + .getLeaderGateway(deadline.timeLeft()) + .ask(new JobManagerMessages.CancelJob(jobId), deadline.timeLeft()) + .mapTo(ClassTag$.MODULE$.apply(CancellationSuccess.class)); + + Await.ready(cancellation, deadline.timeLeft()); + } + } + } + + /** + * Tests simple value state queryable state instance. Each source emits + * (subtaskIndex, 0)..(subtaskIndex, numElements) tuples, which are then + * queried. The tests succeeds after each subtask index is queried with + * value numElements (the latest element updated the state). + */ + @Test + public void testValueState() throws Exception { + // Config + final Deadline deadline = TEST_TIMEOUT.fromNow(); + + final int numElements = 1024; + + final QueryableStateClient client = new QueryableStateClient(cluster.configuration()); + + JobID jobId = null; + try { + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + env.setParallelism(NUM_SLOTS); + // Very important, because cluster is shared between tests and we + // don't explicitly check that all slots are available before + // submitting. + env.setRestartStrategy(RestartStrategies.fixedDelayRestart(Integer.MAX_VALUE, 1000)); + + DataStream> source = env + .addSource(new TestAscendingValueSource(numElements)); + + // Value state + ValueStateDescriptor> valueState = new ValueStateDescriptor<>( + "any", + source.getType(), + null); + + QueryableStateStream> queryableState = + source.keyBy(new KeySelector, Integer>() { + @Override + public Integer getKey(Tuple2 value) throws Exception { + return value.f0; + } + }).asQueryableState("hakuna", valueState); + + // Submit the job graph + JobGraph jobGraph = env.getStreamGraph().getJobGraph(); + jobId = jobGraph.getJobID(); + + cluster.submitJobDetached(jobGraph); + + // Now query + long expected = numElements; + + FiniteDuration retryDelay = new FiniteDuration(1, TimeUnit.SECONDS); + for (int key = 0; key < NUM_SLOTS; key++) { + final byte[] serializedKey = KvStateRequestSerializer.serializeKeyAndNamespace( + key, + queryableState.getKeySerializer(), + VoidNamespace.INSTANCE, + VoidNamespaceSerializer.INSTANCE); + + boolean success = false; + while (deadline.hasTimeLeft() && !success) { + Future future = getKvStateWithRetries(client, + jobId, + queryableState.getQueryableStateName(), + key, + serializedKey, + retryDelay); + + byte[] serializedValue = Await.result(future, deadline.timeLeft()); + + Tuple2 value = KvStateRequestSerializer.deserializeValue( + serializedValue, + queryableState.getValueSerializer()); + + assertEquals("Key mismatch", key, value.f0.intValue()); + if (expected == value.f1) { + success = true; + } else { + // Retry + Thread.sleep(50); + } + } + + assertTrue("Did not succeed query", success); + } + } finally { + // Free cluster resources + if (jobId != null) { + Future cancellation = cluster + .getLeaderGateway(deadline.timeLeft()) + .ask(new JobManagerMessages.CancelJob(jobId), deadline.timeLeft()) + .mapTo(ClassTag$.MODULE$.apply(CancellationSuccess.class)); + + Await.ready(cancellation, deadline.timeLeft()); + } + + client.shutDown(); + } + } + + /** + * Tests simple value state queryable state instance. Each source emits + * (subtaskIndex, 0)..(subtaskIndex, numElements) tuples, which are then + * queried. The tests succeeds after each subtask index is queried with + * value numElements (the latest element updated the state). + * + * This is the same as the simple value state test, but uses the API shortcut. + */ + @Test + public void testValueStateShortcut() throws Exception { + // Config + final Deadline deadline = TEST_TIMEOUT.fromNow(); + + final int numElements = 1024; + + final QueryableStateClient client = new QueryableStateClient(cluster.configuration()); + + JobID jobId = null; + try { + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + env.setParallelism(NUM_SLOTS); + // Very important, because cluster is shared between tests and we + // don't explicitly check that all slots are available before + // submitting. + env.setRestartStrategy(RestartStrategies.fixedDelayRestart(Integer.MAX_VALUE, 1000)); + + DataStream> source = env + .addSource(new TestAscendingValueSource(numElements)); + + // Value state shortcut + QueryableStateStream> queryableState = + source.keyBy(new KeySelector, Integer>() { + @Override + public Integer getKey(Tuple2 value) throws Exception { + return value.f0; + } + }).asQueryableState("matata"); + + // Submit the job graph + JobGraph jobGraph = env.getStreamGraph().getJobGraph(); + jobId = jobGraph.getJobID(); + + cluster.submitJobDetached(jobGraph); + + // Now query + long expected = numElements; + + FiniteDuration retryDelay = new FiniteDuration(1, TimeUnit.SECONDS); + for (int key = 0; key < NUM_SLOTS; key++) { + final byte[] serializedKey = KvStateRequestSerializer.serializeKeyAndNamespace( + key, + queryableState.getKeySerializer(), + VoidNamespace.INSTANCE, + VoidNamespaceSerializer.INSTANCE); + + boolean success = false; + while (deadline.hasTimeLeft() && !success) { + Future future = getKvStateWithRetries(client, + jobId, + queryableState.getQueryableStateName(), + key, + serializedKey, + retryDelay); + + byte[] serializedValue = Await.result(future, deadline.timeLeft()); + + Tuple2 value = KvStateRequestSerializer.deserializeValue( + serializedValue, + queryableState.getValueSerializer()); + + assertEquals("Key mismatch", key, value.f0.intValue()); + if (expected == value.f1) { + success = true; + } else { + // Retry + Thread.sleep(50); + } + } + + assertTrue("Did not succeed query", success); + } + } finally { + // Free cluster resources + if (jobId != null) { + Future cancellation = cluster + .getLeaderGateway(deadline.timeLeft()) + .ask(new JobManagerMessages.CancelJob(jobId), deadline.timeLeft()) + .mapTo(ClassTag$.MODULE$.apply(CancellationSuccess.class)); + + Await.ready(cancellation, deadline.timeLeft()); + } + + client.shutDown(); + } + } + + /** + * Tests simple list state queryable state instance. Each source emits + * (subtaskIndex, 0)..(subtaskIndex, numElements) tuples, which are then + * queried. The tests succeeds after each subtask index is queried with + * a list of size numElements and each emitted tuple is part of the list. + */ + @Test + public void testListState() throws Exception { + // Config + final Deadline deadline = TEST_TIMEOUT.fromNow(); + + final int numElements = 128; + + final QueryableStateClient client = new QueryableStateClient(cluster.configuration()); + + JobID jobId = null; + try { + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + env.setParallelism(NUM_SLOTS); + // Very important, because cluster is shared between tests and we + // don't explicitly check that all slots are available before + // submitting. + env.setRestartStrategy(RestartStrategies.fixedDelayRestart(Integer.MAX_VALUE, 1000)); + + DataStream> source = env + .addSource(new TestAscendingValueSource(numElements)); + + // List state + ListStateDescriptor> listState = new ListStateDescriptor<>( + "any", + source.getType()); + + QueryableStateStream> queryableState = + source.keyBy(new KeySelector, Integer>() { + @Override + public Integer getKey(Tuple2 value) throws Exception { + return value.f0; + } + }).asQueryableState("timon", listState); + + // Submit the job graph + JobGraph jobGraph = env.getStreamGraph().getJobGraph(); + jobId = jobGraph.getJobID(); + + cluster.submitJobDetached(jobGraph); + + // Now query + long expected = numElements + 1; // +1 for 0-value + + FiniteDuration retryDelay = new FiniteDuration(1, TimeUnit.SECONDS); + for (int key = 0; key < NUM_SLOTS; key++) { + final byte[] serializedKey = KvStateRequestSerializer.serializeKeyAndNamespace( + key, + queryableState.getKeySerializer(), + VoidNamespace.INSTANCE, + VoidNamespaceSerializer.INSTANCE); + + boolean success = false; + while (deadline.hasTimeLeft() && !success) { + Future future = getKvStateWithRetries(client, + jobId, + queryableState.getQueryableStateName(), + key, + serializedKey, + retryDelay); + + byte[] serializedValue = Await.result(future, deadline.timeLeft()); + + List> list = KvStateRequestSerializer.deserializeList( + serializedValue, + queryableState.getValueSerializer()); + + if (list.size() == expected) { + for (int i = 0; i < expected; i++) { + Tuple2 elem = list.get(i); + assertEquals("Key mismatch", key, elem.f0.intValue()); + assertEquals("Value mismatch", i, elem.f1.longValue()); + } + + success = true; + } else { + // Retry + Thread.sleep(50); + } + } + + assertTrue("Did not succeed query", success); + } + } finally { + // Free cluster resources + if (jobId != null) { + Future cancellation = cluster + .getLeaderGateway(deadline.timeLeft()) + .ask(new JobManagerMessages.CancelJob(jobId), deadline.timeLeft()) + .mapTo(ClassTag$.MODULE$.apply(CancellationSuccess.class)); + + Await.ready(cancellation, deadline.timeLeft()); + } + + client.shutDown(); + } + } + + /** + * Tests simple folding state queryable state instance. Each source emits + * (subtaskIndex, 0)..(subtaskIndex, numElements) tuples, which are then + * queried. The folding state sums these up and maps them to Strings. The + * test succeeds after each subtask index is queried with result n*(n+1)/2 + * (as a String). + */ + @Test + public void testFoldingState() throws Exception { + // Config + final Deadline deadline = TEST_TIMEOUT.fromNow(); + + final int numElements = 1024; + + final QueryableStateClient client = new QueryableStateClient(cluster.configuration()); + + JobID jobId = null; + try { + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + env.setParallelism(NUM_SLOTS); + // Very important, because cluster is shared between tests and we + // don't explicitly check that all slots are available before + // submitting. + env.setRestartStrategy(RestartStrategies.fixedDelayRestart(Integer.MAX_VALUE, 1000)); + + DataStream> source = env + .addSource(new TestAscendingValueSource(numElements)); + + // Folding state + FoldingStateDescriptor, String> foldingState = + new FoldingStateDescriptor<>( + "any", + "0", + new SumFold(), + StringSerializer.INSTANCE); + + QueryableStateStream queryableState = + source.keyBy(new KeySelector, Integer>() { + @Override + public Integer getKey(Tuple2 value) throws Exception { + return value.f0; + } + }).asQueryableState("pumba", foldingState); + + // Submit the job graph + JobGraph jobGraph = env.getStreamGraph().getJobGraph(); + jobId = jobGraph.getJobID(); + + cluster.submitJobDetached(jobGraph); + + // Now query + String expected = Integer.toString(numElements * (numElements + 1) / 2); + + FiniteDuration retryDelay = new FiniteDuration(1, TimeUnit.SECONDS); + for (int key = 0; key < NUM_SLOTS; key++) { + final byte[] serializedKey = KvStateRequestSerializer.serializeKeyAndNamespace( + key, + queryableState.getKeySerializer(), + VoidNamespace.INSTANCE, + VoidNamespaceSerializer.INSTANCE); + + boolean success = false; + while (deadline.hasTimeLeft() && !success) { + Future future = getKvStateWithRetries(client, + jobId, + queryableState.getQueryableStateName(), + key, + serializedKey, + retryDelay); + + byte[] serializedValue = Await.result(future, deadline.timeLeft()); + + String value = KvStateRequestSerializer.deserializeValue( + serializedValue, + queryableState.getValueSerializer()); + + if (expected.equals(value)) { + success = true; + } else { + // Retry + Thread.sleep(50); + } + } + + assertTrue("Did not succeed query", success); + } + } finally { + // Free cluster resources + if (jobId != null) { + Future cancellation = cluster + .getLeaderGateway(deadline.timeLeft()) + .ask(new JobManagerMessages.CancelJob(jobId), deadline.timeLeft()) + .mapTo(ClassTag$.MODULE$.apply(CancellationSuccess.class)); + + Await.ready(cancellation, deadline.timeLeft()); + } + + client.shutDown(); + } + } + + /** + * Tests simple reducing state queryable state instance. Each source emits + * (subtaskIndex, 0)..(subtaskIndex, numElements) tuples, which are then + * queried. The reducing state instance sums these up. The test succeeds + * after each subtask index is queried with result n*(n+1)/2. + */ + @Test + public void testReducingState() throws Exception { + // Config + final Deadline deadline = TEST_TIMEOUT.fromNow(); + + final int numElements = 1024; + + final QueryableStateClient client = new QueryableStateClient(cluster.configuration()); + + JobID jobId = null; + try { + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + env.setParallelism(NUM_SLOTS); + // Very important, because cluster is shared between tests and we + // don't explicitly check that all slots are available before + // submitting. + env.setRestartStrategy(RestartStrategies.fixedDelayRestart(Integer.MAX_VALUE, 1000)); + + DataStream> source = env + .addSource(new TestAscendingValueSource(numElements)); + + // Reducing state + ReducingStateDescriptor> reducingState = + new ReducingStateDescriptor<>( + "any", + new SumReduce(), + source.getType()); + + QueryableStateStream> queryableState = + source.keyBy(new KeySelector, Integer>() { + @Override + public Integer getKey(Tuple2 value) throws Exception { + return value.f0; + } + }).asQueryableState("jungle", reducingState); + + // Submit the job graph + JobGraph jobGraph = env.getStreamGraph().getJobGraph(); + jobId = jobGraph.getJobID(); + + cluster.submitJobDetached(jobGraph); + + // Wait until job is running + + // Now query + long expected = numElements * (numElements + 1) / 2; + + FiniteDuration retryDelay = new FiniteDuration(1, TimeUnit.SECONDS); + for (int key = 0; key < NUM_SLOTS; key++) { + final byte[] serializedKey = KvStateRequestSerializer.serializeKeyAndNamespace( + key, + queryableState.getKeySerializer(), + VoidNamespace.INSTANCE, + VoidNamespaceSerializer.INSTANCE); + + boolean success = false; + while (deadline.hasTimeLeft() && !success) { + Future future = getKvStateWithRetries(client, + jobId, + queryableState.getQueryableStateName(), + key, + serializedKey, + retryDelay); + + byte[] serializedValue = Await.result(future, deadline.timeLeft()); + + Tuple2 value = KvStateRequestSerializer.deserializeValue( + serializedValue, + queryableState.getValueSerializer()); + + assertEquals("Key mismatch", key, value.f0.intValue()); + if (expected == value.f1) { + success = true; + } else { + // Retry + Thread.sleep(50); + } + } + + assertTrue("Did not succeed query", success); + } + } finally { + // Free cluster resources + if (jobId != null) { + Future cancellation = cluster + .getLeaderGateway(deadline.timeLeft()) + .ask(new JobManagerMessages.CancelJob(jobId), deadline.timeLeft()) + .mapTo(ClassTag$.MODULE$.apply(CancellationSuccess.class)); + + Await.ready(cancellation, deadline.timeLeft()); + } + + client.shutDown(); + } + } + + @SuppressWarnings("unchecked") + private static Future getKvStateWithRetries( + final QueryableStateClient client, + final JobID jobId, + final String queryName, + final int key, + final byte[] serializedKey, + final FiniteDuration retryDelay) { + + return client.getKvState(jobId, queryName, key, serializedKey) + .recoverWith(new Recover>() { + @Override + public Future recover(Throwable failure) throws Throwable { + if (failure instanceof AssertionError) { + return Futures.failed(failure); + } else { + // At startup some failures are expected + // due to races. Make sure that they don't + // fail this test. + return Patterns.after( + retryDelay, + TEST_ACTOR_SYSTEM.scheduler(), + TEST_ACTOR_SYSTEM.dispatcher(), + new Callable>() { + @Override + public Future call() throws Exception { + return getKvStateWithRetries( + client, + jobId, + queryName, + key, + serializedKey, + retryDelay); + } + }); + } + } + }, TEST_ACTOR_SYSTEM.dispatcher()); + } + + /** + * Test source producing (key, 0)..(key, maxValue) with key being the sub + * task index. + * + *

After all tuples have been emitted, the source waits to be cancelled + * and does not immediately finish. + */ + private static class TestAscendingValueSource extends RichParallelSourceFunction> { + + private final long maxValue; + private volatile boolean isRunning = true; + + public TestAscendingValueSource(long maxValue) { + Preconditions.checkArgument(maxValue >= 0); + this.maxValue = maxValue; + } + + @Override + public void open(Configuration parameters) throws Exception { + super.open(parameters); + } + + @Override + public void run(SourceContext> ctx) throws Exception { + // f0 => key + int key = getRuntimeContext().getIndexOfThisSubtask(); + Tuple2 record = new Tuple2<>(key, 0L); + + long currentValue = 0; + while (isRunning && currentValue <= maxValue) { + synchronized (ctx.getCheckpointLock()) { + record.f1 = currentValue; + ctx.collect(record); + } + + currentValue++; + } + + while (isRunning) { + synchronized (this) { + this.wait(); + } + } + } + + @Override + public void cancel() { + isRunning = false; + + synchronized (this) { + this.notifyAll(); + } + } + + } + + /** + * Test source producing (key, 1) tuples with random key in key range (numKeys). + */ + private static class TestKeyRangeSource extends RichParallelSourceFunction> + implements CheckpointListener { + + private final static AtomicLong LATEST_CHECKPOINT_ID = new AtomicLong(); + private final int numKeys; + private final ThreadLocalRandom random = ThreadLocalRandom.current(); + private volatile boolean isRunning = true; + + public TestKeyRangeSource(int numKeys) { + this.numKeys = numKeys; + } + + @Override + public void open(Configuration parameters) throws Exception { + super.open(parameters); + if (getRuntimeContext().getIndexOfThisSubtask() == 0) { + LATEST_CHECKPOINT_ID.set(0); + } + } + + @Override + public void run(SourceContext> ctx) throws Exception { + // f0 => key + Tuple2 record = new Tuple2<>(0, 1L); + + while (isRunning) { + synchronized (ctx.getCheckpointLock()) { + record.f0 = random.nextInt(numKeys); + ctx.collect(record); + } + } + } + + @Override + public void cancel() { + isRunning = false; + } + + @Override + public void notifyCheckpointComplete(long checkpointId) throws Exception { + if (getRuntimeContext().getIndexOfThisSubtask() == 0) { + LATEST_CHECKPOINT_ID.set(checkpointId); + } + } + } + + private static class SumFold implements FoldFunction, String> { + @Override + public String fold(String accumulator, Tuple2 value) throws Exception { + long acc = Long.valueOf(accumulator); + acc += value.f1; + return Long.toString(acc); + } + } + + private static class SumReduce implements ReduceFunction> { + @Override + public Tuple2 reduce(Tuple2 value1, Tuple2 value2) throws Exception { + value1.f1 += value2.f1; + return value1; + } + } + +}