From 044b30e5cbdfe66fa6b481bd945ae6703521c544 Mon Sep 17 00:00:00 2001 From: kkloudas Date: Tue, 30 Jan 2018 17:06:16 +0100 Subject: [PATCH 1/6] [FLINK-8522] [checkpoint] Remove number of states from checkpoint. --- .../apache/flink/runtime/state/DefaultOperatorStateBackend.java | 2 -- 1 file changed, 2 deletions(-) diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java index c45b9e35bb6ff..e6d6dc6549b5a 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java @@ -280,8 +280,6 @@ public OperatorStateHandle performOperation() throws Exception { backendSerializationProxy.write(dov); - dov.writeInt(registeredStatesDeepCopies.size()); - for (Map.Entry> entry : registeredStatesDeepCopies.entrySet()) { From 759a3710e4af0a9367f6abfe78af4b5d55771754 Mon Sep 17 00:00:00 2001 From: kkloudas Date: Thu, 21 Dec 2017 13:51:35 +0100 Subject: [PATCH 2/6] [FLINK-4940] Add broadcast state to the OperatorStateBackend. --- .../kafka/FlinkKafkaConsumerBaseTest.java | 12 + .../api/common/state/BroadcastState.java | 90 +++++ .../api/common/state/OperatorStateStore.java | 32 ++ .../common/state/ReadOnlyBroadcastState.java | 70 ++++ .../OperatorStateRepartitioner.java | 4 +- .../RoundRobinOperatorStateRepartitioner.java | 80 +++-- .../checkpoint/StateAssignmentOperation.java | 5 +- .../state/AbstractKeyedStateBackend.java | 4 +- .../state/BackendWritableBroadcastState.java | 42 +++ .../state/DefaultOperatorStateBackend.java | 322 +++++++++++++++--- .../runtime/state/HeapBroadcastState.java | 154 +++++++++ .../OperatorBackendSerializationProxy.java | 73 ++-- ...endStateMetaInfoSnapshotReaderWriters.java | 142 +++++++- .../runtime/state/OperatorStateHandle.java | 5 +- ...gisteredBroadcastBackendStateMetaInfo.java | 230 +++++++++++++ .../checkpoint/CheckpointCoordinatorTest.java | 74 +++- .../savepoint/CheckpointTestUtils.java | 4 +- .../state/OperatorStateBackendTest.java | 199 ++++++++++- .../state/OperatorStateHandleTest.java | 5 +- .../state/SerializationProxiesTest.java | 109 +++++- 20 files changed, 1496 insertions(+), 160 deletions(-) create mode 100644 flink-core/src/main/java/org/apache/flink/api/common/state/BroadcastState.java create mode 100644 flink-core/src/main/java/org/apache/flink/api/common/state/ReadOnlyBroadcastState.java create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/state/BackendWritableBroadcastState.java create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/state/HeapBroadcastState.java create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/state/RegisteredBroadcastBackendStateMetaInfo.java diff --git a/flink-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBaseTest.java b/flink-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBaseTest.java index 41b609ec31995..5040966337af4 100644 --- a/flink-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBaseTest.java +++ b/flink-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBaseTest.java @@ -19,9 +19,11 @@ package org.apache.flink.streaming.connectors.kafka; import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.state.BroadcastState; import org.apache.flink.api.common.state.KeyedStateStore; import org.apache.flink.api.common.state.ListState; import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.state.MapStateDescriptor; import org.apache.flink.api.common.state.OperatorStateStore; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.configuration.Configuration; @@ -893,6 +895,11 @@ public ListState getOperatorState(ListStateDescriptor stateDescriptor) throw new UnsupportedOperationException(); } + @Override + public BroadcastState getBroadcastState(MapStateDescriptor stateDescriptor) throws Exception { + throw new UnsupportedOperationException(); + } + @Override public ListState getListState(ListStateDescriptor stateDescriptor) throws Exception { throw new UnsupportedOperationException(); @@ -902,6 +909,11 @@ public ListState getListState(ListStateDescriptor stateDescriptor) thr public Set getRegisteredStateNames() { throw new UnsupportedOperationException(); } + + @Override + public Set getRegisteredBroadcastStateNames() { + throw new UnsupportedOperationException(); + } } private static class MockFunctionInitializationContext implements FunctionInitializationContext { diff --git a/flink-core/src/main/java/org/apache/flink/api/common/state/BroadcastState.java b/flink-core/src/main/java/org/apache/flink/api/common/state/BroadcastState.java new file mode 100644 index 0000000000000..0cece41a46f2f --- /dev/null +++ b/flink-core/src/main/java/org/apache/flink/api/common/state/BroadcastState.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.common.state; + +import org.apache.flink.annotation.PublicEvolving; + +import java.util.Iterator; +import java.util.Map; + +/** + * A type of state that can be created to store the state of a {@code BroadcastStream}. This state assumes that + * the same elements are sent to all instances of an operator. + * + *

CAUTION: the user has to guarantee that all task instances store the same elements in this type of state. + * + *

Each operator instance individually maintains and stores elements in the broadcast state. The fact that the + * incoming stream is a broadcast one guarantees that all instances see all the elements. Upon recovery + * or re-scaling, the same state is given to each of the instances. To avoid hotspots, each task reads its previous + * partition, and if there are more tasks (scale up), then the new instances read from the old instances in a round + * robin fashion. This is why each instance has to guarantee that it stores the same elements as the rest. If not, + * upon recovery or rescaling you may have unpredictable redistribution of the partitions, thus unpredictable results. + * + * @param The key type of the elements in the {@link BroadcastState}. + * @param The value type of the elements in the {@link BroadcastState}. + */ +@PublicEvolving +public interface BroadcastState extends ReadOnlyBroadcastState { + + /** + * Associates a new value with the given key. + * + * @param key The key of the mapping + * @param value The new value of the mapping + * + * @throws Exception Thrown if the system cannot access the state. + */ + void put(K key, V value) throws Exception; + + /** + * Copies all of the mappings from the given map into the state. + * + * @param map The mappings to be stored in this state + * + * @throws Exception Thrown if the system cannot access the state. + */ + void putAll(Map map) throws Exception; + + /** + * Deletes the mapping of the given key. + * + * @param key The key of the mapping + * + * @throws Exception Thrown if the system cannot access the state. + */ + void remove(K key) throws Exception; + + /** + * Iterates over all the mappings in the state. + * + * @return An iterator over all the mappings in the state + * + * @throws Exception Thrown if the system cannot access the state. + */ + Iterator> iterator() throws Exception; + + /** + * Returns all the mappings in the state + * + * @return An iterable view of all the key-value pairs in the state. + * + * @throws Exception Thrown if the system cannot access the state. + */ + Iterable> entries() throws Exception; +} diff --git a/flink-core/src/main/java/org/apache/flink/api/common/state/OperatorStateStore.java b/flink-core/src/main/java/org/apache/flink/api/common/state/OperatorStateStore.java index bf220419db181..c2037e0b58425 100644 --- a/flink-core/src/main/java/org/apache/flink/api/common/state/OperatorStateStore.java +++ b/flink-core/src/main/java/org/apache/flink/api/common/state/OperatorStateStore.java @@ -29,6 +29,31 @@ @PublicEvolving public interface OperatorStateStore { + /** + * Creates (or restores) a {@link BroadcastState broadcast state}. This type of state can only be created to store + * the state of a {@code BroadcastStream}. Each state is registered under a unique name. + * The provided serializer is used to de/serialize the state in case of checkpointing (snapshot/restore). + * The returned broadcast state has {@code key-value} format. + * + *

CAUTION: the user has to guarantee that all task instances store the same elements in this type of state. + * + *

Each operator instance individually maintains and stores elements in the broadcast state. The fact that the + * incoming stream is a broadcast one guarantees that all instances see all the elements. Upon recovery + * or re-scaling, the same state is given to each of the instances. To avoid hotspots, each task reads its previous + * partition, and if there are more tasks (scale up), then the new instances read from the old instances in a round + * robin fashion. This is why each instance has to guarantee that it stores the same elements as the rest. If not, + * upon recovery or rescaling you may have unpredictable redistribution of the partitions, thus unpredictable results. + * + * @param stateDescriptor The descriptor for this state, providing a name, a serializer for the keys and one for the + * values. + * @param The type of the keys in the broadcast state. + * @param The type of the values in the broadcast state. + * + * @return The {@link BroadcastState Broadcast State}. + * @throws Exception + */ + BroadcastState getBroadcastState(MapStateDescriptor stateDescriptor) throws Exception; + /** * Creates (or restores) a list state. Each state is registered under a unique name. * The provided serializer is used to de/serialize the state in case of checkpointing (snapshot/restore). @@ -83,6 +108,13 @@ public interface OperatorStateStore { */ Set getRegisteredStateNames(); + /** + * Returns a set with the names of all currently registered broadcast states. + * + * @return set of names for all registered broadcast states. + */ + Set getRegisteredBroadcastStateNames(); + // ------------------------------------------------------------------------------------------- // Deprecated methods // ------------------------------------------------------------------------------------------- diff --git a/flink-core/src/main/java/org/apache/flink/api/common/state/ReadOnlyBroadcastState.java b/flink-core/src/main/java/org/apache/flink/api/common/state/ReadOnlyBroadcastState.java new file mode 100644 index 0000000000000..4d3f2e72a97c8 --- /dev/null +++ b/flink-core/src/main/java/org/apache/flink/api/common/state/ReadOnlyBroadcastState.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.common.state; + +import org.apache.flink.annotation.PublicEvolving; + +import java.util.Map; + +/** + * A read-only view of the {@link BroadcastState}. + * + *

Although read-only, the user code should not modify the value + * returned by the {@link #get(Object)} or the entries of the immutable + * iterator returned by the {@link #immutableEntries()}, as this can lead to + * inconsistent states. The reason for this is that we do not create extra + * copies of the elements for performance reasons. + * + * @param The key type of the elements in the {@link ReadOnlyBroadcastState}. + * @param The value type of the elements in the {@link ReadOnlyBroadcastState}. + */ +@PublicEvolving +public interface ReadOnlyBroadcastState extends State { + + /** + * Returns the current value associated with the given key. + * + *

The user code must not modify the value returned, as + * this can lead to inconsistent states. + * + * @param key The key of the mapping + * @return The value of the mapping with the given key + * + * @throws Exception Thrown if the system cannot access the state. + */ + V get(K key) throws Exception; + + /** + * Returns whether there exists the given mapping. + * + * @param key The key of the mapping + * @return True if there exists a mapping whose key equals to the given key + * + * @throws Exception Thrown if the system cannot access the state. + */ + boolean contains(K key) throws Exception; + + /** + * Returns an immutable {@link Iterable} over the entries in the state. + * + *

The user code must not modify the entries of the returned immutable + * iterator, as this can lead to inconsistent states. + */ + Iterable> immutableEntries() throws Exception; +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorStateRepartitioner.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorStateRepartitioner.java index 98810f187c706..090f48a3c87ad 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorStateRepartitioner.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorStateRepartitioner.java @@ -31,12 +31,12 @@ public interface OperatorStateRepartitioner { /** * @param previousParallelSubtaskStates List of state handles to the parallel subtask states of an operator, as they * have been checkpointed. - * @param parallelism The parallelism that we consider for the state redistribution. Determines the size of the + * @param newParallelism The parallelism that we consider for the state redistribution. Determines the size of the * returned list. * @return List with one entry per parallel subtask. Each subtask receives now one collection of states that build * of the new total state for this subtask. */ List> repartitionState( List previousParallelSubtaskStates, - int parallelism); + int newParallelism); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/RoundRobinOperatorStateRepartitioner.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/RoundRobinOperatorStateRepartitioner.java index 4513ef80b32b1..e09b677c3b0d8 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/RoundRobinOperatorStateRepartitioner.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/RoundRobinOperatorStateRepartitioner.java @@ -42,10 +42,10 @@ public class RoundRobinOperatorStateRepartitioner implements OperatorStateRepart @Override public List> repartitionState( List previousParallelSubtaskStates, - int parallelism) { + int newParallelism) { Preconditions.checkNotNull(previousParallelSubtaskStates); - Preconditions.checkArgument(parallelism > 0); + Preconditions.checkArgument(newParallelism > 0); // Reorganize: group by (State Name -> StreamStateHandle + Offsets) GroupByStateNameResults nameToStateByMode = groupByStateName(previousParallelSubtaskStates); @@ -55,11 +55,11 @@ public List> repartitionState( } // Assemble result from all merge maps - List> result = new ArrayList<>(parallelism); + List> result = new ArrayList<>(newParallelism); // Do the actual repartitioning for all named states List> mergeMapList = - repartition(nameToStateByMode, parallelism); + repartition(nameToStateByMode, newParallelism); for (int i = 0; i < mergeMapList.size(); ++i) { result.add(i, new ArrayList<>(mergeMapList.get(i).values())); @@ -72,8 +72,7 @@ public List> repartitionState( * Group by the different named states. */ @SuppressWarnings("unchecked, rawtype") - private GroupByStateNameResults groupByStateName( - List previousParallelSubtaskStates) { + private GroupByStateNameResults groupByStateName(List previousParallelSubtaskStates) { //Reorganize: group by (State Name -> StreamStateHandle + StateMetaInfo) EnumMap(OperatorStateHandle.Mode.class); for (OperatorStateHandle.Mode mode : OperatorStateHandle.Mode.values()) { - Map>> map = new HashMap<>(); - nameToStateByMode.put( - mode, - new HashMap>>()); + nameToStateByMode.put(mode, new HashMap<>()); } for (OperatorStateHandle psh : previousParallelSubtaskStates) { @@ -120,14 +116,14 @@ private GroupByStateNameResults groupByStateName( */ private List> repartition( GroupByStateNameResults nameToStateByMode, - int parallelism) { + int newParallelism) { // We will use this to merge w.r.t. StreamStateHandles for each parallel subtask inside the maps - List> mergeMapList = new ArrayList<>(parallelism); + List> mergeMapList = new ArrayList<>(newParallelism); // Initialize - for (int i = 0; i < parallelism; ++i) { - mergeMapList.add(new HashMap()); + for (int i = 0; i < newParallelism; ++i) { + mergeMapList.add(new HashMap<>()); } // Start with the state handles we distribute round robin by splitting by offsets @@ -150,15 +146,15 @@ private List> repartition( // Repartition the state across the parallel operator instances int lstIdx = 0; int offsetIdx = 0; - int baseFraction = totalPartitions / parallelism; - int remainder = totalPartitions % parallelism; + int baseFraction = totalPartitions / newParallelism; + int remainder = totalPartitions % newParallelism; int newStartParallelOp = startParallelOp; - for (int i = 0; i < parallelism; ++i) { + for (int i = 0; i < newParallelism; ++i) { // Preparation: calculate the actual index considering wrap around - int parallelOpIdx = (i + startParallelOp) % parallelism; + int parallelOpIdx = (i + startParallelOp) % newParallelism; // Now calculate the number of partitions we will assign to the parallel instance in this round ... int numberOfPartitionsToAssign = baseFraction; @@ -209,10 +205,7 @@ private List> repartition( Map mergeMap = mergeMapList.get(parallelOpIdx); OperatorStateHandle operatorStateHandle = mergeMap.get(handleWithOffsets.f0); if (operatorStateHandle == null) { - operatorStateHandle = new OperatorStateHandle( - new HashMap(), - handleWithOffsets.f0); - + operatorStateHandle = new OperatorStateHandle(new HashMap<>(), handleWithOffsets.f0); mergeMap.put(handleWithOffsets.f0, operatorStateHandle); } operatorStateHandle.getStateNameToPartitionOffsets().put( @@ -226,30 +219,51 @@ private List> repartition( // Now we also add the state handles marked for broadcast to all parallel instances Map>> broadcastNameToState = - nameToStateByMode.getByMode(OperatorStateHandle.Mode.BROADCAST); + nameToStateByMode.getByMode(OperatorStateHandle.Mode.UNION); - for (int i = 0; i < parallelism; ++i) { + for (int i = 0; i < newParallelism; ++i) { Map mergeMap = mergeMapList.get(i); for (Map.Entry>> e : broadcastNameToState.entrySet()) { - List> current = e.getValue(); - - for (Tuple2 handleWithMetaInfo : current) { + for (Tuple2 handleWithMetaInfo : e.getValue()) { OperatorStateHandle operatorStateHandle = mergeMap.get(handleWithMetaInfo.f0); if (operatorStateHandle == null) { - operatorStateHandle = new OperatorStateHandle( - new HashMap(), - handleWithMetaInfo.f0); - + operatorStateHandle = new OperatorStateHandle(new HashMap<>(), handleWithMetaInfo.f0); mergeMap.put(handleWithMetaInfo.f0, operatorStateHandle); } operatorStateHandle.getStateNameToPartitionOffsets().put(e.getKey(), handleWithMetaInfo.f1); } } } + + // Now we also add the state handles marked for uniform broadcast to all parallel instances + Map>> uniformBroadcastNameToState = + nameToStateByMode.getByMode(OperatorStateHandle.Mode.BROADCAST); + + for (int i = 0; i < newParallelism; ++i) { + + final Map mergeMap = mergeMapList.get(i); + + // for each name, pick the i-th entry + for (Map.Entry>> e : + uniformBroadcastNameToState.entrySet()) { + + int oldParallelism = e.getValue().size(); + + Tuple2 handleWithMetaInfo = + e.getValue().get(i % oldParallelism); + + OperatorStateHandle operatorStateHandle = mergeMap.get(handleWithMetaInfo.f0); + if (operatorStateHandle == null) { + operatorStateHandle = new OperatorStateHandle(new HashMap<>(), handleWithMetaInfo.f0); + mergeMap.put(handleWithMetaInfo.f0, operatorStateHandle); + } + operatorStateHandle.getStateNameToPartitionOffsets().put(e.getKey(), handleWithMetaInfo.f1); + } + } return mergeMapList; } @@ -257,7 +271,7 @@ private static final class GroupByStateNameResults { private final EnumMap>>> byMode; - public GroupByStateNameResults( + GroupByStateNameResults( EnumMap>>> byMode) { this.byMode = Preconditions.checkNotNull(byMode); @@ -268,4 +282,4 @@ public Map> applyRepartitioner( Map partitionOffsets = operatorStateHandle.getStateNameToPartitionOffsets(); - for (OperatorStateHandle.StateMetaInfo metaInfo : partitionOffsets.values()) { // if we find any broadcast state, we cannot take the shortcut and need to go through repartitioning - if (OperatorStateHandle.Mode.BROADCAST.equals(metaInfo.getDistributionMode())) { + if (OperatorStateHandle.Mode.UNION.equals(metaInfo.getDistributionMode())) { return opStateRepartitioner.repartitionState( chainOpParallelStates, newParallelism); @@ -639,7 +638,7 @@ public static List> applyRepartitioner( /** * Determine the subset of {@link KeyGroupsStateHandle KeyGroupsStateHandles} with correct * key group index for the given subtask {@link KeyGroupRange}. - *

+ * *

This is publicly visible to be used in tests. */ public static List getKeyedStateHandles( diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java index 0ffde7a8a7ba4..cc53c0cb895f9 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java @@ -98,9 +98,7 @@ public abstract class AbstractKeyedStateBackend private final ExecutionConfig executionConfig; - /** - * Decorates the input and output streams to write key-groups compressed. - */ + /** Decorates the input and output streams to write key-groups compressed. */ protected final StreamCompressionDecorator keyGroupCompressionDecorator; public AbstractKeyedStateBackend( diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/BackendWritableBroadcastState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/BackendWritableBroadcastState.java new file mode 100644 index 0000000000000..8daf07c018abe --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/BackendWritableBroadcastState.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state; + +import org.apache.flink.api.common.state.BroadcastState; +import org.apache.flink.core.fs.FSDataOutputStream; + +import java.io.IOException; + +/** + * An interface with methods related to the interplay between the {@link BroadcastState Broadcast State} and + * the {@link OperatorStateBackend}. + * + * @param The key type of the elements in the {@link BroadcastState Broadcast State}. + * @param The value type of the elements in the {@link BroadcastState Broadcast State}. + */ +public interface BackendWritableBroadcastState extends BroadcastState { + + BackendWritableBroadcastState deepCopy(); + + long write(FSDataOutputStream out) throws IOException; + + void setStateMetaInfo(RegisteredBroadcastBackendStateMetaInfo stateMetaInfo); + + RegisteredBroadcastBackendStateMetaInfo getStateMetaInfo(); +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java index e6d6dc6549b5a..f4866439b467f 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java @@ -22,6 +22,8 @@ import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.state.ListState; import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.state.MapStateDescriptor; +import org.apache.flink.api.common.state.BroadcastState; import org.apache.flink.api.common.typeutils.CompatibilityResult; import org.apache.flink.api.common.typeutils.CompatibilityUtil; import org.apache.flink.api.common.typeutils.TypeSerializer; @@ -69,7 +71,12 @@ public class DefaultOperatorStateBackend implements OperatorStateBackend { /** * Map for all registered operator states. Maps state name -> state */ - private final Map> registeredStates; + private final Map> registeredOperatorStates; + + /** + * Map for all registered operator broadcast states. Maps state name -> state + */ + private final Map> registeredBroadcastStates; /** * CloseableRegistry to participate in the tasks lifecycle. @@ -102,12 +109,17 @@ public class DefaultOperatorStateBackend implements OperatorStateBackend { *

TODO this map can be removed when eager-state registration is in place. * TODO we currently need this cached to check state migration strategies when new serializers are registered. */ - private final Map> restoredStateMetaInfos; + private final Map> restoredOperatorStateMetaInfos; + + /** + * Map of state names to their corresponding restored broadcast state meta info. + */ + private final Map> restoredBroadcastStateMetaInfos; /** * Cache of already accessed states. * - *

In contrast to {@link #registeredStates} and {@link #restoredStateMetaInfos} which may be repopulated + *

In contrast to {@link #registeredOperatorStates} and {@link #restoredOperatorStateMetaInfos} which may be repopulated * with restored state, this map is always empty at the beginning. * *

TODO this map should be moved to a base class once we have proper hierarchy for the operator state backends. @@ -116,6 +128,8 @@ public class DefaultOperatorStateBackend implements OperatorStateBackend { */ private final HashMap> accessedStatesByName; + private final Map> accessedBroadcastStatesByName; + public DefaultOperatorStateBackend( ClassLoader userClassLoader, ExecutionConfig executionConfig, @@ -125,10 +139,13 @@ public DefaultOperatorStateBackend( this.userClassloader = Preconditions.checkNotNull(userClassLoader); this.executionConfig = executionConfig; this.javaSerializer = new JavaSerializer<>(); - this.registeredStates = new HashMap<>(); + this.registeredOperatorStates = new HashMap<>(); + this.registeredBroadcastStates = new HashMap<>(); this.asynchronousSnapshots = asynchronousSnapshots; this.accessedStatesByName = new HashMap<>(); - this.restoredStateMetaInfos = new HashMap<>(); + this.accessedBroadcastStatesByName = new HashMap<>(); + this.restoredOperatorStateMetaInfos = new HashMap<>(); + this.restoredBroadcastStateMetaInfos = new HashMap<>(); } public ExecutionConfig getExecutionConfig() { @@ -137,7 +154,12 @@ public ExecutionConfig getExecutionConfig() { @Override public Set getRegisteredStateNames() { - return registeredStates.keySet(); + return registeredOperatorStates.keySet(); + } + + @Override + public Set getRegisteredBroadcastStateNames() { + return registeredBroadcastStates.keySet(); } @Override @@ -148,13 +170,95 @@ public void close() throws IOException { @Override public void dispose() { IOUtils.closeQuietly(closeStreamOnCancelRegistry); - registeredStates.clear(); + registeredOperatorStates.clear(); + registeredBroadcastStates.clear(); } // ------------------------------------------------------------------------------------------- // State access methods // ------------------------------------------------------------------------------------------- + @Override + public BroadcastState getBroadcastState(final MapStateDescriptor stateDescriptor) throws StateMigrationException { + + Preconditions.checkNotNull(stateDescriptor); + String name = Preconditions.checkNotNull(stateDescriptor.getName()); + + @SuppressWarnings("unchecked") + BackendWritableBroadcastState previous = (BackendWritableBroadcastState) accessedBroadcastStatesByName.get(name); + if (previous != null) { + checkStateNameAndMode( + previous.getStateMetaInfo().getName(), + name, + previous.getStateMetaInfo().getAssignmentMode(), + OperatorStateHandle.Mode.BROADCAST); + return previous; + } + + stateDescriptor.initializeSerializerUnlessSet(getExecutionConfig()); + TypeSerializer broadcastStateKeySerializer = Preconditions.checkNotNull(stateDescriptor.getKeySerializer()); + TypeSerializer broadcastStateValueSerializer = Preconditions.checkNotNull(stateDescriptor.getValueSerializer()); + + BackendWritableBroadcastState broadcastState = (BackendWritableBroadcastState) registeredBroadcastStates.get(name); + + if (broadcastState == null) { + broadcastState = new HeapBroadcastState<>( + new RegisteredBroadcastBackendStateMetaInfo<>( + name, + OperatorStateHandle.Mode.BROADCAST, + broadcastStateKeySerializer, + broadcastStateValueSerializer)); + registeredBroadcastStates.put(name, broadcastState); + } else { + // has restored state; check compatibility of new state access + + checkStateNameAndMode( + broadcastState.getStateMetaInfo().getName(), + name, + broadcastState.getStateMetaInfo().getAssignmentMode(), + OperatorStateHandle.Mode.BROADCAST); + + @SuppressWarnings("unchecked") + RegisteredBroadcastBackendStateMetaInfo.Snapshot restoredMetaInfo = + (RegisteredBroadcastBackendStateMetaInfo.Snapshot) restoredBroadcastStateMetaInfos.get(name); + + // check compatibility to determine if state migration is required + CompatibilityResult keyCompatibility = CompatibilityUtil.resolveCompatibilityResult( + restoredMetaInfo.getKeySerializer(), + UnloadableDummyTypeSerializer.class, + restoredMetaInfo.getKeySerializerConfigSnapshot(), + broadcastStateKeySerializer); + + CompatibilityResult valueCompatibility = CompatibilityUtil.resolveCompatibilityResult( + restoredMetaInfo.getValueSerializer(), + UnloadableDummyTypeSerializer.class, + restoredMetaInfo.getValueSerializerConfigSnapshot(), + broadcastStateValueSerializer); + + if (!keyCompatibility.isRequiresMigration() && !valueCompatibility.isRequiresMigration()) { + // new serializer is compatible; use it to replace the old serializer + broadcastState.setStateMetaInfo( + new RegisteredBroadcastBackendStateMetaInfo<>( + name, + OperatorStateHandle.Mode.BROADCAST, + broadcastStateKeySerializer, + broadcastStateValueSerializer)); + } else { + // TODO state migration currently isn't possible. + + // NOTE: for heap backends, it is actually fine to proceed here without failing the restore, + // since the state has already been deserialized to objects and we can just continue with + // the new serializer; we're deliberately failing here for now to have equal functionality with + // the RocksDB backend to avoid confusion for users. + + throw new StateMigrationException("State migration isn't supported, yet."); + } + } + + accessedBroadcastStatesByName.put(name, broadcastState); + return broadcastState; + } + @Override public ListState getListState(ListStateDescriptor stateDescriptor) throws Exception { return getListState(stateDescriptor, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE); @@ -162,7 +266,7 @@ public ListState getListState(ListStateDescriptor stateDescriptor) thr @Override public ListState getUnionListState(ListStateDescriptor stateDescriptor) throws Exception { - return getListState(stateDescriptor, OperatorStateHandle.Mode.BROADCAST); + return getListState(stateDescriptor, OperatorStateHandle.Mode.UNION); } // ------------------------------------------------------------------------------------------- @@ -203,23 +307,39 @@ public RunnableFuture snapshot( final long syncStartTime = System.currentTimeMillis(); - if (registeredStates.isEmpty()) { + if (registeredOperatorStates.isEmpty() && registeredBroadcastStates.isEmpty()) { return DoneFuture.nullValue(); } - final Map> registeredStatesDeepCopies = - new HashMap<>(registeredStates.size()); + final Map> registeredOperatorStatesDeepCopies = + new HashMap<>(registeredOperatorStates.size()); + final Map> registeredBroadcastStatesDeepCopies = + new HashMap<>(registeredBroadcastStates.size()); - // eagerly create deep copies of the list states in the sync phase, so that we can use them in the async writing ClassLoader snapshotClassLoader = Thread.currentThread().getContextClassLoader(); Thread.currentThread().setContextClassLoader(userClassloader); try { - for (Map.Entry> entry : this.registeredStates.entrySet()) { - PartitionableListState listState = entry.getValue(); - if (null != listState) { - listState = listState.deepCopy(); + // eagerly create deep copies of the list and the broadcast states (if any) + // in the synchronous phase, so that we can use them in the async writing. + + if (!registeredOperatorStates.isEmpty()) { + for (Map.Entry> entry : registeredOperatorStates.entrySet()) { + PartitionableListState listState = entry.getValue(); + if (null != listState) { + listState = listState.deepCopy(); + } + registeredOperatorStatesDeepCopies.put(entry.getKey(), listState); + } + } + + if (!registeredBroadcastStates.isEmpty()) { + for (Map.Entry> entry : registeredBroadcastStates.entrySet()) { + BackendWritableBroadcastState broadcastState = entry.getValue(); + if (null != broadcastState) { + broadcastState = broadcastState.deepCopy(); + } + registeredBroadcastStatesDeepCopies.put(entry.getKey(), broadcastState); } - registeredStatesDeepCopies.put(entry.getKey(), listState); } } finally { Thread.currentThread().setContextClassLoader(snapshotClassLoader); @@ -263,25 +383,38 @@ public OperatorStateHandle performOperation() throws Exception { CheckpointStreamFactory.CheckpointStateOutputStream localOut = this.out; - final Map writtenStatesMetaData = - new HashMap<>(registeredStatesDeepCopies.size()); + // get the registered operator state infos ... + List> operatorMetaInfoSnapshots = + new ArrayList<>(registeredOperatorStatesDeepCopies.size()); + + for (Map.Entry> entry : registeredOperatorStatesDeepCopies.entrySet()) { + operatorMetaInfoSnapshots.add(entry.getValue().getStateMetaInfo().snapshot()); + } - List> metaInfoSnapshots = - new ArrayList<>(registeredStatesDeepCopies.size()); + // ... get the registered broadcast operator state infos ... + List> broadcastMetaInfoSnapshots = + new ArrayList<>(registeredBroadcastStatesDeepCopies.size()); - for (Map.Entry> entry : registeredStatesDeepCopies.entrySet()) { - metaInfoSnapshots.add(entry.getValue().getStateMetaInfo().snapshot()); + for (Map.Entry> entry : registeredBroadcastStatesDeepCopies.entrySet()) { + broadcastMetaInfoSnapshots.add(entry.getValue().getStateMetaInfo().snapshot()); } + // ... write them all in the checkpoint stream ... DataOutputView dov = new DataOutputViewStreamWrapper(localOut); OperatorBackendSerializationProxy backendSerializationProxy = - new OperatorBackendSerializationProxy(metaInfoSnapshots); + new OperatorBackendSerializationProxy(operatorMetaInfoSnapshots, broadcastMetaInfoSnapshots); backendSerializationProxy.write(dov); + // ... and then go for the states ... + + // we put BOTH normal and broadcast state metadata here + final Map writtenStatesMetaData = + new HashMap<>(registeredOperatorStatesDeepCopies.size() + registeredBroadcastStatesDeepCopies.size()); + for (Map.Entry> entry : - registeredStatesDeepCopies.entrySet()) { + registeredOperatorStatesDeepCopies.entrySet()) { PartitionableListState value = entry.getValue(); long[] partitionOffsets = value.write(localOut); @@ -291,6 +424,19 @@ public OperatorStateHandle performOperation() throws Exception { new OperatorStateHandle.StateMetaInfo(partitionOffsets, mode)); } + // ... and the broadcast states themselves ... + for (Map.Entry> entry : + registeredBroadcastStatesDeepCopies.entrySet()) { + + BackendWritableBroadcastState value = entry.getValue(); + long[] partitionOffsets = {value.write(localOut)}; + OperatorStateHandle.Mode mode = value.getStateMetaInfo().getAssignmentMode(); + writtenStatesMetaData.put( + entry.getKey(), + new OperatorStateHandle.StateMetaInfo(partitionOffsets, mode)); + } + + // ... and, finally, create the state handle. OperatorStateHandle retValue = null; if (closeStreamOnCancelRegistry.unregisterCloseable(out)) { @@ -348,11 +494,11 @@ public void restore(Collection restoreSnapshots) throws Exc backendSerializationProxy.read(new DataInputViewStreamWrapper(in)); - List> restoredMetaInfoSnapshots = - backendSerializationProxy.getStateMetaInfoSnapshots(); + List> restoredOperatorMetaInfoSnapshots = + backendSerializationProxy.getOperatorStateMetaInfoSnapshots(); // Recreate all PartitionableListStates from the meta info - for (RegisteredOperatorBackendStateMetaInfo.Snapshot restoredMetaInfo : restoredMetaInfoSnapshots) { + for (RegisteredOperatorBackendStateMetaInfo.Snapshot restoredMetaInfo : restoredOperatorMetaInfoSnapshots) { if (restoredMetaInfo.getPartitionStateSerializer() == null || restoredMetaInfo.getPartitionStateSerializer() instanceof UnloadableDummyTypeSerializer) { @@ -368,9 +514,9 @@ public void restore(Collection restoreSnapshots) throws Exc " not be loaded. This is a temporary restriction that will be fixed in future versions."); } - restoredStateMetaInfos.put(restoredMetaInfo.getName(), restoredMetaInfo); + restoredOperatorStateMetaInfos.put(restoredMetaInfo.getName(), restoredMetaInfo); - PartitionableListState listState = registeredStates.get(restoredMetaInfo.getName()); + PartitionableListState listState = registeredOperatorStates.get(restoredMetaInfo.getName()); if (null == listState) { listState = new PartitionableListState<>( @@ -379,22 +525,66 @@ public void restore(Collection restoreSnapshots) throws Exc restoredMetaInfo.getPartitionStateSerializer(), restoredMetaInfo.getAssignmentMode())); - registeredStates.put(listState.getStateMetaInfo().getName(), listState); + registeredOperatorStates.put(listState.getStateMetaInfo().getName(), listState); + } else { + // TODO with eager state registration in place, check here for serializer migration strategies + } + } + + // ... and then get back the broadcast state. + List> restoredBroadcastMetaInfoSnapshots = + backendSerializationProxy.getBroadcastStateMetaInfoSnapshots(); + + for (RegisteredBroadcastBackendStateMetaInfo.Snapshot restoredMetaInfo : restoredBroadcastMetaInfoSnapshots) { + + if (restoredMetaInfo.getKeySerializer() == null || restoredMetaInfo.getValueSerializer() == null || + restoredMetaInfo.getKeySerializer() instanceof UnloadableDummyTypeSerializer || + restoredMetaInfo.getValueSerializer() instanceof UnloadableDummyTypeSerializer) { + + // must fail now if the previous serializer cannot be restored because there is no serializer + // capable of reading previous state + // TODO when eager state registration is in place, we can try to get a convert deserializer + // TODO from the newly registered serializer instead of simply failing here + + throw new IOException("Unable to restore broadcast state [" + restoredMetaInfo.getName() + "]." + + " The previous key and value serializers of the state must be present; the serializers could" + + " have been removed from the classpath, or their implementations have changed and could" + + " not be loaded. This is a temporary restriction that will be fixed in future versions."); + } + + restoredBroadcastStateMetaInfos.put(restoredMetaInfo.getName(), restoredMetaInfo); + + BackendWritableBroadcastState broadcastState = registeredBroadcastStates.get(restoredMetaInfo.getName()); + + if (broadcastState == null) { + broadcastState = new HeapBroadcastState<>( + new RegisteredBroadcastBackendStateMetaInfo<>( + restoredMetaInfo.getName(), + restoredMetaInfo.getAssignmentMode(), + restoredMetaInfo.getKeySerializer(), + restoredMetaInfo.getValueSerializer())); + + registeredBroadcastStates.put(broadcastState.getStateMetaInfo().getName(), broadcastState); } else { // TODO with eager state registration in place, check here for serializer migration strategies } } - // Restore all the state in PartitionableListStates + // Restore all the states for (Map.Entry nameToOffsets : stateHandle.getStateNameToPartitionOffsets().entrySet()) { - PartitionableListState stateListForName = registeredStates.get(nameToOffsets.getKey()); - - Preconditions.checkState(null != stateListForName, "Found state without " + - "corresponding meta info: " + nameToOffsets.getKey()); + final String stateName = nameToOffsets.getKey(); - deserializeStateValues(stateListForName, in, nameToOffsets.getValue()); + PartitionableListState listStateForName = registeredOperatorStates.get(stateName); + if (listStateForName == null) { + BackendWritableBroadcastState broadcastStateForName = registeredBroadcastStates.get(stateName); + Preconditions.checkState(broadcastStateForName != null, "Found state without " + + "corresponding meta info: " + stateName); + deserializeBroadcastStateValues(broadcastStateForName, in, nameToOffsets.getValue()); + } else { + deserializeOperatorStateValues(listStateForName, in, nameToOffsets.getValue()); + } } } finally { @@ -428,7 +618,7 @@ static final class PartitionableListState implements ListState { */ private final ArrayListSerializer internalListCopySerializer; - public PartitionableListState(RegisteredOperatorBackendStateMetaInfo stateMetaInfo) { + PartitionableListState(RegisteredOperatorBackendStateMetaInfo stateMetaInfo) { this(stateMetaInfo, new ArrayList()); } @@ -513,7 +703,7 @@ public void addAll(List values) throws Exception { private ListState getListState( ListStateDescriptor stateDescriptor, - OperatorStateHandle.Mode mode) throws IOException, StateMigrationException { + OperatorStateHandle.Mode mode) throws StateMigrationException { Preconditions.checkNotNull(stateDescriptor); String name = Preconditions.checkNotNull(stateDescriptor.getName()); @@ -521,7 +711,11 @@ private ListState getListState( @SuppressWarnings("unchecked") PartitionableListState previous = (PartitionableListState) accessedStatesByName.get(name); if (previous != null) { - checkStateNameAndMode(previous.getStateMetaInfo(), name, mode); + checkStateNameAndMode( + previous.getStateMetaInfo().getName(), + name, + previous.getStateMetaInfo().getAssignmentMode(), + mode); return previous; } @@ -533,7 +727,7 @@ private ListState getListState( TypeSerializer partitionStateSerializer = Preconditions.checkNotNull(stateDescriptor.getElementSerializer()); @SuppressWarnings("unchecked") - PartitionableListState partitionableListState = (PartitionableListState) registeredStates.get(name); + PartitionableListState partitionableListState = (PartitionableListState) registeredOperatorStates.get(name); if (null == partitionableListState) { // no restored state for the state name; simply create new state holder @@ -544,15 +738,19 @@ private ListState getListState( partitionStateSerializer, mode)); - registeredStates.put(name, partitionableListState); + registeredOperatorStates.put(name, partitionableListState); } else { // has restored state; check compatibility of new state access - checkStateNameAndMode(partitionableListState.getStateMetaInfo(), name, mode); + checkStateNameAndMode( + partitionableListState.getStateMetaInfo().getName(), + name, + partitionableListState.getStateMetaInfo().getAssignmentMode(), + mode); @SuppressWarnings("unchecked") RegisteredOperatorBackendStateMetaInfo.Snapshot restoredMetaInfo = - (RegisteredOperatorBackendStateMetaInfo.Snapshot) restoredStateMetaInfos.get(name); + (RegisteredOperatorBackendStateMetaInfo.Snapshot) restoredOperatorStateMetaInfos.get(name); // check compatibility to determine if state migration is required CompatibilityResult stateCompatibility = CompatibilityUtil.resolveCompatibilityResult( @@ -581,7 +779,7 @@ private ListState getListState( return partitionableListState; } - private static void deserializeStateValues( + private static void deserializeOperatorStateValues( PartitionableListState stateListForName, FSDataInputStream in, OperatorStateHandle.StateMetaInfo metaInfo) throws IOException { @@ -599,21 +797,45 @@ private static void deserializeStateValues( } } + private static void deserializeBroadcastStateValues( + final BackendWritableBroadcastState broadcastStateForName, + final FSDataInputStream in, + final OperatorStateHandle.StateMetaInfo metaInfo) throws Exception { + + if (metaInfo != null) { + long[] offsets = metaInfo.getOffsets(); + if (offsets != null) { + + TypeSerializer keySerializer = broadcastStateForName.getStateMetaInfo().getKeySerializer(); + TypeSerializer valueSerializer = broadcastStateForName.getStateMetaInfo().getValueSerializer(); + + in.seek(offsets[0]); + + DataInputView div = new DataInputViewStreamWrapper(in); + int size = div.readInt(); + for (int i = 0; i < size; i++) { + broadcastStateForName.put(keySerializer.deserialize(div), valueSerializer.deserialize(div)); + } + } + } + } + private static void checkStateNameAndMode( - RegisteredOperatorBackendStateMetaInfo previousMetaInfo, + String actualName, String expectedName, + OperatorStateHandle.Mode actualMode, OperatorStateHandle.Mode expectedMode) { Preconditions.checkState( - previousMetaInfo.getName().equals(expectedName), + actualName.equals(expectedName), "Incompatible state names. " + - "Was [" + previousMetaInfo.getName() + "], " + + "Was [" + actualName + "], " + "registered with [" + expectedName + "]."); Preconditions.checkState( - previousMetaInfo.getAssignmentMode().equals(expectedMode), + actualMode.equals(expectedMode), "Incompatible state assignment modes. " + - "Was [" + previousMetaInfo.getAssignmentMode() + "], " + + "Was [" + actualMode + "], " + "registered with [" + expectedMode + "]."); } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/HeapBroadcastState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/HeapBroadcastState.java new file mode 100644 index 0000000000000..42e68f3f2118a --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/HeapBroadcastState.java @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state; + +import org.apache.flink.api.common.state.BroadcastState; +import org.apache.flink.api.common.typeutils.base.MapSerializer; +import org.apache.flink.core.fs.FSDataOutputStream; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.core.memory.DataOutputViewStreamWrapper; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.Map; + +/** + * A {@link BroadcastState Broadcast State} backed a heap-based {@link Map}. + * + * @param The key type of the elements in the {@link BroadcastState Broadcast State}. + * @param The value type of the elements in the {@link BroadcastState Broadcast State}. + */ +public class HeapBroadcastState implements BackendWritableBroadcastState { + + /** + * Meta information of the state, including state name, assignment mode, and serializer. + */ + private RegisteredBroadcastBackendStateMetaInfo stateMetaInfo; + + /** + * The internal map the holds the elements of the state. + */ + private final Map backingMap; + + /** + * A serializer that allows to perform deep copies of internal map state. + */ + private final MapSerializer internalMapCopySerializer; + + HeapBroadcastState(RegisteredBroadcastBackendStateMetaInfo stateMetaInfo) { + this(stateMetaInfo, new HashMap<>()); + } + + private HeapBroadcastState(final RegisteredBroadcastBackendStateMetaInfo stateMetaInfo, final Map internalMap) { + + this.stateMetaInfo = Preconditions.checkNotNull(stateMetaInfo); + this.backingMap = Preconditions.checkNotNull(internalMap); + this.internalMapCopySerializer = new MapSerializer<>(stateMetaInfo.getKeySerializer(), stateMetaInfo.getValueSerializer()); + } + + private HeapBroadcastState(HeapBroadcastState toCopy) { + this(toCopy.stateMetaInfo, toCopy.internalMapCopySerializer.copy(toCopy.backingMap)); + } + + @Override + public void setStateMetaInfo(RegisteredBroadcastBackendStateMetaInfo stateMetaInfo) { + this.stateMetaInfo = stateMetaInfo; + } + + @Override + public RegisteredBroadcastBackendStateMetaInfo getStateMetaInfo() { + return stateMetaInfo; + } + + @Override + public HeapBroadcastState deepCopy() { + return new HeapBroadcastState<>(this); + } + + @Override + public void clear() { + backingMap.clear(); + } + + @Override + public String toString() { + return "HeapBroadcastState{" + + "stateMetaInfo=" + stateMetaInfo + + ", backingMap=" + backingMap + + ", internalMapCopySerializer=" + internalMapCopySerializer + + '}'; + } + + @Override + public long write(FSDataOutputStream out) throws IOException { + long partitionOffset = out.getPos(); + + DataOutputView dov = new DataOutputViewStreamWrapper(out); + dov.writeInt(backingMap.size()); + for (Map.Entry entry: backingMap.entrySet()) { + getStateMetaInfo().getKeySerializer().serialize(entry.getKey(), dov); + getStateMetaInfo().getValueSerializer().serialize(entry.getValue(), dov); + } + + return partitionOffset; + } + + @Override + public V get(K key) { + return backingMap.get(key); + } + + @Override + public void put(K key, V value) { + backingMap.put(key, value); + } + + @Override + public void putAll(Map map) { + backingMap.putAll(map); + } + + @Override + public void remove(K key) { + backingMap.remove(key); + } + + @Override + public boolean contains(K key) { + return backingMap.containsKey(key); + } + + @Override + public Iterator> iterator() { + return backingMap.entrySet().iterator(); + } + + @Override + public Iterable> entries() { + return backingMap.entrySet(); + } + + @Override + public Iterable> immutableEntries() { + return Collections.unmodifiableSet(backingMap.entrySet()); + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorBackendSerializationProxy.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorBackendSerializationProxy.java index 074d84e6b4c2c..e73f83a9e1d96 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorBackendSerializationProxy.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorBackendSerializationProxy.java @@ -33,9 +33,10 @@ */ public class OperatorBackendSerializationProxy extends VersionedIOReadableWritable { - public static final int VERSION = 2; + public static final int VERSION = 3; - private List> stateMetaInfoSnapshots; + private List> operatorStateMetaInfoSnapshots; + private List> broadcastStateMetaInfoSnapshots; private ClassLoader userCodeClassLoader; public OperatorBackendSerializationProxy(ClassLoader userCodeClassLoader) { @@ -43,10 +44,15 @@ public OperatorBackendSerializationProxy(ClassLoader userCodeClassLoader) { } public OperatorBackendSerializationProxy( - List> stateMetaInfoSnapshots) { - - this.stateMetaInfoSnapshots = Preconditions.checkNotNull(stateMetaInfoSnapshots); - Preconditions.checkArgument(stateMetaInfoSnapshots.size() <= Short.MAX_VALUE); + List> operatorStateMetaInfoSnapshots, + List> broadcastStateMetaInfoSnapshots) { + + this.operatorStateMetaInfoSnapshots = Preconditions.checkNotNull(operatorStateMetaInfoSnapshots); + this.broadcastStateMetaInfoSnapshots = Preconditions.checkNotNull(broadcastStateMetaInfoSnapshots); + Preconditions.checkArgument( + operatorStateMetaInfoSnapshots.size() <= Short.MAX_VALUE && + broadcastStateMetaInfoSnapshots.size() <= Short.MAX_VALUE + ); } @Override @@ -56,19 +62,26 @@ public int getVersion() { @Override public int[] getCompatibleVersions() { - // we are compatible with version 2 (Flink 1.3.x) and version 1 (Flink 1.2.x) - return new int[] {VERSION, 1}; + // we are compatible with version 3 (Flink 1.5.x), 2 (Flink 1.4.x, Flink 1.3.x) and version 1 (Flink 1.2.x) + return new int[] {VERSION, 2, 1}; } @Override public void write(DataOutputView out) throws IOException { super.write(out); - out.writeShort(stateMetaInfoSnapshots.size()); - for (RegisteredOperatorBackendStateMetaInfo.Snapshot kvState : stateMetaInfoSnapshots) { + out.writeShort(operatorStateMetaInfoSnapshots.size()); + for (RegisteredOperatorBackendStateMetaInfo.Snapshot state : operatorStateMetaInfoSnapshots) { + OperatorBackendStateMetaInfoSnapshotReaderWriters + .getOperatorStateWriterForVersion(VERSION, state) + .writeOperatorStateMetaInfo(out); + } + + out.writeShort(broadcastStateMetaInfoSnapshots.size()); + for (RegisteredBroadcastBackendStateMetaInfo.Snapshot state : broadcastStateMetaInfoSnapshots) { OperatorBackendStateMetaInfoSnapshotReaderWriters - .getWriterForVersion(VERSION, kvState) - .writeStateMetaInfo(out); + .getBroadcastStateWriterForVersion(VERSION, state) + .writeBroadcastStateMetaInfo(out); } } @@ -76,17 +89,35 @@ public void write(DataOutputView out) throws IOException { public void read(DataInputView in) throws IOException { super.read(in); - int numKvStates = in.readShort(); - stateMetaInfoSnapshots = new ArrayList<>(numKvStates); - for (int i = 0; i < numKvStates; i++) { - stateMetaInfoSnapshots.add( - OperatorBackendStateMetaInfoSnapshotReaderWriters - .getReaderForVersion(getReadVersion(), userCodeClassLoader) - .readStateMetaInfo(in)); + int numOperatorStates = in.readShort(); + operatorStateMetaInfoSnapshots = new ArrayList<>(numOperatorStates); + for (int i = 0; i < numOperatorStates; i++) { + operatorStateMetaInfoSnapshots.add( + OperatorBackendStateMetaInfoSnapshotReaderWriters + .getOperatorStateReaderForVersion(getReadVersion(), userCodeClassLoader) + .readOperatorStateMetaInfo(in)); } + + if (getReadVersion() >= 3) { + // broadcast states did not exist prior to version 3 + int numBroadcastStates = in.readShort(); + broadcastStateMetaInfoSnapshots = new ArrayList<>(numBroadcastStates); + for (int i = 0; i < numBroadcastStates; i++) { + broadcastStateMetaInfoSnapshots.add( + OperatorBackendStateMetaInfoSnapshotReaderWriters + .getBroadcastStateReaderForVersion(getReadVersion(), userCodeClassLoader) + .readBroadcastStateMetaInfo(in)); + } + } else { + broadcastStateMetaInfoSnapshots = new ArrayList<>(); + } + } + + public List> getOperatorStateMetaInfoSnapshots() { + return operatorStateMetaInfoSnapshots; } - public List> getStateMetaInfoSnapshots() { - return stateMetaInfoSnapshots; + public List> getBroadcastStateMetaInfoSnapshots() { + return broadcastStateMetaInfoSnapshots; } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorBackendStateMetaInfoSnapshotReaderWriters.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorBackendStateMetaInfoSnapshotReaderWriters.java index 03fe612b016f9..fafd5423e01c3 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorBackendStateMetaInfoSnapshotReaderWriters.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorBackendStateMetaInfoSnapshotReaderWriters.java @@ -31,7 +31,9 @@ import org.slf4j.LoggerFactory; import java.io.IOException; +import java.util.Arrays; import java.util.Collections; +import java.util.List; /** * Readers and writers for different versions of the {@link RegisteredOperatorBackendStateMetaInfo.Snapshot}. @@ -45,9 +47,10 @@ public class OperatorBackendStateMetaInfoSnapshotReaderWriters { // Writers // - v1: Flink 1.2.x // - v2: Flink 1.3.x + // - v3: Flink 1.5.x // ------------------------------------------------------------------------------- - public static OperatorBackendStateMetaInfoWriter getWriterForVersion( + public static OperatorBackendStateMetaInfoWriter getOperatorStateWriterForVersion( int version, RegisteredOperatorBackendStateMetaInfo.Snapshot stateMetaInfo) { switch (version) { @@ -55,6 +58,7 @@ public static OperatorBackendStateMetaInfoWriter getWriterForVersion( return new OperatorBackendStateMetaInfoWriterV1<>(stateMetaInfo); // current version + case 2: case OperatorBackendSerializationProxy.VERSION: return new OperatorBackendStateMetaInfoWriterV2<>(stateMetaInfo); @@ -65,8 +69,28 @@ public static OperatorBackendStateMetaInfoWriter getWriterForVersion( } } + public static BroadcastStateMetaInfoWriter getBroadcastStateWriterForVersion( + final int version, + final RegisteredBroadcastBackendStateMetaInfo.Snapshot broadcastStateMetaInfo) { + + switch (version) { + // current version + case OperatorBackendSerializationProxy.VERSION: + return new BroadcastStateMetaInfoWriterV3<>(broadcastStateMetaInfo); + + default: + // guard for future + throw new IllegalStateException( + "Unrecognized broadcast state meta info writer version: " + version); + } + } + public interface OperatorBackendStateMetaInfoWriter { - void writeStateMetaInfo(DataOutputView out) throws IOException; + void writeOperatorStateMetaInfo(DataOutputView out) throws IOException; + } + + public interface BroadcastStateMetaInfoWriter { + void writeBroadcastStateMetaInfo(final DataOutputView out) throws IOException; } public static abstract class AbstractOperatorBackendStateMetaInfoWriter @@ -79,6 +103,16 @@ public AbstractOperatorBackendStateMetaInfoWriter(RegisteredOperatorBackendState } } + public abstract static class AbstractBroadcastStateMetaInfoWriter + implements BroadcastStateMetaInfoWriter { + + protected final RegisteredBroadcastBackendStateMetaInfo.Snapshot broadcastStateMetaInfo; + + public AbstractBroadcastStateMetaInfoWriter(final RegisteredBroadcastBackendStateMetaInfo.Snapshot broadcastStateMetaInfo) { + this.broadcastStateMetaInfo = Preconditions.checkNotNull(broadcastStateMetaInfo); + } + } + public static class OperatorBackendStateMetaInfoWriterV1 extends AbstractOperatorBackendStateMetaInfoWriter { public OperatorBackendStateMetaInfoWriterV1(RegisteredOperatorBackendStateMetaInfo.Snapshot stateMetaInfo) { @@ -86,7 +120,7 @@ public OperatorBackendStateMetaInfoWriterV1(RegisteredOperatorBackendStateMetaIn } @Override - public void writeStateMetaInfo(DataOutputView out) throws IOException { + public void writeOperatorStateMetaInfo(DataOutputView out) throws IOException { out.writeUTF(stateMetaInfo.getName()); out.writeByte(stateMetaInfo.getAssignmentMode().ordinal()); TypeSerializerSerializationUtil.writeSerializer(out, stateMetaInfo.getPartitionStateSerializer()); @@ -100,7 +134,7 @@ public OperatorBackendStateMetaInfoWriterV2(RegisteredOperatorBackendStateMetaIn } @Override - public void writeStateMetaInfo(DataOutputView out) throws IOException { + public void writeOperatorStateMetaInfo(DataOutputView out) throws IOException { out.writeUTF(stateMetaInfo.getName()); out.writeByte(stateMetaInfo.getAssignmentMode().ordinal()); @@ -113,20 +147,51 @@ public void writeStateMetaInfo(DataOutputView out) throws IOException { } } + public static class BroadcastStateMetaInfoWriterV3 extends AbstractBroadcastStateMetaInfoWriter { + + public BroadcastStateMetaInfoWriterV3( + final RegisteredBroadcastBackendStateMetaInfo.Snapshot broadcastStateMetaInfo) { + super(broadcastStateMetaInfo); + } + + @Override + public void writeBroadcastStateMetaInfo(final DataOutputView out) throws IOException { + out.writeUTF(broadcastStateMetaInfo.getName()); + out.writeByte(broadcastStateMetaInfo.getAssignmentMode().ordinal()); + + // write in a way that allows us to be fault-tolerant and skip blocks in the case of java serialization failures + TypeSerializerSerializationUtil.writeSerializersAndConfigsWithResilience( + out, + Arrays.asList( + Tuple2.of( + broadcastStateMetaInfo.getKeySerializer(), + broadcastStateMetaInfo.getKeySerializerConfigSnapshot() + ), + Tuple2.of( + broadcastStateMetaInfo.getValueSerializer(), + broadcastStateMetaInfo.getValueSerializerConfigSnapshot() + ) + ) + ); + } + } + // ------------------------------------------------------------------------------- // Readers // - v1: Flink 1.2.x // - v2: Flink 1.3.x + // - v3: Flink 1.5.x // ------------------------------------------------------------------------------- - public static OperatorBackendStateMetaInfoReader getReaderForVersion( + public static OperatorBackendStateMetaInfoReader getOperatorStateReaderForVersion( int version, ClassLoader userCodeClassLoader) { switch (version) { case 1: return new OperatorBackendStateMetaInfoReaderV1<>(userCodeClassLoader); - // current version + // version 2 and version 3 (current) + case 2: case OperatorBackendSerializationProxy.VERSION: return new OperatorBackendStateMetaInfoReaderV2<>(userCodeClassLoader); @@ -137,8 +202,27 @@ public static OperatorBackendStateMetaInfoReader getReaderForVersion( } } + public static BroadcastStateMetaInfoReader getBroadcastStateReaderForVersion( + int version, ClassLoader userCodeClassLoader) { + + switch (version) { + // current version + case OperatorBackendSerializationProxy.VERSION: + return new BroadcastStateMetaInfoReaderV3<>(userCodeClassLoader); + + default: + // guard for future + throw new IllegalStateException( + "Unrecognized broadcast state meta info reader version: " + version); + } + } + public interface OperatorBackendStateMetaInfoReader { - RegisteredOperatorBackendStateMetaInfo.Snapshot readStateMetaInfo(DataInputView in) throws IOException; + RegisteredOperatorBackendStateMetaInfo.Snapshot readOperatorStateMetaInfo(DataInputView in) throws IOException; + } + + public interface BroadcastStateMetaInfoReader { + RegisteredBroadcastBackendStateMetaInfo.Snapshot readBroadcastStateMetaInfo(final DataInputView in) throws IOException; } public static abstract class AbstractOperatorBackendStateMetaInfoReader @@ -151,6 +235,16 @@ public AbstractOperatorBackendStateMetaInfoReader(ClassLoader userCodeClassLoade } } + public abstract static class AbstractBroadcastStateMetaInfoReader + implements BroadcastStateMetaInfoReader { + + protected final ClassLoader userCodeClassLoader; + + public AbstractBroadcastStateMetaInfoReader(final ClassLoader userCodeClassLoader) { + this.userCodeClassLoader = Preconditions.checkNotNull(userCodeClassLoader); + } + } + public static class OperatorBackendStateMetaInfoReaderV1 extends AbstractOperatorBackendStateMetaInfoReader { public OperatorBackendStateMetaInfoReaderV1(ClassLoader userCodeClassLoader) { @@ -159,7 +253,7 @@ public OperatorBackendStateMetaInfoReaderV1(ClassLoader userCodeClassLoader) { @SuppressWarnings("unchecked") @Override - public RegisteredOperatorBackendStateMetaInfo.Snapshot readStateMetaInfo(DataInputView in) throws IOException { + public RegisteredOperatorBackendStateMetaInfo.Snapshot readOperatorStateMetaInfo(DataInputView in) throws IOException { RegisteredOperatorBackendStateMetaInfo.Snapshot stateMetaInfo = new RegisteredOperatorBackendStateMetaInfo.Snapshot<>(); @@ -196,7 +290,7 @@ public OperatorBackendStateMetaInfoReaderV2(ClassLoader userCodeClassLoader) { } @Override - public RegisteredOperatorBackendStateMetaInfo.Snapshot readStateMetaInfo(DataInputView in) throws IOException { + public RegisteredOperatorBackendStateMetaInfo.Snapshot readOperatorStateMetaInfo(DataInputView in) throws IOException { RegisteredOperatorBackendStateMetaInfo.Snapshot stateMetaInfo = new RegisteredOperatorBackendStateMetaInfo.Snapshot<>(); @@ -212,4 +306,34 @@ public RegisteredOperatorBackendStateMetaInfo.Snapshot readStateMetaInfo(Data return stateMetaInfo; } } + + public static class BroadcastStateMetaInfoReaderV3 extends AbstractBroadcastStateMetaInfoReader { + + public BroadcastStateMetaInfoReaderV3(final ClassLoader userCodeClassLoader) { + super(userCodeClassLoader); + } + + @Override + public RegisteredBroadcastBackendStateMetaInfo.Snapshot readBroadcastStateMetaInfo(final DataInputView in) throws IOException { + RegisteredBroadcastBackendStateMetaInfo.Snapshot stateMetaInfo = + new RegisteredBroadcastBackendStateMetaInfo.Snapshot<>(); + + stateMetaInfo.setName(in.readUTF()); + stateMetaInfo.setAssignmentMode(OperatorStateHandle.Mode.values()[in.readByte()]); + + List, TypeSerializerConfigSnapshot>> serializers = + TypeSerializerSerializationUtil.readSerializersAndConfigsWithResilience(in, userCodeClassLoader); + + Tuple2, TypeSerializerConfigSnapshot> keySerializerAndConfig = serializers.get(0); + Tuple2, TypeSerializerConfigSnapshot> valueSerializerAndConfig = serializers.get(1); + + stateMetaInfo.setKeySerializer((TypeSerializer) keySerializerAndConfig.f0); + stateMetaInfo.setKeySerializerConfigSnapshot(keySerializerAndConfig.f1); + + stateMetaInfo.setValueSerializer((TypeSerializer) valueSerializerAndConfig.f0); + stateMetaInfo.setValueSerializerConfigSnapshot(valueSerializerAndConfig.f1); + + return stateMetaInfo; + } + } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateHandle.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateHandle.java index a357dc4c03699..f9427ef868491 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateHandle.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateHandle.java @@ -36,8 +36,9 @@ public class OperatorStateHandle implements StreamStateHandle { * The modes that determine how an {@link OperatorStateHandle} is assigned to tasks during restore. */ public enum Mode { - SPLIT_DISTRIBUTE, // The operator state partitions in the state handle are split and distributed to one task each. - BROADCAST // The operator state partitions are broadcast to all task. + SPLIT_DISTRIBUTE, // The operator state partitions in the state handle are split and distributed to one task each. + UNION, // The operator state partitions are UNION-ed upon restoring and sent to all tasks. + BROADCAST // The operator states are identical, as the state is produced from a broadcast stream. } private static final long serialVersionUID = 35876522969227335L; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/RegisteredBroadcastBackendStateMetaInfo.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/RegisteredBroadcastBackendStateMetaInfo.java new file mode 100644 index 0000000000000..d462b34fab2cf --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/RegisteredBroadcastBackendStateMetaInfo.java @@ -0,0 +1,230 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state; + +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.common.typeutils.TypeSerializerConfigSnapshot; +import org.apache.flink.util.Preconditions; + +import java.util.Objects; + +public class RegisteredBroadcastBackendStateMetaInfo { + + /** The name of the state, as registered by the user. */ + private final String name; + + /** The mode how elements in this state are assigned to tasks during restore. */ + private final OperatorStateHandle.Mode assignmentMode; + + /** The type serializer for the keys in the map state. */ + private final TypeSerializer keySerializer; + + /** The type serializer for the values in the map state. */ + private final TypeSerializer valueSerializer; + + public RegisteredBroadcastBackendStateMetaInfo( + final String name, + final OperatorStateHandle.Mode assignmentMode, + final TypeSerializer keySerializer, + final TypeSerializer valueSerializer) { + + Preconditions.checkArgument(assignmentMode != null && assignmentMode == OperatorStateHandle.Mode.BROADCAST); + + this.name = Preconditions.checkNotNull(name); + this.assignmentMode = assignmentMode; + this.keySerializer = Preconditions.checkNotNull(keySerializer); + this.valueSerializer = Preconditions.checkNotNull(valueSerializer); + } + + public String getName() { + return name; + } + + public TypeSerializer getKeySerializer() { + return keySerializer; + } + + public TypeSerializer getValueSerializer() { + return valueSerializer; + } + + public OperatorStateHandle.Mode getAssignmentMode() { + return assignmentMode; + } + + public RegisteredBroadcastBackendStateMetaInfo.Snapshot snapshot() { + return new RegisteredBroadcastBackendStateMetaInfo.Snapshot<>( + name, + assignmentMode, + keySerializer.duplicate(), + valueSerializer.duplicate(), + keySerializer.snapshotConfiguration(), + valueSerializer.snapshotConfiguration()); + } + + @Override + public boolean equals(Object obj) { + if (obj == this) { + return true; + } + + if (!(obj instanceof RegisteredBroadcastBackendStateMetaInfo)) { + return false; + } + + final RegisteredBroadcastBackendStateMetaInfo other = + (RegisteredBroadcastBackendStateMetaInfo) obj; + + return Objects.equals(name, other.getName()) + && Objects.equals(assignmentMode, other.getAssignmentMode()) + && Objects.equals(keySerializer, other.getKeySerializer()) + && Objects.equals(valueSerializer, other.getValueSerializer()); + } + + @Override + public int hashCode() { + int result = name.hashCode(); + result = 31 * result + assignmentMode.hashCode(); + result = 31 * result + keySerializer.hashCode(); + result = 31 * result + valueSerializer.hashCode(); + return result; + } + + @Override + public String toString() { + return "RegisteredBroadcastBackendStateMetaInfo{" + + "name='" + name + '\'' + + ", keySerializer=" + keySerializer + + ", valueSerializer=" + valueSerializer + + ", assignmentMode=" + assignmentMode + + '}'; + } + + /** + * A consistent snapshot of a {@link RegisteredOperatorBackendStateMetaInfo}. + */ + public static class Snapshot { + + private String name; + private OperatorStateHandle.Mode assignmentMode; + private TypeSerializer keySerializer; + private TypeSerializer valueSerializer; + private TypeSerializerConfigSnapshot keySerializerConfigSnapshot; + private TypeSerializerConfigSnapshot valueSerializerConfigSnapshot; + + /** Empty constructor used when restoring the state meta info snapshot. */ + Snapshot() {} + + private Snapshot( + final String name, + final OperatorStateHandle.Mode assignmentMode, + final TypeSerializer keySerializer, + final TypeSerializer valueSerializer, + final TypeSerializerConfigSnapshot keySerializerConfigSnapshot, + final TypeSerializerConfigSnapshot valueSerializerConfigSnapshot) { + + this.name = Preconditions.checkNotNull(name); + this.assignmentMode = Preconditions.checkNotNull(assignmentMode); + this.keySerializer = Preconditions.checkNotNull(keySerializer); + this.valueSerializer = Preconditions.checkNotNull(valueSerializer); + this.keySerializerConfigSnapshot = Preconditions.checkNotNull(keySerializerConfigSnapshot); + this.valueSerializerConfigSnapshot = Preconditions.checkNotNull(valueSerializerConfigSnapshot); + } + + public String getName() { + return name; + } + + void setName(String name) { + this.name = name; + } + + public OperatorStateHandle.Mode getAssignmentMode() { + return assignmentMode; + } + + void setAssignmentMode(OperatorStateHandle.Mode mode) { + this.assignmentMode = mode; + } + + public TypeSerializer getKeySerializer() { + return keySerializer; + } + + void setKeySerializer(TypeSerializer serializer) { + this.keySerializer = serializer; + } + + public TypeSerializer getValueSerializer() { + return valueSerializer; + } + + void setValueSerializer(TypeSerializer serializer) { + this.valueSerializer = serializer; + } + + public TypeSerializerConfigSnapshot getKeySerializerConfigSnapshot() { + return keySerializerConfigSnapshot; + } + + void setKeySerializerConfigSnapshot(TypeSerializerConfigSnapshot configSnapshot) { + this.keySerializerConfigSnapshot = configSnapshot; + } + + public TypeSerializerConfigSnapshot getValueSerializerConfigSnapshot() { + return valueSerializerConfigSnapshot; + } + + void setValueSerializerConfigSnapshot(TypeSerializerConfigSnapshot configSnapshot) { + this.valueSerializerConfigSnapshot = configSnapshot; + } + + @Override + public boolean equals(Object obj) { + if (obj == this) { + return true; + } + + if (!(obj instanceof RegisteredBroadcastBackendStateMetaInfo.Snapshot)) { + return false; + } + + RegisteredBroadcastBackendStateMetaInfo.Snapshot snapshot = + (RegisteredBroadcastBackendStateMetaInfo.Snapshot) obj; + + return name.equals(snapshot.getName()) + && assignmentMode.ordinal() == snapshot.getAssignmentMode().ordinal() + && Objects.equals(keySerializer, snapshot.getKeySerializer()) + && Objects.equals(valueSerializer, snapshot.getValueSerializer()) + && keySerializerConfigSnapshot.equals(snapshot.getKeySerializerConfigSnapshot()) + && valueSerializerConfigSnapshot.equals(snapshot.getValueSerializerConfigSnapshot()); + } + + @Override + public int hashCode() { + int result = name.hashCode(); + result = 31 * result + assignmentMode.hashCode(); + result = 31 * result + ((keySerializer != null) ? keySerializer.hashCode() : 0); + result = 31 * result + ((valueSerializer != null) ? valueSerializer.hashCode() : 0); + result = 31 * result + keySerializerConfigSnapshot.hashCode(); + result = 31 * result + valueSerializerConfigSnapshot.hashCode(); + return result; + } + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java index d3185751e1012..c791fd8d02e88 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java @@ -2731,10 +2731,15 @@ public void testExternalizedCheckpoints() throws Exception { @Test public void testReplicateModeStateHandle() { Map metaInfoMap = new HashMap<>(1); - metaInfoMap.put("t-1", new OperatorStateHandle.StateMetaInfo(new long[]{0, 23}, OperatorStateHandle.Mode.BROADCAST)); - metaInfoMap.put("t-2", new OperatorStateHandle.StateMetaInfo(new long[]{42, 64}, OperatorStateHandle.Mode.BROADCAST)); + metaInfoMap.put("t-1", new OperatorStateHandle.StateMetaInfo(new long[]{0, 23}, OperatorStateHandle.Mode.UNION)); + metaInfoMap.put("t-2", new OperatorStateHandle.StateMetaInfo(new long[]{42, 64}, OperatorStateHandle.Mode.UNION)); metaInfoMap.put("t-3", new OperatorStateHandle.StateMetaInfo(new long[]{72, 83}, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE)); - OperatorStateHandle osh = new OperatorStateHandle(metaInfoMap, new ByteStreamStateHandle("test", new byte[100])); + metaInfoMap.put("t-4", new OperatorStateHandle.StateMetaInfo(new long[]{87, 94, 95}, OperatorStateHandle.Mode.BROADCAST)); + metaInfoMap.put("t-5", new OperatorStateHandle.StateMetaInfo(new long[]{97, 108, 112}, OperatorStateHandle.Mode.BROADCAST)); + metaInfoMap.put("t-6", new OperatorStateHandle.StateMetaInfo(new long[]{121, 143, 147}, OperatorStateHandle.Mode.BROADCAST)); + + // this is what a single task will return + OperatorStateHandle osh = new OperatorStateHandle(metaInfoMap, new ByteStreamStateHandle("test", new byte[150])); OperatorStateRepartitioner repartitioner = RoundRobinOperatorStateRepartitioner.INSTANCE; List> repartitionedStates = @@ -2757,18 +2762,26 @@ public void testReplicateModeStateHandle() { OperatorStateHandle.StateMetaInfo stateMetaInfo = stateNameToMetaInfo.getValue(); if (OperatorStateHandle.Mode.SPLIT_DISTRIBUTE.equals(stateMetaInfo.getDistributionMode())) { + // SPLIT_DISTRIBUTE: so split the state and re-distribute it -> each one will go to one task Assert.assertEquals(1, stateNameToMetaInfo.getValue().getOffsets().length); - } else { + } else if (OperatorStateHandle.Mode.UNION.equals(stateMetaInfo.getDistributionMode())) { + // BROADCAST: so all to all Assert.assertEquals(2, stateNameToMetaInfo.getValue().getOffsets().length); + } else { + // UNIFORM_BROADCAST: so all to all + Assert.assertEquals(3, stateNameToMetaInfo.getValue().getOffsets().length); } } } } - Assert.assertEquals(3, checkCounts.size()); + Assert.assertEquals(6, checkCounts.size()); Assert.assertEquals(3, checkCounts.get("t-1").intValue()); Assert.assertEquals(3, checkCounts.get("t-2").intValue()); Assert.assertEquals(2, checkCounts.get("t-3").intValue()); + Assert.assertEquals(3, checkCounts.get("t-4").intValue()); + Assert.assertEquals(3, checkCounts.get("t-5").intValue()); + Assert.assertEquals(3, checkCounts.get("t-6").intValue()); } // ------------------------------------------------------------------------ @@ -3243,7 +3256,7 @@ private void doTestPartitionableStateRepartitioning( Path fakePath = new Path("/fake-" + i); Map namedStatesToOffsets = new HashMap<>(); int off = 0; - for (int s = 0; s < numNamedStates; ++s) { + for (int s = 0; s < numNamedStates - 1; ++s) { long[] offs = new long[1 + r.nextInt(maxPartitionsPerState)]; for (int o = 0; o < offs.length; ++o) { @@ -3252,19 +3265,29 @@ private void doTestPartitionableStateRepartitioning( } OperatorStateHandle.Mode mode = r.nextInt(10) == 0 ? - OperatorStateHandle.Mode.BROADCAST : OperatorStateHandle.Mode.SPLIT_DISTRIBUTE; + OperatorStateHandle.Mode.UNION : OperatorStateHandle.Mode.SPLIT_DISTRIBUTE; namedStatesToOffsets.put( "State-" + s, new OperatorStateHandle.StateMetaInfo(offs, mode)); } + if (numNamedStates % 2 == 0) { + // finally add a broadcast state + long[] offs = {off + 1, off + 2, off + 3, off + 4}; + + namedStatesToOffsets.put( + "State-" + (numNamedStates - 1), + new OperatorStateHandle.StateMetaInfo(offs, OperatorStateHandle.Mode.BROADCAST)); + } + previousParallelOpInstanceStates.add( new OperatorStateHandle(namedStatesToOffsets, new FileStateHandle(fakePath, -1))); } Map>> expected = new HashMap<>(); + int taskIndex = 0; int expectedTotalPartitions = 0; for (OperatorStateHandle psh : previousParallelOpInstanceStates) { Map offsMap = psh.getStateNameToPartitionOffsets(); @@ -3272,20 +3295,39 @@ private void doTestPartitionableStateRepartitioning( for (Map.Entry e : offsMap.entrySet()) { long[] offs = e.getValue().getOffsets(); - int replication = e.getValue().getDistributionMode().equals(OperatorStateHandle.Mode.BROADCAST) ? - newParallelism : 1; + int replication; + switch (e.getValue().getDistributionMode()) { + case UNION: + replication = newParallelism; + break; + case BROADCAST: + int extra = taskIndex < (newParallelism % oldParallelism) ? 1 : 0; + replication = newParallelism / oldParallelism + extra; + break; + case SPLIT_DISTRIBUTE: + replication = 1; + break; + default: + throw new RuntimeException("Unknown distribution mode " + e.getValue().getDistributionMode()); + } - expectedTotalPartitions += replication * offs.length; - List offsList = new ArrayList<>(offs.length); + if (replication > 0) { + expectedTotalPartitions += replication * offs.length; + List offsList = new ArrayList<>(offs.length); - for (long off : offs) { - for (int p = 0; p < replication; ++p) { - offsList.add(off); + for (long off : offs) { + for (int p = 0; p < replication; ++p) { + offsList.add(off); + } } + offsMapWithList.put(e.getKey(), offsList); } - offsMapWithList.put(e.getKey(), offsList); } - expected.put(psh.getDelegateStateHandle(), offsMapWithList); + + if (!offsMapWithList.isEmpty()) { + expected.put(psh.getDelegateStateHandle(), offsMapWithList); + } + taskIndex++; } OperatorStateRepartitioner repartitioner = RoundRobinOperatorStateRepartitioner.INSTANCE; diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/CheckpointTestUtils.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/CheckpointTestUtils.java index acedb5071b252..d1d67ff940268 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/CheckpointTestUtils.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/CheckpointTestUtils.java @@ -97,7 +97,7 @@ public static Collection createOperatorStates( Map offsetsMap = new HashMap<>(); offsetsMap.put("A", new OperatorStateHandle.StateMetaInfo(new long[]{0, 10, 20}, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE)); offsetsMap.put("B", new OperatorStateHandle.StateMetaInfo(new long[]{30, 40, 50}, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE)); - offsetsMap.put("C", new OperatorStateHandle.StateMetaInfo(new long[]{60, 70, 80}, OperatorStateHandle.Mode.BROADCAST)); + offsetsMap.put("C", new OperatorStateHandle.StateMetaInfo(new long[]{60, 70, 80}, OperatorStateHandle.Mode.UNION)); if (hasOperatorStateBackend) { operatorStateHandleBackend = new OperatorStateHandle(offsetsMap, operatorStateBackend); @@ -179,7 +179,7 @@ public static Collection createTaskStates( Map offsetsMap = new HashMap<>(); offsetsMap.put("A", new OperatorStateHandle.StateMetaInfo(new long[]{0, 10, 20}, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE)); offsetsMap.put("B", new OperatorStateHandle.StateMetaInfo(new long[]{30, 40, 50}, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE)); - offsetsMap.put("C", new OperatorStateHandle.StateMetaInfo(new long[]{60, 70, 80}, OperatorStateHandle.Mode.BROADCAST)); + offsetsMap.put("C", new OperatorStateHandle.StateMetaInfo(new long[]{60, 70, 80}, OperatorStateHandle.Mode.UNION)); if (chainIdx != noOperatorStateBackendAtIndex) { OperatorStateHandle operatorStateHandleBackend = diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java index ef390db479029..1881dad968074 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java @@ -18,8 +18,11 @@ package org.apache.flink.runtime.state; import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.state.BroadcastState; import org.apache.flink.api.common.state.ListState; import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.state.MapStateDescriptor; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; import org.apache.flink.api.common.typeutils.CompatibilityResult; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.common.typeutils.TypeSerializerConfigSnapshot; @@ -49,7 +52,9 @@ import java.io.IOException; import java.io.Serializable; import java.util.Collections; +import java.util.HashMap; import java.util.Iterator; +import java.util.Map; import java.util.concurrent.CancellationException; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; @@ -85,6 +90,7 @@ public void testCreateOnAbstractStateBackend() throws Exception { assertNotNull(operatorStateBackend); assertTrue(operatorStateBackend.getRegisteredStateNames().isEmpty()); + assertTrue(operatorStateBackend.getRegisteredBroadcastStateNames().isEmpty()); } @Test @@ -233,6 +239,20 @@ public void testCorrectClassLoaderUsedOnSnapshot() throws Exception { listState.add(42); + AtomicInteger keyCopyCounter = new AtomicInteger(0); + AtomicInteger valueCopyCounter = new AtomicInteger(0); + + TypeSerializer keySerializer = new VerifyingIntSerializer(env.getUserClassLoader(), keyCopyCounter); + TypeSerializer valueSerializer = new VerifyingIntSerializer(env.getUserClassLoader(), valueCopyCounter); + + MapStateDescriptor broadcastStateDesc = new MapStateDescriptor<>( + "test-broadcast", keySerializer, valueSerializer); + + BroadcastState broadcastState = operatorStateBackend.getBroadcastState(broadcastStateDesc); + broadcastState.put(1, 2); + broadcastState.put(3, 4); + broadcastState.put(5, 6); + CheckpointStreamFactory streamFactory = new MemCheckpointStreamFactory(4096); RunnableFuture runnableFuture = operatorStateBackend.snapshot(1, 1, streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()); @@ -240,6 +260,8 @@ public void testCorrectClassLoaderUsedOnSnapshot() throws Exception { // make sure that the copy method has been called assertTrue(copyCounter.get() > 0); + assertTrue(keyCopyCounter.get() > 0); + assertTrue(valueCopyCounter.get() > 0); } /** @@ -360,18 +382,103 @@ public void testSnapshotEmpty() throws Exception { assertNull(stateHandle); } + @Test + public void testSnapshotBroadcastStateWithEmptyOperatorState() throws Exception { + final AbstractStateBackend abstractStateBackend = new MemoryStateBackend(4096); + + final OperatorStateBackend operatorStateBackend = + abstractStateBackend.createOperatorStateBackend(createMockEnvironment(), "testOperator"); + + final MapStateDescriptor broadcastStateDesc = new MapStateDescriptor<>( + "test-broadcast", BasicTypeInfo.INT_TYPE_INFO, BasicTypeInfo.INT_TYPE_INFO); + + final Map expected = new HashMap<>(3); + expected.put(1, 2); + expected.put(3, 4); + expected.put(5, 6); + + final BroadcastState broadcastState = operatorStateBackend.getBroadcastState(broadcastStateDesc); + broadcastState.putAll(expected); + + final CheckpointStreamFactory streamFactory = new MemCheckpointStreamFactory(4096); + OperatorStateHandle stateHandle = null; + + try { + RunnableFuture snapshot = + operatorStateBackend.snapshot(0L, 0L, streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()); + + stateHandle = FutureUtil.runIfNotDoneAndGet(snapshot); + assertNotNull(stateHandle); + + final Map retrieved = new HashMap<>(); + + operatorStateBackend.restore(Collections.singleton(stateHandle)); + BroadcastState retrievedState = operatorStateBackend.getBroadcastState(broadcastStateDesc); + for (Map.Entry e: retrievedState.entries()) { + retrieved.put(e.getKey(), e.getValue()); + } + assertEquals(expected, retrieved); + + // remove an element from both expected and stored state. + broadcastState.remove(1); + expected.remove(1); + + snapshot = operatorStateBackend.snapshot(1L, 1L, streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()); + stateHandle = FutureUtil.runIfNotDoneAndGet(snapshot); + + retrieved.clear(); + operatorStateBackend.restore(Collections.singleton(stateHandle)); + retrievedState = operatorStateBackend.getBroadcastState(broadcastStateDesc); + for (Map.Entry e: retrievedState.immutableEntries()) { + retrieved.put(e.getKey(), e.getValue()); + } + assertEquals(expected, retrieved); + + // remove all elements from both expected and stored state. + broadcastState.clear(); + expected.clear(); + + snapshot = operatorStateBackend.snapshot(2L, 2L, streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()); + stateHandle = FutureUtil.runIfNotDoneAndGet(snapshot); + + retrieved.clear(); + operatorStateBackend.restore(Collections.singleton(stateHandle)); + retrievedState = operatorStateBackend.getBroadcastState(broadcastStateDesc); + for (Map.Entry e: retrievedState.immutableEntries()) { + retrieved.put(e.getKey(), e.getValue()); + } + assertTrue(expected.isEmpty()); + assertEquals(expected, retrieved); + } finally { + operatorStateBackend.close(); + operatorStateBackend.dispose(); + if (stateHandle != null) { + stateHandle.discardState(); + } + } + } + @Test public void testSnapshotRestoreSync() throws Exception { - AbstractStateBackend abstractStateBackend = new MemoryStateBackend(4096); + AbstractStateBackend abstractStateBackend = new MemoryStateBackend(2 * 4096); OperatorStateBackend operatorStateBackend = abstractStateBackend.createOperatorStateBackend(createMockEnvironment(), "test-op-name"); ListStateDescriptor stateDescriptor1 = new ListStateDescriptor<>("test1", new JavaSerializer<>()); ListStateDescriptor stateDescriptor2 = new ListStateDescriptor<>("test2", new JavaSerializer<>()); ListStateDescriptor stateDescriptor3 = new ListStateDescriptor<>("test3", new JavaSerializer<>()); + + MapStateDescriptor broadcastStateDescriptor1 = new MapStateDescriptor<>("test4", new JavaSerializer<>(), new JavaSerializer<>()); + MapStateDescriptor broadcastStateDescriptor2 = new MapStateDescriptor<>("test5", new JavaSerializer<>(), new JavaSerializer<>()); + MapStateDescriptor broadcastStateDescriptor3 = new MapStateDescriptor<>("test6", new JavaSerializer<>(), new JavaSerializer<>()); + ListState listState1 = operatorStateBackend.getListState(stateDescriptor1); ListState listState2 = operatorStateBackend.getListState(stateDescriptor2); ListState listState3 = operatorStateBackend.getUnionListState(stateDescriptor3); + BroadcastState broadcastState1 = operatorStateBackend.getBroadcastState(broadcastStateDescriptor1); + BroadcastState broadcastState2 = operatorStateBackend.getBroadcastState(broadcastStateDescriptor2); + BroadcastState broadcastState3 = operatorStateBackend.getBroadcastState(broadcastStateDescriptor3); + listState1.add(42); listState1.add(4711); @@ -384,7 +491,12 @@ public void testSnapshotRestoreSync() throws Exception { listState3.add(19); listState3.add(20); - CheckpointStreamFactory streamFactory = new MemCheckpointStreamFactory(4096); + broadcastState1.put(1, 2); + broadcastState1.put(2, 5); + + broadcastState2.put(2, 5); + + CheckpointStreamFactory streamFactory = new MemCheckpointStreamFactory(2 * 4096); RunnableFuture runnableFuture = operatorStateBackend.snapshot(1, 1, streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()); OperatorStateHandle stateHandle = FutureUtil.runIfNotDoneAndGet(runnableFuture); @@ -401,12 +513,18 @@ public void testSnapshotRestoreSync() throws Exception { operatorStateBackend.restore(Collections.singletonList(stateHandle)); assertEquals(3, operatorStateBackend.getRegisteredStateNames().size()); + assertEquals(3, operatorStateBackend.getRegisteredBroadcastStateNames().size()); listState1 = operatorStateBackend.getListState(stateDescriptor1); listState2 = operatorStateBackend.getListState(stateDescriptor2); listState3 = operatorStateBackend.getUnionListState(stateDescriptor3); + broadcastState1 = operatorStateBackend.getBroadcastState(broadcastStateDescriptor1); + broadcastState2 = operatorStateBackend.getBroadcastState(broadcastStateDescriptor2); + broadcastState3 = operatorStateBackend.getBroadcastState(broadcastStateDescriptor3); + assertEquals(3, operatorStateBackend.getRegisteredStateNames().size()); + assertEquals(3, operatorStateBackend.getRegisteredBroadcastStateNames().size()); Iterator it = listState1.get().iterator(); assertEquals(42, it.next()); @@ -426,6 +544,27 @@ public void testSnapshotRestoreSync() throws Exception { assertEquals(20, it.next()); assertFalse(it.hasNext()); + Iterator> bIt = broadcastState1.iterator(); + assertTrue(bIt.hasNext()); + Map.Entry entry = bIt.next(); + assertEquals(1, entry.getKey()); + assertEquals(2, entry.getValue()); + assertTrue(bIt.hasNext()); + entry = bIt.next(); + assertEquals(2, entry.getKey()); + assertEquals(5, entry.getValue()); + assertFalse(bIt.hasNext()); + + bIt = broadcastState2.iterator(); + assertTrue(bIt.hasNext()); + entry = bIt.next(); + assertEquals(2, entry.getKey()); + assertEquals(5, entry.getValue()); + assertFalse(bIt.hasNext()); + + bIt = broadcastState3.iterator(); + assertFalse(bIt.hasNext()); + operatorStateBackend.close(); operatorStateBackend.dispose(); } finally { @@ -444,10 +583,22 @@ public void testSnapshotRestoreAsync() throws Exception { new ListStateDescriptor<>("test2", new JavaSerializer()); ListStateDescriptor stateDescriptor3 = new ListStateDescriptor<>("test3", new JavaSerializer()); + + MapStateDescriptor broadcastStateDescriptor1 = + new MapStateDescriptor<>("test4", new JavaSerializer(), new JavaSerializer()); + MapStateDescriptor broadcastStateDescriptor2 = + new MapStateDescriptor<>("test5", new JavaSerializer(), new JavaSerializer()); + MapStateDescriptor broadcastStateDescriptor3 = + new MapStateDescriptor<>("test6", new JavaSerializer(), new JavaSerializer()); + ListState listState1 = operatorStateBackend.getListState(stateDescriptor1); ListState listState2 = operatorStateBackend.getListState(stateDescriptor2); ListState listState3 = operatorStateBackend.getUnionListState(stateDescriptor3); + BroadcastState broadcastState1 = operatorStateBackend.getBroadcastState(broadcastStateDescriptor1); + BroadcastState broadcastState2 = operatorStateBackend.getBroadcastState(broadcastStateDescriptor2); + BroadcastState broadcastState3 = operatorStateBackend.getBroadcastState(broadcastStateDescriptor3); + listState1.add(MutableType.of(42)); listState1.add(MutableType.of(4711)); @@ -460,6 +611,11 @@ public void testSnapshotRestoreAsync() throws Exception { listState3.add(MutableType.of(19)); listState3.add(MutableType.of(20)); + broadcastState1.put(MutableType.of(1), MutableType.of(2)); + broadcastState1.put(MutableType.of(2), MutableType.of(5)); + + broadcastState2.put(MutableType.of(2), MutableType.of(5)); + BlockerCheckpointStreamFactory streamFactory = new BlockerCheckpointStreamFactory(1024 * 1024); OneShotLatch waiterLatch = new OneShotLatch(); @@ -482,6 +638,8 @@ public void testSnapshotRestoreAsync() throws Exception { listState1.add(MutableType.of(77)); + broadcastState1.put(MutableType.of(32), MutableType.of(97)); + int n = 0; for (MutableType mutableType : listState2.get()) { @@ -493,6 +651,7 @@ public void testSnapshotRestoreAsync() throws Exception { } listState3.clear(); + broadcastState2.clear(); operatorStateBackend.getListState( new ListStateDescriptor<>("test4", new JavaSerializer())); @@ -514,12 +673,18 @@ public void testSnapshotRestoreAsync() throws Exception { operatorStateBackend.restore(Collections.singletonList(stateHandle)); assertEquals(3, operatorStateBackend.getRegisteredStateNames().size()); + assertEquals(3, operatorStateBackend.getRegisteredBroadcastStateNames().size()); listState1 = operatorStateBackend.getListState(stateDescriptor1); listState2 = operatorStateBackend.getListState(stateDescriptor2); listState3 = operatorStateBackend.getUnionListState(stateDescriptor3); + broadcastState1 = operatorStateBackend.getBroadcastState(broadcastStateDescriptor1); + broadcastState2 = operatorStateBackend.getBroadcastState(broadcastStateDescriptor2); + broadcastState3 = operatorStateBackend.getBroadcastState(broadcastStateDescriptor3); + assertEquals(3, operatorStateBackend.getRegisteredStateNames().size()); + assertEquals(3, operatorStateBackend.getRegisteredBroadcastStateNames().size()); Iterator it = listState1.get().iterator(); assertEquals(42, it.next().value); @@ -539,6 +704,27 @@ public void testSnapshotRestoreAsync() throws Exception { assertEquals(20, it.next().value); assertFalse(it.hasNext()); + Iterator> bIt = broadcastState1.iterator(); + assertTrue(bIt.hasNext()); + Map.Entry entry = bIt.next(); + assertEquals(1, entry.getKey().value); + assertEquals(2, entry.getValue().value); + assertTrue(bIt.hasNext()); + entry = bIt.next(); + assertEquals(2, entry.getKey().value); + assertEquals(5, entry.getValue().value); + assertFalse(bIt.hasNext()); + + bIt = broadcastState2.iterator(); + assertTrue(bIt.hasNext()); + entry = bIt.next(); + assertEquals(2, entry.getKey().value); + assertEquals(5, entry.getValue().value); + assertFalse(bIt.hasNext()); + + bIt = broadcastState3.iterator(); + assertFalse(bIt.hasNext()); + operatorStateBackend.close(); operatorStateBackend.dispose(); } finally { @@ -558,10 +744,16 @@ public void testSnapshotAsyncClose() throws Exception { ListState listState1 = operatorStateBackend.getOperatorState(stateDescriptor1); - listState1.add(MutableType.of(42)); listState1.add(MutableType.of(4711)); + MapStateDescriptor broadcastStateDescriptor1 = + new MapStateDescriptor<>("test4", new JavaSerializer(), new JavaSerializer()); + + BroadcastState broadcastState1 = operatorStateBackend.getBroadcastState(broadcastStateDescriptor1); + broadcastState1.put(MutableType.of(1), MutableType.of(2)); + broadcastState1.put(MutableType.of(2), MutableType.of(5)); + BlockerCheckpointStreamFactory streamFactory = new BlockerCheckpointStreamFactory(1024 * 1024); OneShotLatch waiterLatch = new OneShotLatch(); @@ -602,7 +794,6 @@ public void testSnapshotAsyncCancel() throws Exception { ListState listState1 = operatorStateBackend.getOperatorState(stateDescriptor1); - listState1.add(MutableType.of(42)); listState1.add(MutableType.of(4711)); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateHandleTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateHandleTest.java index ab801b6896b7f..88f9cd7eb3ef7 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateHandleTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateHandleTest.java @@ -28,10 +28,11 @@ public void testFixedEnumOrder() { // Ensure the order / ordinal of all values of enum 'mode' are fixed, as this is used for serialization Assert.assertEquals(0, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE.ordinal()); - Assert.assertEquals(1, OperatorStateHandle.Mode.BROADCAST.ordinal()); + Assert.assertEquals(1, OperatorStateHandle.Mode.UNION.ordinal()); + Assert.assertEquals(2, OperatorStateHandle.Mode.BROADCAST.ordinal()); // Ensure all enum values are registered and fixed forever by this test - Assert.assertEquals(2, OperatorStateHandle.Mode.values().length); + Assert.assertEquals(3, OperatorStateHandle.Mode.values().length); // Byte is used to encode enum value on serialization Assert.assertTrue(OperatorStateHandle.Mode.values().length <= Byte.MAX_VALUE); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/SerializationProxiesTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/SerializationProxiesTest.java index 341d4feeced37..57e4aed36c20d 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/SerializationProxiesTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/SerializationProxiesTest.java @@ -24,6 +24,7 @@ import org.apache.flink.api.common.typeutils.base.DoubleSerializer; import org.apache.flink.api.common.typeutils.base.IntSerializer; import org.apache.flink.api.common.typeutils.base.LongSerializer; +import org.apache.flink.api.common.typeutils.base.StringSerializer; import org.apache.flink.core.memory.ByteArrayInputStreamWithPos; import org.apache.flink.core.memory.ByteArrayOutputStreamWithPos; import org.apache.flink.core.memory.DataInputViewStreamWrapper; @@ -205,6 +206,8 @@ public void testKeyedStateMetaInfoReadSerializerFailureResilience() throws Excep public void testOperatorBackendSerializationProxyRoundtrip() throws Exception { TypeSerializer stateSerializer = DoubleSerializer.INSTANCE; + TypeSerializer keySerializer = DoubleSerializer.INSTANCE; + TypeSerializer valueSerializer = StringSerializer.INSTANCE; List> stateMetaInfoSnapshots = new ArrayList<>(); @@ -213,10 +216,17 @@ public void testOperatorBackendSerializationProxyRoundtrip() throws Exception { stateMetaInfoSnapshots.add(new RegisteredOperatorBackendStateMetaInfo<>( "b", stateSerializer, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE).snapshot()); stateMetaInfoSnapshots.add(new RegisteredOperatorBackendStateMetaInfo<>( - "c", stateSerializer, OperatorStateHandle.Mode.BROADCAST).snapshot()); + "c", stateSerializer, OperatorStateHandle.Mode.UNION).snapshot()); + + List> broadcastStateMetaInfoSnapshots = new ArrayList<>(); + + broadcastStateMetaInfoSnapshots.add(new RegisteredBroadcastBackendStateMetaInfo<>( + "d", OperatorStateHandle.Mode.BROADCAST, keySerializer, valueSerializer).snapshot()); + broadcastStateMetaInfoSnapshots.add(new RegisteredBroadcastBackendStateMetaInfo<>( + "e", OperatorStateHandle.Mode.BROADCAST, valueSerializer, keySerializer).snapshot()); OperatorBackendSerializationProxy serializationProxy = - new OperatorBackendSerializationProxy(stateMetaInfoSnapshots); + new OperatorBackendSerializationProxy(stateMetaInfoSnapshots, broadcastStateMetaInfoSnapshots); byte[] serialized; try (ByteArrayOutputStreamWithPos out = new ByteArrayOutputStreamWithPos()) { @@ -231,7 +241,8 @@ public void testOperatorBackendSerializationProxyRoundtrip() throws Exception { serializationProxy.read(new DataInputViewStreamWrapper(in)); } - Assert.assertEquals(stateMetaInfoSnapshots, serializationProxy.getStateMetaInfoSnapshots()); + Assert.assertEquals(stateMetaInfoSnapshots, serializationProxy.getOperatorStateMetaInfoSnapshots()); + Assert.assertEquals(broadcastStateMetaInfoSnapshots, serializationProxy.getBroadcastStateMetaInfoSnapshots()); } @Test @@ -242,24 +253,58 @@ public void testOperatorStateMetaInfoSerialization() throws Exception { RegisteredOperatorBackendStateMetaInfo.Snapshot metaInfo = new RegisteredOperatorBackendStateMetaInfo<>( - name, stateSerializer, OperatorStateHandle.Mode.BROADCAST).snapshot(); + name, stateSerializer, OperatorStateHandle.Mode.UNION).snapshot(); byte[] serialized; try (ByteArrayOutputStreamWithPos out = new ByteArrayOutputStreamWithPos()) { OperatorBackendStateMetaInfoSnapshotReaderWriters - .getWriterForVersion(OperatorBackendSerializationProxy.VERSION, metaInfo) - .writeStateMetaInfo(new DataOutputViewStreamWrapper(out)); + .getOperatorStateWriterForVersion(OperatorBackendSerializationProxy.VERSION, metaInfo) + .writeOperatorStateMetaInfo(new DataOutputViewStreamWrapper(out)); serialized = out.toByteArray(); } try (ByteArrayInputStreamWithPos in = new ByteArrayInputStreamWithPos(serialized)) { metaInfo = OperatorBackendStateMetaInfoSnapshotReaderWriters - .getReaderForVersion(OperatorBackendSerializationProxy.VERSION, Thread.currentThread().getContextClassLoader()) - .readStateMetaInfo(new DataInputViewStreamWrapper(in)); + .getOperatorStateReaderForVersion(OperatorBackendSerializationProxy.VERSION, Thread.currentThread().getContextClassLoader()) + .readOperatorStateMetaInfo(new DataInputViewStreamWrapper(in)); + } + + Assert.assertEquals(name, metaInfo.getName()); + Assert.assertEquals(OperatorStateHandle.Mode.UNION, metaInfo.getAssignmentMode()); + Assert.assertEquals(stateSerializer, metaInfo.getPartitionStateSerializer()); + } + + @Test + public void testBroadcastStateMetaInfoSerialization() throws Exception { + + String name = "test"; + TypeSerializer keySerializer = DoubleSerializer.INSTANCE; + TypeSerializer valueSerializer = StringSerializer.INSTANCE; + + RegisteredBroadcastBackendStateMetaInfo.Snapshot metaInfo = + new RegisteredBroadcastBackendStateMetaInfo<>( + name, OperatorStateHandle.Mode.BROADCAST, keySerializer, valueSerializer).snapshot(); + + byte[] serialized; + try (ByteArrayOutputStreamWithPos out = new ByteArrayOutputStreamWithPos()) { + OperatorBackendStateMetaInfoSnapshotReaderWriters + .getBroadcastStateWriterForVersion(OperatorBackendSerializationProxy.VERSION, metaInfo) + .writeBroadcastStateMetaInfo(new DataOutputViewStreamWrapper(out)); + + serialized = out.toByteArray(); + } + + try (ByteArrayInputStreamWithPos in = new ByteArrayInputStreamWithPos(serialized)) { + metaInfo = OperatorBackendStateMetaInfoSnapshotReaderWriters + .getBroadcastStateReaderForVersion(OperatorBackendSerializationProxy.VERSION, Thread.currentThread().getContextClassLoader()) + .readBroadcastStateMetaInfo(new DataInputViewStreamWrapper(in)); } Assert.assertEquals(name, metaInfo.getName()); + Assert.assertEquals(OperatorStateHandle.Mode.BROADCAST, metaInfo.getAssignmentMode()); + Assert.assertEquals(keySerializer, metaInfo.getKeySerializer()); + Assert.assertEquals(valueSerializer, metaInfo.getValueSerializer()); } @Test @@ -269,13 +314,13 @@ public void testOperatorStateMetaInfoReadSerializerFailureResilience() throws Ex RegisteredOperatorBackendStateMetaInfo.Snapshot metaInfo = new RegisteredOperatorBackendStateMetaInfo<>( - name, stateSerializer, OperatorStateHandle.Mode.BROADCAST).snapshot(); + name, stateSerializer, OperatorStateHandle.Mode.UNION).snapshot(); byte[] serialized; try (ByteArrayOutputStreamWithPos out = new ByteArrayOutputStreamWithPos()) { OperatorBackendStateMetaInfoSnapshotReaderWriters - .getWriterForVersion(OperatorBackendSerializationProxy.VERSION, metaInfo) - .writeStateMetaInfo(new DataOutputViewStreamWrapper(out)); + .getOperatorStateWriterForVersion(OperatorBackendSerializationProxy.VERSION, metaInfo) + .writeOperatorStateMetaInfo(new DataOutputViewStreamWrapper(out)); serialized = out.toByteArray(); } @@ -288,8 +333,8 @@ public void testOperatorStateMetaInfoReadSerializerFailureResilience() throws Ex try (ByteArrayInputStreamWithPos in = new ByteArrayInputStreamWithPos(serialized)) { metaInfo = OperatorBackendStateMetaInfoSnapshotReaderWriters - .getReaderForVersion(OperatorBackendSerializationProxy.VERSION, Thread.currentThread().getContextClassLoader()) - .readStateMetaInfo(new DataInputViewStreamWrapper(in)); + .getOperatorStateReaderForVersion(OperatorBackendSerializationProxy.VERSION, Thread.currentThread().getContextClassLoader()) + .readOperatorStateMetaInfo(new DataInputViewStreamWrapper(in)); } Assert.assertEquals(name, metaInfo.getName()); @@ -297,6 +342,44 @@ public void testOperatorStateMetaInfoReadSerializerFailureResilience() throws Ex Assert.assertEquals(stateSerializer.snapshotConfiguration(), metaInfo.getPartitionStateSerializerConfigSnapshot()); } + @Test + public void testBroadcastStateMetaInfoReadSerializerFailureResilience() throws Exception { + String broadcastName = "broadcastTest"; + TypeSerializer keySerializer = DoubleSerializer.INSTANCE; + TypeSerializer valueSerializer = StringSerializer.INSTANCE; + + RegisteredBroadcastBackendStateMetaInfo.Snapshot broadcastMetaInfo = + new RegisteredBroadcastBackendStateMetaInfo<>( + broadcastName, OperatorStateHandle.Mode.BROADCAST, keySerializer, valueSerializer).snapshot(); + + byte[] serialized; + try (ByteArrayOutputStreamWithPos out = new ByteArrayOutputStreamWithPos()) { + OperatorBackendStateMetaInfoSnapshotReaderWriters + .getBroadcastStateWriterForVersion(OperatorBackendSerializationProxy.VERSION, broadcastMetaInfo) + .writeBroadcastStateMetaInfo(new DataOutputViewStreamWrapper(out)); + + serialized = out.toByteArray(); + } + + // mock failure when deserializing serializer + TypeSerializerSerializationUtil.TypeSerializerSerializationProxy mockProxy = + mock(TypeSerializerSerializationUtil.TypeSerializerSerializationProxy.class); + doThrow(new IOException()).when(mockProxy).read(any(DataInputViewStreamWrapper.class)); + PowerMockito.whenNew(TypeSerializerSerializationUtil.TypeSerializerSerializationProxy.class).withAnyArguments().thenReturn(mockProxy); + + try (ByteArrayInputStreamWithPos in = new ByteArrayInputStreamWithPos(serialized)) { + broadcastMetaInfo = OperatorBackendStateMetaInfoSnapshotReaderWriters + .getBroadcastStateReaderForVersion(OperatorBackendSerializationProxy.VERSION, Thread.currentThread().getContextClassLoader()) + .readBroadcastStateMetaInfo(new DataInputViewStreamWrapper(in)); + } + + Assert.assertEquals(broadcastName, broadcastMetaInfo.getName()); + Assert.assertEquals(null, broadcastMetaInfo.getKeySerializer()); + Assert.assertEquals(keySerializer.snapshotConfiguration(), broadcastMetaInfo.getKeySerializerConfigSnapshot()); + Assert.assertEquals(null, broadcastMetaInfo.getValueSerializer()); + Assert.assertEquals(valueSerializer.snapshotConfiguration(), broadcastMetaInfo.getValueSerializerConfigSnapshot()); + } + /** * This test fixes the order of elements in the enum which is important for serialization. Do not modify this test * except if you are entirely sure what you are doing. From d87da64392fb944415a5d68cffc388a29654ed43 Mon Sep 17 00:00:00 2001 From: kkloudas Date: Thu, 21 Dec 2017 14:38:54 +0100 Subject: [PATCH 3/6] [FLINK-3659] Expose broadcast state on DataStream API. --- .../datastream/BroadcastConnectedStream.java | 255 +++++++ .../api/datastream/BroadcastStream.java | 87 +++ .../streaming/api/datastream/DataStream.java | 41 ++ .../co/BaseBroadcastProcessFunction.java | 105 +++ .../co/BroadcastProcessFunction.java | 93 +++ .../co/KeyedBroadcastProcessFunction.java | 145 ++++ .../api/graph/StreamGraphGenerator.java | 2 +- .../co/CoBroadcastWithKeyedOperator.java | 324 +++++++++ .../co/CoBroadcastWithNonKeyedOperator.java | 228 ++++++ .../TwoInputTransformation.java | 4 +- .../flink/streaming/api/DataStreamTest.java | 186 +++++ .../co/CoBroadcastWithKeyedOperatorTest.java | 655 ++++++++++++++++++ .../CoBroadcastWithNonKeyedOperatorTest.java | 497 +++++++++++++ .../TwoInputStreamOperatorTestHarness.java | 2 +- 14 files changed, 2620 insertions(+), 4 deletions(-) create mode 100644 flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/BroadcastConnectedStream.java create mode 100644 flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/BroadcastStream.java create mode 100644 flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/co/BaseBroadcastProcessFunction.java create mode 100644 flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/co/BroadcastProcessFunction.java create mode 100644 flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/co/KeyedBroadcastProcessFunction.java create mode 100644 flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithKeyedOperator.java create mode 100644 flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithNonKeyedOperator.java create mode 100644 flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithKeyedOperatorTest.java create mode 100644 flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithNonKeyedOperatorTest.java diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/BroadcastConnectedStream.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/BroadcastConnectedStream.java new file mode 100644 index 0000000000000..aeb3bc27b0825 --- /dev/null +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/BroadcastConnectedStream.java @@ -0,0 +1,255 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.api.datastream; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.api.common.state.MapStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.Utils; +import org.apache.flink.api.java.typeutils.TypeExtractor; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.functions.co.BroadcastProcessFunction; +import org.apache.flink.streaming.api.functions.co.KeyedBroadcastProcessFunction; +import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; +import org.apache.flink.streaming.api.operators.co.CoBroadcastWithKeyedOperator; +import org.apache.flink.streaming.api.operators.co.CoBroadcastWithNonKeyedOperator; +import org.apache.flink.streaming.api.transformations.TwoInputTransformation; +import org.apache.flink.util.Preconditions; + +import java.util.Collections; + +import static java.util.Objects.requireNonNull; + +/** + * A BroadcastConnectedStream represents the result of connecting a keyed or non-keyed stream, + * with a {@link BroadcastStream} with {@link org.apache.flink.api.common.state.BroadcastState + * BroadcastState}. As in the case of {@link ConnectedStreams} these streams are useful for cases + * where operations on one stream directly affect the operations on the other stream, usually via + * shared state between the streams. + * + *

An example for the use of such connected streams would be to apply rules that change over time + * onto another, possibly keyed stream. The stream with the broadcast state has the rules, and will + * store them in the broadcast state, while the other stream will contain the elements to apply the + * rules to. By broadcasting the rules, these will be available in all parallel instances, and + * can be applied to all partitions of the other stream. + * + * @param The input type of the non-broadcast side. + * @param The input type of the broadcast side. + * @param The key type of the elements in the {@link org.apache.flink.api.common.state.BroadcastState BroadcastState}. + * @param The value type of the elements in the {@link org.apache.flink.api.common.state.BroadcastState BroadcastState}. + */ +@PublicEvolving +public class BroadcastConnectedStream { + + private final StreamExecutionEnvironment environment; + private final DataStream inputStream1; + private final BroadcastStream inputStream2; + private final MapStateDescriptor broadcastStateDescriptor; + + protected BroadcastConnectedStream( + final StreamExecutionEnvironment env, + final DataStream input1, + final BroadcastStream input2, + final MapStateDescriptor broadcastStateDescriptor) { + this.environment = requireNonNull(env); + this.inputStream1 = requireNonNull(input1); + this.inputStream2 = requireNonNull(input2); + this.broadcastStateDescriptor = requireNonNull(broadcastStateDescriptor); + } + + public StreamExecutionEnvironment getExecutionEnvironment() { + return environment; + } + + /** + * Returns the non-broadcast {@link DataStream}. + * + * @return The stream which, by convention, is not broadcasted. + */ + public DataStream getFirstInput() { + return inputStream1; + } + + /** + * Returns the {@link BroadcastStream}. + * + * @return The stream which, by convention, is the broadcast one. + */ + public BroadcastStream getSecondInput() { + return inputStream2; + } + + /** + * Gets the type of the first input. + * + * @return The type of the first input + */ + public TypeInformation getType1() { + return inputStream1.getType(); + } + + /** + * Gets the type of the second input. + * + * @return The type of the second input + */ + public TypeInformation getType2() { + return inputStream2.getType(); + } + + /** + * Assumes as inputs a {@link BroadcastStream} and a {@link KeyedStream} and applies the given + * {@link KeyedBroadcastProcessFunction} on them, thereby creating a transformed output stream. + * + * @param function The {@link KeyedBroadcastProcessFunction} that is called for each element in the stream. + * @param The type of the output elements. + * @return The transformed {@link DataStream}. + */ + @PublicEvolving + public SingleOutputStreamOperator process(final KeyedBroadcastProcessFunction function) { + + TypeInformation outTypeInfo = TypeExtractor.getBinaryOperatorReturnType( + function, + KeyedBroadcastProcessFunction.class, + 0, + 1, + 2, + TypeExtractor.NO_INDEX, + TypeExtractor.NO_INDEX, + TypeExtractor.NO_INDEX, + getType1(), + getType2(), + Utils.getCallLocationName(), + true); + + return process(function, outTypeInfo); + } + + /** + * Assumes as inputs a {@link BroadcastStream} and a {@link KeyedStream} and applies the given + * {@link KeyedBroadcastProcessFunction} on them, thereby creating a transformed output stream. + * + * @param function The {@link KeyedBroadcastProcessFunction} that is called for each element in the stream. + * @param outTypeInfo The type of the output elements. + * @param The type of the output elements. + * @return The transformed {@link DataStream}. + */ + @PublicEvolving + public SingleOutputStreamOperator process( + final KeyedBroadcastProcessFunction function, + final TypeInformation outTypeInfo) { + + Preconditions.checkNotNull(function); + Preconditions.checkArgument(inputStream1 instanceof KeyedStream, + "A KeyedBroadcastProcessFunction can only be used with a keyed stream as the second input."); + + TwoInputStreamOperator operator = + new CoBroadcastWithKeyedOperator<>(function, Collections.singletonList(broadcastStateDescriptor)); + return transform("Co-Process-Broadcast-Keyed", outTypeInfo, operator); + } + + /** + * Assumes as inputs a {@link BroadcastStream} and a non-keyed {@link DataStream} and applies the given + * {@link BroadcastProcessFunction} on them, thereby creating a transformed output stream. + * + * @param function The {@link BroadcastProcessFunction} that is called for each element in the stream. + * @param The type of the output elements. + * @return The transformed {@link DataStream}. + */ + @PublicEvolving + public SingleOutputStreamOperator process(final BroadcastProcessFunction function) { + + TypeInformation outTypeInfo = TypeExtractor.getBinaryOperatorReturnType( + function, + BroadcastProcessFunction.class, + 0, + 1, + 2, + TypeExtractor.NO_INDEX, + TypeExtractor.NO_INDEX, + TypeExtractor.NO_INDEX, + getType1(), + getType2(), + Utils.getCallLocationName(), + true); + + return process(function, outTypeInfo); + } + + /** + * Assumes as inputs a {@link BroadcastStream} and a non-keyed {@link DataStream} and applies the given + * {@link BroadcastProcessFunction} on them, thereby creating a transformed output stream. + * + * @param function The {@link BroadcastProcessFunction} that is called for each element in the stream. + * @param outTypeInfo The type of the output elements. + * @param The type of the output elements. + * @return The transformed {@link DataStream}. + */ + @PublicEvolving + public SingleOutputStreamOperator process( + final BroadcastProcessFunction function, + final TypeInformation outTypeInfo) { + + Preconditions.checkNotNull(function); + Preconditions.checkArgument(!(inputStream1 instanceof KeyedStream), + "A BroadcastProcessFunction can only be used with a non-keyed stream as the second input."); + + TwoInputStreamOperator operator = + new CoBroadcastWithNonKeyedOperator<>(function, Collections.singletonList(broadcastStateDescriptor)); + return transform("Co-Process-Broadcast", outTypeInfo, operator); + } + + @Internal + private SingleOutputStreamOperator transform( + final String functionName, + final TypeInformation outTypeInfo, + final TwoInputStreamOperator operator) { + + // read the output type of the input Transforms to coax out errors about MissingTypeInfo + inputStream1.getType(); + inputStream2.getType(); + + TwoInputTransformation transform = new TwoInputTransformation<>( + inputStream1.getTransformation(), + inputStream2.getTransformation(), + functionName, + operator, + outTypeInfo, + environment.getParallelism()); + + if (inputStream1 instanceof KeyedStream) { + KeyedStream keyedInput1 = (KeyedStream) inputStream1; + TypeInformation keyType1 = keyedInput1.getKeyType(); + transform.setStateKeySelectors(keyedInput1.getKeySelector(), null); + transform.setStateKeyType(keyType1); + } + + @SuppressWarnings({ "unchecked", "rawtypes" }) + SingleOutputStreamOperator returnStream = new SingleOutputStreamOperator(environment, transform); + + getExecutionEnvironment().addOperator(transform); + + return returnStream; + } + + protected F clean(F f) { + return getExecutionEnvironment().clean(f); + } +} diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/BroadcastStream.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/BroadcastStream.java new file mode 100644 index 0000000000000..e21e36faff22b --- /dev/null +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/BroadcastStream.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.api.datastream; + +import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.api.common.state.MapStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.transformations.StreamTransformation; + +import static java.util.Objects.requireNonNull; + +/** + * A {@code BroadcastStream} is a stream with {@link org.apache.flink.api.common.state.BroadcastState BroadcastState}. + * This can be created by any stream using the {@link DataStream#broadcast(MapStateDescriptor)} method and + * implicitly creates a state where the user can store elements of the created {@code BroadcastStream}. + * (see {@link BroadcastConnectedStream}). + * + *

Note that no further operation can be applied to these streams. The only available option is to connect them + * with a keyed or non-keyed stream, using the {@link KeyedStream#connect(BroadcastStream)} and the + * {@link DataStream#connect(BroadcastStream)} respectively. Applying these methods will result it a + * {@link BroadcastConnectedStream} for further processing. + * + * @param The type of input/output elements. + * @param The key type of the elements in the {@link org.apache.flink.api.common.state.BroadcastState BroadcastState}. + * @param The value type of the elements in the {@link org.apache.flink.api.common.state.BroadcastState BroadcastState}. + */ +@PublicEvolving +public class BroadcastStream { + + private final StreamExecutionEnvironment environment; + + private final DataStream inputStream; + + /** + * The {@link org.apache.flink.api.common.state.StateDescriptor state descriptor} of the + * {@link org.apache.flink.api.common.state.BroadcastState broadcast state}. This state + * has a {@code key-value} format. + */ + private final MapStateDescriptor broadcastStateDescriptor; + + protected BroadcastStream( + final StreamExecutionEnvironment env, + final DataStream input, + final MapStateDescriptor broadcastStateDescriptor) { + + this.environment = requireNonNull(env); + this.inputStream = requireNonNull(input); + this.broadcastStateDescriptor = requireNonNull(broadcastStateDescriptor); + } + + public TypeInformation getType() { + return inputStream.getType(); + } + + public F clean(F f) { + return environment.clean(f); + } + + public StreamTransformation getTransformation() { + return inputStream.getTransformation(); + } + + public MapStateDescriptor getBroadcastStateDescriptor() { + return broadcastStateDescriptor; + } + + public StreamExecutionEnvironment getEnvironment() { + return environment; + } +} diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/DataStream.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/DataStream.java index 83c11266ce00a..d85968957e057 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/DataStream.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/DataStream.java @@ -32,6 +32,7 @@ import org.apache.flink.api.common.operators.Keys; import org.apache.flink.api.common.operators.ResourceSpec; import org.apache.flink.api.common.serialization.SerializationSchema; +import org.apache.flink.api.common.state.MapStateDescriptor; import org.apache.flink.api.common.typeinfo.BasicArrayTypeInfo; import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; import org.apache.flink.api.common.typeinfo.TypeInformation; @@ -252,6 +253,30 @@ public ConnectedStreams connect(DataStream dataStream) { return new ConnectedStreams<>(environment, this, dataStream); } + /** + * Creates a new {@link BroadcastConnectedStream} by connecting the current + * {@link DataStream} or {@link KeyedStream} with a {@link BroadcastStream}. + * + *

The latter can be created using the {@link #broadcast(MapStateDescriptor)} method. + * + *

The resulting stream can be further processed using the {@code BroadcastConnectedStream.process(MyFunction)} + * method, where {@code MyFunction} can be either a + * {@link org.apache.flink.streaming.api.functions.co.KeyedBroadcastProcessFunction KeyedBroadcastProcessFunction} + * or a {@link org.apache.flink.streaming.api.functions.co.BroadcastProcessFunction BroadcastProcessFunction} + * depending on the current stream being a {@link KeyedStream} or not. + * + * @param broadcastStream The broadcast stream with the broadcast state to be connected with this stream. + * @return The {@link BroadcastConnectedStream}. + */ + @PublicEvolving + public BroadcastConnectedStream connect(BroadcastStream broadcastStream) { + return new BroadcastConnectedStream<>( + environment, + this, + Preconditions.checkNotNull(broadcastStream), + broadcastStream.getBroadcastStateDescriptor()); + } + /** * It creates a new {@link KeyedStream} that uses the provided key for partitioning * its operator states. @@ -371,6 +396,22 @@ public DataStream broadcast() { return setConnectionType(new BroadcastPartitioner()); } + /** + * Sets the partitioning of the {@link DataStream} so that the output elements + * are broadcasted to every parallel instance of the next operation. In addition, + * it implicitly creates a {@link org.apache.flink.api.common.state.BroadcastState broadcast state} + * which can be used to store the element of the stream. + * + * @return A {@link BroadcastStream} which can be used in the {@link #connect(BroadcastStream)} to + * create a {@link BroadcastConnectedStream} for further processing of the elements. + */ + @PublicEvolving + public BroadcastStream broadcast(final MapStateDescriptor broadcastStateDescriptor) { + Preconditions.checkNotNull(broadcastStateDescriptor); + final DataStream broadcastStream = setConnectionType(new BroadcastPartitioner<>()); + return new BroadcastStream<>(environment, broadcastStream, broadcastStateDescriptor); + } + /** * Sets the partitioning of the {@link DataStream} so that the output elements * are shuffled uniformly randomly to the next operation. diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/co/BaseBroadcastProcessFunction.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/co/BaseBroadcastProcessFunction.java new file mode 100644 index 0000000000000..9419d806603e1 --- /dev/null +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/co/BaseBroadcastProcessFunction.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.api.functions.co; + +import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.api.common.functions.AbstractRichFunction; +import org.apache.flink.api.common.state.BroadcastState; +import org.apache.flink.api.common.state.MapStateDescriptor; +import org.apache.flink.api.common.state.ReadOnlyBroadcastState; +import org.apache.flink.util.OutputTag; + +/** + * The base class containing the functionality available to all broadcast process function. + * These include the {@link BroadcastProcessFunction} and the {@link KeyedBroadcastProcessFunction}. + */ +@PublicEvolving +public abstract class BaseBroadcastProcessFunction extends AbstractRichFunction { + + private static final long serialVersionUID = -131631008887478610L; + + /** + * The base context available to all methods in a broadcast process function. This + * include {@link BroadcastProcessFunction BroadcastProcessFunctions} and + * {@link KeyedBroadcastProcessFunction KeyedBroadcastProcessFunctions}. + */ + abstract class BaseContext { + + /** + * Timestamp of the element currently being processed or timestamp of a firing timer. + * + *

This might be {@code null}, for example if the time characteristic of your program + * is set to {@link org.apache.flink.streaming.api.TimeCharacteristic#ProcessingTime}. + */ + public abstract Long timestamp(); + + /** + * Emits a record to the side output identified by the {@link OutputTag}. + * + * @param outputTag the {@code OutputTag} that identifies the side output to emit to. + * @param value The record to emit. + */ + public abstract void output(OutputTag outputTag, X value); + + /** Returns the current processing time. */ + public abstract long currentProcessingTime(); + + /** Returns the current event-time watermark. */ + public abstract long currentWatermark(); + } + + /** + * A base {@link BaseContext context} available to the broadcasted stream side of + * a {@link org.apache.flink.streaming.api.datastream.BroadcastConnectedStream BroadcastConnectedStream}. + * + *

Apart from the basic functionality of a {@link BaseContext context}, + * this also allows to get and update the elements stored in the + * {@link BroadcastState broadcast state}. + * In other words, it gives read/write access to the broadcast state. + */ + public abstract class Context extends BaseContext { + + /** + * Fetches the {@link BroadcastState} with the specified name. + * + * @param stateDescriptor the {@link MapStateDescriptor} of the state to be fetched. + * @return The required {@link BroadcastState broadcast state}. + */ + public abstract BroadcastState getBroadcastState(MapStateDescriptor stateDescriptor); + } + + /** + * A {@link BaseContext context} available to the non-broadcasted stream side of + * a {@link org.apache.flink.streaming.api.datastream.BroadcastConnectedStream BroadcastConnectedStream}. + * + *

Apart from the basic functionality of a {@link BaseContext context}, + * this also allows to get a read-only {@link Iterable} over the elements stored in the + * broadcast state. + */ + public abstract class ReadOnlyContext extends BaseContext { + + /** + * Fetches a read-only view of the broadcast state with the specified name. + * + * @param stateDescriptor the {@link MapStateDescriptor} of the state to be fetched. + * @return The required read-only view of the broadcast state. + */ + public abstract ReadOnlyBroadcastState getBroadcastState(MapStateDescriptor stateDescriptor); + } +} diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/co/BroadcastProcessFunction.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/co/BroadcastProcessFunction.java new file mode 100644 index 0000000000000..4dcc92992cd4c --- /dev/null +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/co/BroadcastProcessFunction.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.api.functions.co; + +import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.api.common.state.MapStateDescriptor; +import org.apache.flink.util.Collector; + +/** + * A function to be applied to a + * {@link org.apache.flink.streaming.api.datastream.BroadcastConnectedStream BroadcastConnectedStream} that + * connects {@link org.apache.flink.streaming.api.datastream.BroadcastStream BroadcastStream}, i.e. a stream + * with broadcast state, with a non-keyed {@link org.apache.flink.streaming.api.datastream.DataStream DataStream}. + * + *

The stream with the broadcast state can be created using the + * {@link org.apache.flink.streaming.api.datastream.DataStream#broadcast(MapStateDescriptor) + * stream.broadcast(MapStateDescriptor)} method. + * + *

The user has to implement two methods: + *

    + *
  1. the {@link #processBroadcastElement(Object, Context, Collector)} which will be applied to + * each element in the broadcast side + *
  2. and the {@link #processElement(Object, ReadOnlyContext, Collector)} which will be applied to the + * non-broadcasted/keyed side. + *
+ * + *

The {@code processElementOnBroadcastSide()} takes as argument (among others) a context that allows it to + * read/write to the broadcast state, while the {@code processElement()} has read-only access to the broadcast state. + * + * @param The input type of the non-broadcast side. + * @param The input type of the broadcast side. + * @param The output type of the operator. + */ +@PublicEvolving +public abstract class BroadcastProcessFunction extends BaseBroadcastProcessFunction { + + private static final long serialVersionUID = 8352559162119034453L; + + /** + * This method is called for each element in the (non-broadcast) + * {@link org.apache.flink.streaming.api.datastream.DataStream data stream}. + * + *

This function can output zero or more elements using the {@link Collector} parameter, + * query the current processing/event time, and also query and update the local keyed state. + * Finally, it has read-only access to the broadcast state. + * The context is only valid during the invocation of this method, do not store it. + * + * @param value The stream element. + * @param ctx A {@link ReadOnlyContext} that allows querying the timestamp of the element, + * querying the current processing/event time and updating the broadcast state. + * The context is only valid during the invocation of this method, do not store it. + * @param out The collector to emit resulting elements to + * @throws Exception The function may throw exceptions which cause the streaming program + * to fail and go into recovery. + */ + public abstract void processElement(IN1 value, ReadOnlyContext ctx, Collector out) throws Exception; + + /** + * This method is called for each element in the + * {@link org.apache.flink.streaming.api.datastream.BroadcastStream broadcast stream}. + * + *

This function can output zero or more elements using the {@link Collector} parameter, + * query the current processing/event time, and also query and update the internal + * {@link org.apache.flink.api.common.state.BroadcastState broadcast state}. These can be done + * through the provided {@link Context}. + * The context is only valid during the invocation of this method, do not store it. + * + * @param value The stream element. + * @param ctx A {@link Context} that allows querying the timestamp of the element, + * querying the current processing/event time and updating the broadcast state. + * The context is only valid during the invocation of this method, do not store it. + * @param out The collector to emit resulting elements to + * @throws Exception The function may throw exceptions which cause the streaming program + * to fail and go into recovery. + */ + public abstract void processBroadcastElement(IN2 value, Context ctx, Collector out) throws Exception; +} diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/co/KeyedBroadcastProcessFunction.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/co/KeyedBroadcastProcessFunction.java new file mode 100644 index 0000000000000..9d14259b9aed3 --- /dev/null +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/co/KeyedBroadcastProcessFunction.java @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.api.functions.co; + +import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.api.common.state.MapStateDescriptor; +import org.apache.flink.streaming.api.TimeDomain; +import org.apache.flink.streaming.api.TimerService; +import org.apache.flink.util.Collector; + +/** + * A function to be applied to a + * {@link org.apache.flink.streaming.api.datastream.BroadcastConnectedStream BroadcastConnectedStream} that + * connects {@link org.apache.flink.streaming.api.datastream.BroadcastStream BroadcastStream}, i.e. a stream + * with broadcast state, with a {@link org.apache.flink.streaming.api.datastream.KeyedStream KeyedStream}. + * + *

The stream with the broadcast state can be created using the + * {@link org.apache.flink.streaming.api.datastream.KeyedStream#broadcast(MapStateDescriptor) + * keyedStream.broadcast(MapStateDescriptor)} method. + * + *

The user has to implement two methods: + *

    + *
  1. the {@link #processBroadcastElement(Object, Context, Collector)} which will be applied to + * each element in the broadcast side + *
  2. and the {@link #processElement(Object, KeyedReadOnlyContext, Collector)} which will be applied to the + * non-broadcasted/keyed side. + *
+ * + *

The {@code processElementOnBroadcastSide()} takes as an argument (among others) a context that allows it to + * read/write to the broadcast state and also apply a transformation to all (local) keyed states, while the + * {@code processElement()} has read-only access to the broadcast state, but can read/write to the keyed state and + * register timers. + * + * @param The input type of the keyed (non-broadcast) side. + * @param The input type of the broadcast side. + * @param The output type of the operator. + */ +@PublicEvolving +public abstract class KeyedBroadcastProcessFunction extends BaseBroadcastProcessFunction { + + private static final long serialVersionUID = -2584726797564976453L; + + /** + * This method is called for each element in the (non-broadcast) + * {@link org.apache.flink.streaming.api.datastream.KeyedStream keyed stream}. + * + *

It can output zero or more elements using the {@link Collector} parameter, + * query the current processing/event time, and also query and update the local keyed state. + * In addition, it can get a {@link TimerService} for registering timers and querying the time. + * Finally, it has read-only access to the broadcast state. + * The context is only valid during the invocation of this method, do not store it. + * + * @param value The stream element. + * @param ctx A {@link KeyedReadOnlyContext} that allows querying the timestamp of the element, + * querying the current processing/event time and iterating the broadcast state + * with read-only access. + * The context is only valid during the invocation of this method, do not store it. + * @param out The collector to emit resulting elements to + * @throws Exception The function may throw exceptions which cause the streaming program + * to fail and go into recovery. + */ + public abstract void processElement(final IN1 value, final KeyedReadOnlyContext ctx, final Collector out) throws Exception; + + /** + * This method is called for each element in the + * {@link org.apache.flink.streaming.api.datastream.BroadcastStream broadcast stream}. + * + *

It can output zero or more elements using the {@link Collector} parameter, + * query the current processing/event time, and also query and update the internal + * {@link org.apache.flink.api.common.state.BroadcastState broadcast state}. These can + * be done through the provided {@link Context}. + * The context is only valid during the invocation of this method, do not store it. + * + * @param value The stream element. + * @param ctx A {@link Context} that allows querying the timestamp of the element, + * querying the current processing/event time and updating the broadcast state. + * The context is only valid during the invocation of this method, do not store it. + * @param out The collector to emit resulting elements to + * @throws Exception The function may throw exceptions which cause the streaming program + * to fail and go into recovery. + */ + public abstract void processBroadcastElement(final IN2 value, final Context ctx, final Collector out) throws Exception; + + /** + * Called when a timer set using {@link TimerService} fires. + * + * @param timestamp The timestamp of the firing timer. + * @param ctx An {@link OnTimerContext} that allows querying the timestamp of the firing timer, + * querying the current processing/event time, iterating the broadcast state + * with read-only access, querying the {@link TimeDomain} of the firing timer + * and getting a {@link TimerService} for registering timers and querying the time. + * The context is only valid during the invocation of this method, do not store it. + * @param out The collector for returning result values. + * + * @throws Exception This method may throw exceptions. Throwing an exception will cause the operation + * to fail and may trigger recovery. + */ + public void onTimer(final long timestamp, final OnTimerContext ctx, final Collector out) throws Exception { + // the default implementation does nothing. + } + + /** + * A {@link BaseBroadcastProcessFunction.Context context} available to the keyed stream side of + * a {@link org.apache.flink.streaming.api.datastream.BroadcastConnectedStream} (if any). + * + *

Apart from the basic functionality of a {@link BaseBroadcastProcessFunction.Context context}, + * this also allows to get a read-only {@link Iterable} over the elements stored in the + * broadcast state and a {@link TimerService} for querying time and registering timers. + */ + public abstract class KeyedReadOnlyContext extends ReadOnlyContext { + + /** + * A {@link TimerService} for querying time and registering timers. + */ + public abstract TimerService timerService(); + } + + /** + * Information available in an invocation of {@link #onTimer(long, OnTimerContext, Collector)}. + */ + public abstract class OnTimerContext extends KeyedReadOnlyContext { + + /** + * The {@link TimeDomain} of the firing timer, i.e. if it is + * event or processing time timer. + */ + public abstract TimeDomain timeDomain(); + } +} diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphGenerator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphGenerator.java index 0a05f09eb5ac1..7d0333f4160ed 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphGenerator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphGenerator.java @@ -586,7 +586,7 @@ private Collection transformTwoInputTransform(TwoInputT transform.getOutputType(), transform.getName()); - if (transform.getStateKeySelector1() != null) { + if (transform.getStateKeySelector1() != null || transform.getStateKeySelector2() != null) { TypeSerializer keySerializer = transform.getStateKeyType().createSerializer(env.getConfig()); streamGraph.setTwoInputStateKey(transform.getId(), transform.getStateKeySelector1(), transform.getStateKeySelector2(), keySerializer); } diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithKeyedOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithKeyedOperator.java new file mode 100644 index 0000000000000..794b0db707dfa --- /dev/null +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithKeyedOperator.java @@ -0,0 +1,324 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.api.operators.co; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.state.BroadcastState; +import org.apache.flink.api.common.state.MapStateDescriptor; +import org.apache.flink.api.common.state.ReadOnlyBroadcastState; +import org.apache.flink.runtime.state.VoidNamespace; +import org.apache.flink.runtime.state.VoidNamespaceSerializer; +import org.apache.flink.streaming.api.SimpleTimerService; +import org.apache.flink.streaming.api.TimeDomain; +import org.apache.flink.streaming.api.TimerService; +import org.apache.flink.streaming.api.functions.co.BaseBroadcastProcessFunction; +import org.apache.flink.streaming.api.functions.co.KeyedBroadcastProcessFunction; +import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator; +import org.apache.flink.streaming.api.operators.InternalTimer; +import org.apache.flink.streaming.api.operators.InternalTimerService; +import org.apache.flink.streaming.api.operators.TimestampedCollector; +import org.apache.flink.streaming.api.operators.Triggerable; +import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.util.OutputTag; +import org.apache.flink.util.Preconditions; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.apache.flink.util.Preconditions.checkArgument; +import static org.apache.flink.util.Preconditions.checkState; + +/** + * A {@link TwoInputStreamOperator} for executing {@link KeyedBroadcastProcessFunction KeyedBroadcastProcessFunctions}. + * + * @param The key type of the input keyed stream. + * @param The input type of the keyed (non-broadcast) side. + * @param The input type of the broadcast side. + * @param The output type of the operator. + */ +@Internal +public class CoBroadcastWithKeyedOperator + extends AbstractUdfStreamOperator> + implements TwoInputStreamOperator, Triggerable { + + private static final long serialVersionUID = 5926499536290284870L; + + private final List> broadcastStateDescriptors; + + private transient TimestampedCollector collector; + + private transient Map, BroadcastState> broadcastStates; + + private transient ReadWriteContextImpl rwContext; + + private transient ReadOnlyContextImpl rContext; + + private transient OnTimerContextImpl onTimerContext; + + public CoBroadcastWithKeyedOperator( + final KeyedBroadcastProcessFunction function, + final List> broadcastStateDescriptors) { + super(function); + this.broadcastStateDescriptors = Preconditions.checkNotNull(broadcastStateDescriptors); + } + + @Override + public void open() throws Exception { + super.open(); + + InternalTimerService internalTimerService = + getInternalTimerService("user-timers", VoidNamespaceSerializer.INSTANCE, this); + + TimerService timerService = new SimpleTimerService(internalTimerService); + + collector = new TimestampedCollector<>(output); + + this.broadcastStates = new HashMap<>(broadcastStateDescriptors.size()); + for (MapStateDescriptor descriptor: broadcastStateDescriptors) { + broadcastStates.put(descriptor, getOperatorStateBackend().getBroadcastState(descriptor)); + } + + rwContext = new ReadWriteContextImpl(userFunction, broadcastStates, timerService); + rContext = new ReadOnlyContextImpl(userFunction, broadcastStates, timerService); + onTimerContext = new OnTimerContextImpl(userFunction, broadcastStates, timerService); + } + + @Override + public void processElement1(StreamRecord element) throws Exception { + collector.setTimestamp(element); + rContext.setElement(element); + userFunction.processElement(element.getValue(), rContext, collector); + rContext.setElement(null); + } + + @Override + public void processElement2(StreamRecord element) throws Exception { + collector.setTimestamp(element); + rwContext.setElement(element); + userFunction.processBroadcastElement(element.getValue(), rwContext, collector); + rwContext.setElement(null); + } + + @Override + public void onEventTime(InternalTimer timer) throws Exception { + collector.setAbsoluteTimestamp(timer.getTimestamp()); + onTimerContext.timeDomain = TimeDomain.EVENT_TIME; + onTimerContext.timer = timer; + userFunction.onTimer(timer.getTimestamp(), onTimerContext, collector); + onTimerContext.timeDomain = null; + onTimerContext.timer = null; + } + + @Override + public void onProcessingTime(InternalTimer timer) throws Exception { + collector.eraseTimestamp(); + onTimerContext.timeDomain = TimeDomain.PROCESSING_TIME; + onTimerContext.timer = timer; + userFunction.onTimer(timer.getTimestamp(), onTimerContext, collector); + onTimerContext.timeDomain = null; + onTimerContext.timer = null; + } + + private class ReadWriteContextImpl extends BaseBroadcastProcessFunction.Context { + + private final Map, BroadcastState> states; + + private final TimerService timerService; + + private StreamRecord element; + + ReadWriteContextImpl ( + final KeyedBroadcastProcessFunction function, + final Map, BroadcastState> broadcastStates, + final TimerService timerService) { + + function.super(); + this.states = Preconditions.checkNotNull(broadcastStates); + this.timerService = Preconditions.checkNotNull(timerService); + } + + void setElement(StreamRecord e) { + this.element = e; + } + + @Override + public Long timestamp() { + checkState(element != null); + return element.getTimestamp(); + } + + @Override + public BroadcastState getBroadcastState(MapStateDescriptor stateDescriptor) { + Preconditions.checkNotNull(stateDescriptor); + BroadcastState state = (BroadcastState) states.get(stateDescriptor); + if (state == null) { + throw new IllegalArgumentException("The requested state does not exist. " + + "Check for typos in your state descriptor, or specify the state descriptor " + + "in the datastream.broadcast(...) call if you forgot to register it."); + } + return state; + } + + @Override + public void output(OutputTag outputTag, X value) { + checkArgument(outputTag != null, "OutputTag must not be null."); + output.collect(outputTag, new StreamRecord<>(value, element.getTimestamp())); + } + + @Override + public long currentProcessingTime() { + return timerService.currentProcessingTime(); + } + + @Override + public long currentWatermark() { + return timerService.currentWatermark(); + } + } + + private class ReadOnlyContextImpl extends KeyedBroadcastProcessFunction.KeyedReadOnlyContext { + + private final Map, BroadcastState> states; + + private final TimerService timerService; + + private StreamRecord element; + + ReadOnlyContextImpl( + final KeyedBroadcastProcessFunction function, + final Map, BroadcastState> broadcastStates, + final TimerService timerService) { + + function.super(); + this.states = Preconditions.checkNotNull(broadcastStates); + this.timerService = Preconditions.checkNotNull(timerService); + } + + void setElement(StreamRecord e) { + this.element = e; + } + + @Override + public Long timestamp() { + checkState(element != null); + return element.hasTimestamp() ? element.getTimestamp() : null; + } + + @Override + public TimerService timerService() { + return timerService; + } + + @Override + public long currentProcessingTime() { + return timerService.currentProcessingTime(); + } + + @Override + public long currentWatermark() { + return timerService.currentWatermark(); + } + + @Override + public void output(OutputTag outputTag, X value) { + checkArgument(outputTag != null, "OutputTag must not be null."); + output.collect(outputTag, new StreamRecord<>(value, element.getTimestamp())); + } + + @Override + public ReadOnlyBroadcastState getBroadcastState(MapStateDescriptor stateDescriptor) { + Preconditions.checkNotNull(stateDescriptor); + ReadOnlyBroadcastState state = (ReadOnlyBroadcastState) states.get(stateDescriptor); + if (state == null) { + throw new IllegalArgumentException("The requested state does not exist. " + + "Check for typos in your state descriptor, or specify the state descriptor " + + "in the datastream.broadcast(...) call if you forgot to register it."); + } + return state; + } + } + + private class OnTimerContextImpl extends KeyedBroadcastProcessFunction.OnTimerContext { + + private final Map, BroadcastState> states; + + private final TimerService timerService; + + private TimeDomain timeDomain; + + private InternalTimer timer; + + OnTimerContextImpl( + final KeyedBroadcastProcessFunction function, + final Map, BroadcastState> broadcastStates, + final TimerService timerService) { + + function.super(); + this.states = Preconditions.checkNotNull(broadcastStates); + this.timerService = Preconditions.checkNotNull(timerService); + } + + @Override + public Long timestamp() { + checkState(timer != null); + return timer.getTimestamp(); + } + + @Override + public TimeDomain timeDomain() { + checkState(timeDomain != null); + return timeDomain; + } + + @Override + public TimerService timerService() { + return timerService; + } + + @Override + public long currentProcessingTime() { + return timerService.currentProcessingTime(); + } + + @Override + public long currentWatermark() { + return timerService.currentWatermark(); + } + + @Override + public void output(OutputTag outputTag, X value) { + checkArgument(outputTag != null, "OutputTag must not be null."); + output.collect(outputTag, new StreamRecord<>(value, timer.getTimestamp())); + } + + @Override + public ReadOnlyBroadcastState getBroadcastState(MapStateDescriptor stateDescriptor) { + Preconditions.checkNotNull(stateDescriptor); + ReadOnlyBroadcastState state = (ReadOnlyBroadcastState) states.get(stateDescriptor); + if (state == null) { + throw new IllegalArgumentException("The requested state does not exist. " + + "Check for typos in your state descriptor, or specify the state descriptor " + + "in the datastream.broadcast(...) call if you forgot to register it."); + } + return state; + } + } +} diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithNonKeyedOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithNonKeyedOperator.java new file mode 100644 index 0000000000000..25bf873419da9 --- /dev/null +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithNonKeyedOperator.java @@ -0,0 +1,228 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.api.operators.co; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.state.BroadcastState; +import org.apache.flink.api.common.state.MapStateDescriptor; +import org.apache.flink.api.common.state.ReadOnlyBroadcastState; +import org.apache.flink.streaming.api.functions.co.BaseBroadcastProcessFunction; +import org.apache.flink.streaming.api.functions.co.BroadcastProcessFunction; +import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator; +import org.apache.flink.streaming.api.operators.InternalTimerService; +import org.apache.flink.streaming.api.operators.TimestampedCollector; +import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; +import org.apache.flink.streaming.api.watermark.Watermark; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService; +import org.apache.flink.util.OutputTag; +import org.apache.flink.util.Preconditions; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.apache.flink.util.Preconditions.checkArgument; +import static org.apache.flink.util.Preconditions.checkState; + +/** + * A {@link TwoInputStreamOperator} for executing {@link BroadcastProcessFunction BroadcastProcessFunctions}. + * + * @param The input type of the keyed (non-broadcast) side. + * @param The input type of the broadcast side. + * @param The output type of the operator. + */ +@Internal +public class CoBroadcastWithNonKeyedOperator + extends AbstractUdfStreamOperator> + implements TwoInputStreamOperator { + + private static final long serialVersionUID = -1869740381935471752L; + + /** We listen to this ourselves because we don't have an {@link InternalTimerService}. */ + private long currentWatermark = Long.MIN_VALUE; + + private final List> broadcastStateDescriptors; + + private transient TimestampedCollector collector; + + private transient Map, BroadcastState> broadcastStates; + + private transient ReadWriteContextImpl rwContext; + + private transient ReadOnlyContextImpl rContext; + + public CoBroadcastWithNonKeyedOperator( + final BroadcastProcessFunction function, + final List> broadcastStateDescriptors) { + super(function); + this.broadcastStateDescriptors = Preconditions.checkNotNull(broadcastStateDescriptors); + } + + @Override + public void open() throws Exception { + super.open(); + + collector = new TimestampedCollector<>(output); + + this.broadcastStates = new HashMap<>(broadcastStateDescriptors.size()); + for (MapStateDescriptor descriptor: broadcastStateDescriptors) { + broadcastStates.put(descriptor, getOperatorStateBackend().getBroadcastState(descriptor)); + } + + rwContext = new ReadWriteContextImpl(userFunction, broadcastStates, getProcessingTimeService()); + rContext = new ReadOnlyContextImpl(userFunction, broadcastStates, getProcessingTimeService()); + } + + @Override + public void processElement1(StreamRecord element) throws Exception { + collector.setTimestamp(element); + rContext.setElement(element); + userFunction.processElement(element.getValue(), rContext, collector); + rContext.setElement(null); + } + + @Override + public void processElement2(StreamRecord element) throws Exception { + collector.setTimestamp(element); + rwContext.setElement(element); + userFunction.processBroadcastElement(element.getValue(), rwContext, collector); + rwContext.setElement(null); + } + + @Override + public void processWatermark(Watermark mark) throws Exception { + super.processWatermark(mark); + currentWatermark = mark.getTimestamp(); + } + + private class ReadWriteContextImpl extends BaseBroadcastProcessFunction.Context { + + private final Map, BroadcastState> states; + + private final ProcessingTimeService timerService; + + private StreamRecord element; + + ReadWriteContextImpl( + final BroadcastProcessFunction function, + final Map, BroadcastState> broadcastStates, + final ProcessingTimeService timerService) { + + function.super(); + this.states = Preconditions.checkNotNull(broadcastStates); + this.timerService = Preconditions.checkNotNull(timerService); + } + + void setElement(StreamRecord e) { + this.element = e; + } + + @Override + public Long timestamp() { + checkState(element != null); + return element.getTimestamp(); + } + + @Override + public BroadcastState getBroadcastState(MapStateDescriptor stateDescriptor) { + Preconditions.checkNotNull(stateDescriptor); + BroadcastState state = (BroadcastState) states.get(stateDescriptor); + if (state == null) { + throw new IllegalArgumentException("The requested state does not exist. " + + "Check for typos in your state descriptor, or specify the state descriptor " + + "in the datastream.broadcast(...) call if you forgot to register it."); + } + return state; + } + + @Override + public void output(OutputTag outputTag, X value) { + checkArgument(outputTag != null, "OutputTag must not be null."); + output.collect(outputTag, new StreamRecord<>(value, element.getTimestamp())); + } + + @Override + public long currentProcessingTime() { + return timerService.getCurrentProcessingTime(); + } + + @Override + public long currentWatermark() { + return currentWatermark; + } + } + + private class ReadOnlyContextImpl extends BroadcastProcessFunction.ReadOnlyContext { + + private final Map, BroadcastState> states; + + private final ProcessingTimeService timerService; + + private StreamRecord element; + + ReadOnlyContextImpl( + final BroadcastProcessFunction function, + final Map, BroadcastState> broadcastStates, + final ProcessingTimeService timerService) { + + function.super(); + this.states = Preconditions.checkNotNull(broadcastStates); + this.timerService = Preconditions.checkNotNull(timerService); + } + + void setElement(StreamRecord e) { + this.element = e; + } + + @Override + public Long timestamp() { + checkState(element != null); + return element.hasTimestamp() ? element.getTimestamp() : null; + } + + @Override + public void output(OutputTag outputTag, X value) { + checkArgument(outputTag != null, "OutputTag must not be null."); + output.collect(outputTag, new StreamRecord<>(value, element.getTimestamp())); + } + + @Override + public long currentProcessingTime() { + return timerService.getCurrentProcessingTime(); + } + + @Override + public long currentWatermark() { + return currentWatermark; + } + + @Override + public ReadOnlyBroadcastState getBroadcastState(MapStateDescriptor stateDescriptor) { + Preconditions.checkNotNull(stateDescriptor); + ReadOnlyBroadcastState state = (ReadOnlyBroadcastState) states.get(stateDescriptor); + if (state == null) { + throw new IllegalArgumentException("The requested state does not exist. " + + "Check for typos in your state descriptor, or specify the state descriptor " + + "in the datastream.broadcast(...) call if you forgot to register it."); + } + return state; + } + } +} diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/TwoInputTransformation.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/TwoInputTransformation.java index 5ee055c97e3b2..1c7592194fd32 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/TwoInputTransformation.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/TwoInputTransformation.java @@ -31,8 +31,8 @@ /** * This Transformation represents the application of a - * {@link org.apache.flink.streaming.api.operators.TwoInputStreamOperator} to two input - * {@code StreamTransformations}. The result is again only one stream. + * {@link TwoInputStreamOperator} to two input {@code StreamTransformations}. + * The result is again only one stream. * * @param The type of the elements in the first input {@code StreamTransformation} * @param The type of the elements in the second input {@code StreamTransformation} diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/DataStreamTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/DataStreamTest.java index 6fb06d3aa92c8..59f54b5d4858b 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/DataStreamTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/DataStreamTest.java @@ -25,6 +25,7 @@ import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.api.common.functions.Partitioner; import org.apache.flink.api.common.operators.ResourceSpec; +import org.apache.flink.api.common.state.MapStateDescriptor; import org.apache.flink.api.common.typeinfo.BasicArrayTypeInfo; import org.apache.flink.api.common.typeinfo.BasicTypeInfo; import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; @@ -37,6 +38,7 @@ import org.apache.flink.api.java.typeutils.TupleTypeInfo; import org.apache.flink.api.java.typeutils.TypeExtractor; import org.apache.flink.streaming.api.collector.selector.OutputSelector; +import org.apache.flink.streaming.api.datastream.BroadcastStream; import org.apache.flink.streaming.api.datastream.ConnectedStreams; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.datastream.DataStreamSink; @@ -45,9 +47,12 @@ import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator; import org.apache.flink.streaming.api.datastream.SplitStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks; import org.apache.flink.streaming.api.functions.ProcessFunction; +import org.apache.flink.streaming.api.functions.co.BroadcastProcessFunction; import org.apache.flink.streaming.api.functions.co.CoFlatMapFunction; import org.apache.flink.streaming.api.functions.co.CoMapFunction; +import org.apache.flink.streaming.api.functions.co.KeyedBroadcastProcessFunction; import org.apache.flink.streaming.api.functions.sink.DiscardingSink; import org.apache.flink.streaming.api.functions.sink.SinkFunction; import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction; @@ -57,6 +62,7 @@ import org.apache.flink.streaming.api.operators.KeyedProcessOperator; import org.apache.flink.streaming.api.operators.ProcessOperator; import org.apache.flink.streaming.api.operators.StreamOperator; +import org.apache.flink.streaming.api.watermark.Watermark; import org.apache.flink.streaming.api.windowing.assigners.GlobalWindows; import org.apache.flink.streaming.api.windowing.triggers.CountTrigger; import org.apache.flink.streaming.api.windowing.triggers.PurgingTrigger; @@ -78,8 +84,12 @@ import org.junit.Test; import org.junit.rules.ExpectedException; +import javax.annotation.Nullable; + import java.lang.reflect.Method; +import java.util.HashMap; import java.util.List; +import java.util.Map; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; @@ -753,6 +763,182 @@ public void onTimer( assertTrue(getOperatorForDataStream(processed) instanceof ProcessOperator); } + @Test + public void testConnectWithBroadcastTranslation() throws Exception { + + final Map expected = new HashMap<>(); + expected.put(0L, "test:0"); + expected.put(1L, "test:1"); + expected.put(2L, "test:2"); + expected.put(3L, "test:3"); + expected.put(4L, "test:4"); + expected.put(5L, "test:5"); + + final StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime); + + final DataStream srcOne = env.generateSequence(0L, 5L) + .assignTimestampsAndWatermarks(new CustomWmEmitter() { + + @Override + public long extractTimestamp(Long element, long previousElementTimestamp) { + return element; + } + }).keyBy((KeySelector) value -> value); + + final DataStream srcTwo = env.fromCollection(expected.values()) + .assignTimestampsAndWatermarks(new CustomWmEmitter() { + @Override + public long extractTimestamp(String element, long previousElementTimestamp) { + return Long.parseLong(element.split(":")[1]); + } + }); + + final BroadcastStream broadcast = srcTwo.broadcast(TestBroadcastProcessFunction.DESCRIPTOR); + + // the timestamp should be high enough to trigger the timer after all the elements arrive. + final DataStream output = srcOne.connect(broadcast).process( + new TestBroadcastProcessFunction(100000L, expected)); + + output.addSink(new DiscardingSink<>()); + env.execute(); + } + + private abstract static class CustomWmEmitter implements AssignerWithPunctuatedWatermarks { + + @Nullable + @Override + public Watermark checkAndGetNextWatermark(T lastElement, long extractedTimestamp) { + return new Watermark(extractedTimestamp); + } + } + + private static class TestBroadcastProcessFunction extends KeyedBroadcastProcessFunction { + + private final Map expectedState; + + private final long timerTimestamp; + + static final MapStateDescriptor DESCRIPTOR = new MapStateDescriptor<>( + "broadcast-state", BasicTypeInfo.LONG_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO + ); + + TestBroadcastProcessFunction( + final long timerTS, + final Map expectedBroadcastState + ) { + expectedState = expectedBroadcastState; + timerTimestamp = timerTS; + } + + @Override + public void processElement(Long value, KeyedReadOnlyContext ctx, Collector out) throws Exception { + ctx.timerService().registerEventTimeTimer(timerTimestamp); + } + + @Override + public void processBroadcastElement(String value, Context ctx, Collector out) throws Exception { + long key = Long.parseLong(value.split(":")[1]); + ctx.getBroadcastState(DESCRIPTOR).put(key, value); + } + + @Override + public void onTimer(long timestamp, OnTimerContext ctx, Collector out) throws Exception { + Map map = new HashMap<>(); + for (Map.Entry entry : ctx.getBroadcastState(DESCRIPTOR).immutableEntries()) { + map.put(entry.getKey(), entry.getValue()); + } + Assert.assertEquals(expectedState, map); + } + } + + /** + * Tests that with a {@link KeyedStream} we have to provide a {@link KeyedBroadcastProcessFunction}. + */ + @Test(expected = IllegalArgumentException.class) + public void testFailedTranslationOnKeyed() { + + final MapStateDescriptor descriptor = new MapStateDescriptor<>( + "broadcast", BasicTypeInfo.LONG_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO + ); + + final StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + final DataStream srcOne = env.generateSequence(0L, 5L) + .assignTimestampsAndWatermarks(new CustomWmEmitter() { + + @Override + public long extractTimestamp(Long element, long previousElementTimestamp) { + return element; + } + }).keyBy((KeySelector) value -> value); + + final DataStream srcTwo = env.fromElements("Test:0", "Test:1", "Test:2", "Test:3", "Test:4", "Test:5") + .assignTimestampsAndWatermarks(new CustomWmEmitter() { + @Override + public long extractTimestamp(String element, long previousElementTimestamp) { + return Long.parseLong(element.split(":")[1]); + } + }); + + BroadcastStream broadcast = srcTwo.broadcast(descriptor); + srcOne.connect(broadcast) + .process(new BroadcastProcessFunction() { + @Override + public void processBroadcastElement(String value, Context ctx, Collector out) throws Exception { + // do nothing + } + + @Override + public void processElement(Long value, ReadOnlyContext ctx, Collector out) throws Exception { + // do nothing + } + }); + } + + /** + * Tests that with a non-keyed stream we have to provide a {@link BroadcastProcessFunction}. + */ + @Test(expected = IllegalArgumentException.class) + public void testFailedTranslationOnNonKeyed() { + + final MapStateDescriptor descriptor = new MapStateDescriptor<>( + "broadcast", BasicTypeInfo.LONG_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO + ); + + final StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + final DataStream srcOne = env.generateSequence(0L, 5L) + .assignTimestampsAndWatermarks(new CustomWmEmitter() { + + @Override + public long extractTimestamp(Long element, long previousElementTimestamp) { + return element; + } + }); + + final DataStream srcTwo = env.fromElements("Test:0", "Test:1", "Test:2", "Test:3", "Test:4", "Test:5") + .assignTimestampsAndWatermarks(new CustomWmEmitter() { + @Override + public long extractTimestamp(String element, long previousElementTimestamp) { + return Long.parseLong(element.split(":")[1]); + } + }); + + BroadcastStream broadcast = srcTwo.broadcast(descriptor); + srcOne.connect(broadcast) + .process(new KeyedBroadcastProcessFunction() { + + @Override + public void processBroadcastElement(String value, Context ctx, Collector out) throws Exception { + // do nothing + } + + @Override + public void processElement(Long value, KeyedReadOnlyContext ctx, Collector out) throws Exception { + // do nothing + } + }); + } + @Test public void operatorTest() { StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithKeyedOperatorTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithKeyedOperatorTest.java new file mode 100644 index 0000000000000..3398d14b58126 --- /dev/null +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithKeyedOperatorTest.java @@ -0,0 +1,655 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + *

+ * http://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.api.operators.co; + +import org.apache.flink.api.common.state.MapStateDescriptor; +import org.apache.flink.api.common.state.ValueStateDescriptor; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.streaming.api.functions.co.KeyedBroadcastProcessFunction; +import org.apache.flink.streaming.api.watermark.Watermark; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.streaming.runtime.tasks.OperatorStateHandles; +import org.apache.flink.streaming.util.AbstractStreamOperatorTestHarness; +import org.apache.flink.streaming.util.KeyedTwoInputStreamOperatorTestHarness; +import org.apache.flink.streaming.util.TestHarnessUtil; +import org.apache.flink.streaming.util.TwoInputStreamOperatorTestHarness; +import org.apache.flink.util.Collector; +import org.apache.flink.util.OutputTag; +import org.apache.flink.util.Preconditions; + +import org.junit.Assert; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Queue; +import java.util.Set; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.function.Function; + +/** + * Tests for the {@link CoBroadcastWithKeyedOperator}. + */ +public class CoBroadcastWithKeyedOperatorTest { + + private static final MapStateDescriptor STATE_DESCRIPTOR = + new MapStateDescriptor<>( + "broadcast-state", + BasicTypeInfo.STRING_TYPE_INFO, + BasicTypeInfo.INT_TYPE_INFO + ); + + @Test + public void testFunctionWithTimer() throws Exception { + + try ( + TwoInputStreamOperatorTestHarness testHarness = getInitializedTestHarness( + BasicTypeInfo.STRING_TYPE_INFO, + new IdentityKeySelector<>(), + new FunctionWithTimerOnKeyed(41L)) + ) { + testHarness.processWatermark1(new Watermark(10L)); + testHarness.processWatermark2(new Watermark(10L)); + testHarness.processElement2(new StreamRecord<>(5, 12L)); + + testHarness.processWatermark1(new Watermark(40L)); + testHarness.processWatermark2(new Watermark(40L)); + testHarness.processElement1(new StreamRecord<>("6", 13L)); + testHarness.processElement1(new StreamRecord<>("6", 15L)); + + testHarness.processWatermark1(new Watermark(50L)); + testHarness.processWatermark2(new Watermark(50L)); + + Queue expectedOutput = new ConcurrentLinkedQueue<>(); + + expectedOutput.add(new Watermark(10L)); + expectedOutput.add(new StreamRecord<>("BR:5 WM:10 TS:12", 12L)); + expectedOutput.add(new Watermark(40L)); + expectedOutput.add(new StreamRecord<>("NON-BR:6 WM:40 TS:13", 13L)); + expectedOutput.add(new StreamRecord<>("NON-BR:6 WM:40 TS:15", 15L)); + expectedOutput.add(new StreamRecord<>("TIMER:41", 41L)); + expectedOutput.add(new Watermark(50L)); + + TestHarnessUtil.assertOutputEquals("Output was not correct.", expectedOutput, testHarness.getOutput()); + } + } + + /** + * {@link KeyedBroadcastProcessFunction} that registers a timer and emits + * for every element the watermark and the timestamp of the element. + */ + private static class FunctionWithTimerOnKeyed extends KeyedBroadcastProcessFunction { + + private static final long serialVersionUID = 7496674620398203933L; + + private final long timerTS; + + FunctionWithTimerOnKeyed(long timerTS) { + this.timerTS = timerTS; + } + + @Override + public void processBroadcastElement(Integer value, Context ctx, Collector out) throws Exception { + out.collect("BR:" + value + " WM:" + ctx.currentWatermark() + " TS:" + ctx.timestamp()); + } + + @Override + public void processElement(String value, KeyedReadOnlyContext ctx, Collector out) throws Exception { + ctx.timerService().registerEventTimeTimer(timerTS); + out.collect("NON-BR:" + value + " WM:" + ctx.currentWatermark() + " TS:" + ctx.timestamp()); + } + + @Override + public void onTimer(long timestamp, OnTimerContext ctx, Collector out) throws Exception { + out.collect("TIMER:" + timestamp); + } + } + + @Test + public void testSideOutput() throws Exception { + try ( + TwoInputStreamOperatorTestHarness testHarness = getInitializedTestHarness( + BasicTypeInfo.STRING_TYPE_INFO, + new IdentityKeySelector<>(), + new FunctionWithSideOutput()) + ) { + + testHarness.processWatermark1(new Watermark(10L)); + testHarness.processWatermark2(new Watermark(10L)); + testHarness.processElement2(new StreamRecord<>(5, 12L)); + + testHarness.processWatermark1(new Watermark(40L)); + testHarness.processWatermark2(new Watermark(40L)); + testHarness.processElement1(new StreamRecord<>("6", 13L)); + testHarness.processElement1(new StreamRecord<>("6", 15L)); + + testHarness.processWatermark1(new Watermark(50L)); + testHarness.processWatermark2(new Watermark(50L)); + + Queue> expectedBr = new ConcurrentLinkedQueue<>(); + expectedBr.add(new StreamRecord<>("BR:5 WM:10 TS:12", 12L)); + + Queue> expectedNonBr = new ConcurrentLinkedQueue<>(); + expectedNonBr.add(new StreamRecord<>("NON-BR:6 WM:40 TS:13", 13L)); + expectedNonBr.add(new StreamRecord<>("NON-BR:6 WM:40 TS:15", 15L)); + + TestHarnessUtil.assertOutputEquals( + "Wrong Side Output", + expectedBr, + testHarness.getSideOutput(FunctionWithSideOutput.BROADCAST_TAG)); + + TestHarnessUtil.assertOutputEquals( + "Wrong Side Output", + expectedNonBr, + testHarness.getSideOutput(FunctionWithSideOutput.NON_BROADCAST_TAG)); + } + } + + /** + * {@link KeyedBroadcastProcessFunction} that emits elements on side outputs. + */ + private static class FunctionWithSideOutput extends KeyedBroadcastProcessFunction { + + private static final long serialVersionUID = 7496674620398203933L; + + static final OutputTag BROADCAST_TAG = new OutputTag("br-out") { + private static final long serialVersionUID = -6899484480421899631L; + }; + + static final OutputTag NON_BROADCAST_TAG = new OutputTag("non-br-out") { + private static final long serialVersionUID = 3837387110613831791L; + }; + + @Override + public void processBroadcastElement(Integer value, Context ctx, Collector out) throws Exception { + ctx.output(BROADCAST_TAG, "BR:" + value + " WM:" + ctx.currentWatermark() + " TS:" + ctx.timestamp()); + } + + @Override + public void processElement(String value, KeyedReadOnlyContext ctx, Collector out) throws Exception { + ctx.output(NON_BROADCAST_TAG, "NON-BR:" + value + " WM:" + ctx.currentWatermark() + " TS:" + ctx.timestamp()); + } + } + + @Test + public void testFunctionWithBroadcastState() throws Exception { + + final Map expectedBroadcastState = new HashMap<>(); + expectedBroadcastState.put("5.key", 5); + expectedBroadcastState.put("34.key", 34); + expectedBroadcastState.put("53.key", 53); + expectedBroadcastState.put("12.key", 12); + expectedBroadcastState.put("98.key", 98); + + try ( + TwoInputStreamOperatorTestHarness testHarness = getInitializedTestHarness( + BasicTypeInfo.STRING_TYPE_INFO, + new IdentityKeySelector<>(), + new FunctionWithBroadcastState("key", expectedBroadcastState, 41L)) + ) { + testHarness.processWatermark1(new Watermark(10L)); + testHarness.processWatermark2(new Watermark(10L)); + + testHarness.processElement2(new StreamRecord<>(5, 10L)); + testHarness.processElement2(new StreamRecord<>(34, 12L)); + testHarness.processElement2(new StreamRecord<>(53, 15L)); + testHarness.processElement2(new StreamRecord<>(12, 16L)); + testHarness.processElement2(new StreamRecord<>(98, 19L)); + + testHarness.processElement1(new StreamRecord<>("trigger", 13L)); + + testHarness.processElement2(new StreamRecord<>(51, 21L)); + + testHarness.processWatermark1(new Watermark(50L)); + testHarness.processWatermark2(new Watermark(50L)); + + Queue output = testHarness.getOutput(); + Assert.assertEquals(3L, output.size()); + + Object firstRawWm = output.poll(); + Assert.assertTrue(firstRawWm instanceof Watermark); + Watermark firstWm = (Watermark) firstRawWm; + Assert.assertEquals(10L, firstWm.getTimestamp()); + + Object rawOutputElem = output.poll(); + Assert.assertTrue(rawOutputElem instanceof StreamRecord); + StreamRecord outputRec = (StreamRecord) rawOutputElem; + Assert.assertTrue(outputRec.getValue() instanceof String); + String outputElem = (String) outputRec.getValue(); + + expectedBroadcastState.put("51.key", 51); + List> expectedEntries = new ArrayList<>(); + expectedEntries.addAll(expectedBroadcastState.entrySet()); + String expected = "TS:41 " + mapToString(expectedEntries); + Assert.assertEquals(expected, outputElem); + + Object secondRawWm = output.poll(); + Assert.assertTrue(secondRawWm instanceof Watermark); + Watermark secondWm = (Watermark) secondRawWm; + Assert.assertEquals(50L, secondWm.getTimestamp()); + } + } + + private static class FunctionWithBroadcastState extends KeyedBroadcastProcessFunction { + + private static final long serialVersionUID = 7496674620398203933L; + + private final String keyPostfix; + private final Map expectedBroadcastState; + private final long timerTs; + + FunctionWithBroadcastState( + final String keyPostfix, + final Map expectedBroadcastState, + final long timerTs + ) { + this.keyPostfix = Preconditions.checkNotNull(keyPostfix); + this.expectedBroadcastState = Preconditions.checkNotNull(expectedBroadcastState); + this.timerTs = timerTs; + } + + @Override + public void processBroadcastElement(Integer value, Context ctx, Collector out) throws Exception { + // put an element in the broadcast state + final String key = value + "." + keyPostfix; + ctx.getBroadcastState(STATE_DESCRIPTOR).put(key, value); + } + + @Override + public void processElement(String value, KeyedReadOnlyContext ctx, Collector out) throws Exception { + Iterable> broadcastStateIt = ctx.getBroadcastState(STATE_DESCRIPTOR).immutableEntries(); + Iterator> iter = broadcastStateIt.iterator(); + + for (int i = 0; i < expectedBroadcastState.size(); i++) { + Assert.assertTrue(iter.hasNext()); + + Map.Entry entry = iter.next(); + Assert.assertTrue(expectedBroadcastState.containsKey(entry.getKey())); + Assert.assertEquals(expectedBroadcastState.get(entry.getKey()), entry.getValue()); + } + + Assert.assertFalse(iter.hasNext()); + + ctx.timerService().registerEventTimeTimer(timerTs); + } + + @Override + public void onTimer(long timestamp, OnTimerContext ctx, Collector out) throws Exception { + final Iterator> iter = ctx.getBroadcastState(STATE_DESCRIPTOR).immutableEntries().iterator(); + + final List> map = new ArrayList<>(); + while (iter.hasNext()) { + map.add(iter.next()); + } + final String mapToStr = mapToString(map); + out.collect("TS:" + timestamp + " " + mapToStr); + } + } + + @Test + public void testScaleUp() throws Exception { + final Set keysToRegister = new HashSet<>(); + keysToRegister.add("test1"); + keysToRegister.add("test2"); + keysToRegister.add("test3"); + + final OperatorStateHandles mergedSnapshot; + + try ( + TwoInputStreamOperatorTestHarness testHarness1 = getInitializedTestHarness( + BasicTypeInfo.STRING_TYPE_INFO, + new IdentityKeySelector<>(), + new TestFunctionWithOutput(keysToRegister), + 10, + 2, + 0); + + TwoInputStreamOperatorTestHarness testHarness2 = getInitializedTestHarness( + BasicTypeInfo.STRING_TYPE_INFO, + new IdentityKeySelector<>(), + new TestFunctionWithOutput(keysToRegister), + 10, + 2, + 1) + + ) { + + // make sure all operators have the same state + testHarness1.processElement2(new StreamRecord<>(3)); + testHarness2.processElement2(new StreamRecord<>(3)); + + mergedSnapshot = AbstractStreamOperatorTestHarness.repackageState( + testHarness1.snapshot(0L, 0L), + testHarness2.snapshot(0L, 0L) + ); + } + + final Set expected = new HashSet<>(3); + expected.add("test1=3"); + expected.add("test2=3"); + expected.add("test3=3"); + + try ( + TwoInputStreamOperatorTestHarness testHarness1 = getInitializedTestHarness( + BasicTypeInfo.STRING_TYPE_INFO, + new IdentityKeySelector<>(), + new TestFunctionWithOutput(keysToRegister), + 10, + 3, + 0, + mergedSnapshot); + + TwoInputStreamOperatorTestHarness testHarness2 = getInitializedTestHarness( + BasicTypeInfo.STRING_TYPE_INFO, + new IdentityKeySelector<>(), + new TestFunctionWithOutput(keysToRegister), + 10, + 3, + 1, + mergedSnapshot); + + TwoInputStreamOperatorTestHarness testHarness3 = getInitializedTestHarness( + BasicTypeInfo.STRING_TYPE_INFO, + new IdentityKeySelector<>(), + new TestFunctionWithOutput(keysToRegister), + 10, + 3, + 2, + mergedSnapshot) + ) { + testHarness1.processElement1(new StreamRecord<>("trigger")); + testHarness2.processElement1(new StreamRecord<>("trigger")); + testHarness3.processElement1(new StreamRecord<>("trigger")); + + Queue output1 = testHarness1.getOutput(); + Queue output2 = testHarness2.getOutput(); + Queue output3 = testHarness3.getOutput(); + + Assert.assertEquals(expected.size(), output1.size()); + for (Object o: output1) { + StreamRecord rec = (StreamRecord) o; + Assert.assertTrue(expected.contains(rec.getValue())); + } + + Assert.assertEquals(expected.size(), output2.size()); + for (Object o: output2) { + StreamRecord rec = (StreamRecord) o; + Assert.assertTrue(expected.contains(rec.getValue())); + } + + Assert.assertEquals(expected.size(), output3.size()); + for (Object o: output3) { + StreamRecord rec = (StreamRecord) o; + Assert.assertTrue(expected.contains(rec.getValue())); + } + } + } + + @Test + public void testScaleDown() throws Exception { + final Set keysToRegister = new HashSet<>(); + keysToRegister.add("test1"); + keysToRegister.add("test2"); + keysToRegister.add("test3"); + + final OperatorStateHandles mergedSnapshot; + + try ( + TwoInputStreamOperatorTestHarness testHarness1 = getInitializedTestHarness( + BasicTypeInfo.STRING_TYPE_INFO, + new IdentityKeySelector<>(), + new TestFunctionWithOutput(keysToRegister), + 10, + 3, + 0); + + TwoInputStreamOperatorTestHarness testHarness2 = getInitializedTestHarness( + BasicTypeInfo.STRING_TYPE_INFO, + new IdentityKeySelector<>(), + new TestFunctionWithOutput(keysToRegister), + 10, + 3, + 1); + + TwoInputStreamOperatorTestHarness testHarness3 = getInitializedTestHarness( + BasicTypeInfo.STRING_TYPE_INFO, + new IdentityKeySelector<>(), + new TestFunctionWithOutput(keysToRegister), + 10, + 3, + 2) + ) { + + // make sure all operators have the same state + testHarness1.processElement2(new StreamRecord<>(3)); + testHarness2.processElement2(new StreamRecord<>(3)); + testHarness3.processElement2(new StreamRecord<>(3)); + + mergedSnapshot = AbstractStreamOperatorTestHarness.repackageState( + testHarness1.snapshot(0L, 0L), + testHarness2.snapshot(0L, 0L), + testHarness3.snapshot(0L, 0L) + ); + } + + final Set expected = new HashSet<>(3); + expected.add("test1=3"); + expected.add("test2=3"); + expected.add("test3=3"); + + try ( + TwoInputStreamOperatorTestHarness testHarness1 = getInitializedTestHarness( + BasicTypeInfo.STRING_TYPE_INFO, + new IdentityKeySelector<>(), + new TestFunctionWithOutput(keysToRegister), + 10, + 2, + 0, + mergedSnapshot); + + TwoInputStreamOperatorTestHarness testHarness2 = getInitializedTestHarness( + BasicTypeInfo.STRING_TYPE_INFO, + new IdentityKeySelector<>(), + new TestFunctionWithOutput(keysToRegister), + 10, + 2, + 1, + mergedSnapshot) + ) { + + testHarness1.processElement1(new StreamRecord<>("trigger")); + testHarness2.processElement1(new StreamRecord<>("trigger")); + + Queue output1 = testHarness1.getOutput(); + Queue output2 = testHarness2.getOutput(); + + Assert.assertEquals(expected.size(), output1.size()); + for (Object o: output1) { + StreamRecord rec = (StreamRecord) o; + Assert.assertTrue(expected.contains(rec.getValue())); + } + + Assert.assertEquals(expected.size(), output2.size()); + for (Object o: output2) { + StreamRecord rec = (StreamRecord) o; + Assert.assertTrue(expected.contains(rec.getValue())); + } + } + } + + private static class TestFunctionWithOutput extends KeyedBroadcastProcessFunction { + + private static final long serialVersionUID = 7496674620398203933L; + + private final Set keysToRegister; + + TestFunctionWithOutput(Set keysToRegister) { + this.keysToRegister = Preconditions.checkNotNull(keysToRegister); + } + + @Override + public void processBroadcastElement(Integer value, Context ctx, Collector out) throws Exception { + // put an element in the broadcast state + for (String k : keysToRegister) { + ctx.getBroadcastState(STATE_DESCRIPTOR).put(k, value); + } + } + + @Override + public void processElement(String value, KeyedReadOnlyContext ctx, Collector out) throws Exception { + for (Map.Entry entry : ctx.getBroadcastState(STATE_DESCRIPTOR).immutableEntries()) { + out.collect(entry.toString()); + } + } + } + + @Test + public void testNoKeyedStateOnBroadcastSide() throws Exception { + + boolean exceptionThrown = false; + + try ( + TwoInputStreamOperatorTestHarness testHarness = getInitializedTestHarness( + BasicTypeInfo.STRING_TYPE_INFO, + new IdentityKeySelector<>(), + new KeyedBroadcastProcessFunction() { + + private static final long serialVersionUID = -1725365436500098384L; + + private final ValueStateDescriptor valueState = new ValueStateDescriptor<>("any", BasicTypeInfo.STRING_TYPE_INFO); + + @Override + public void processBroadcastElement(Integer value, Context ctx, Collector out) throws Exception { + getRuntimeContext().getState(valueState).value(); // this should fail + } + + @Override + public void processElement(String value, KeyedReadOnlyContext ctx, Collector out) throws Exception { + // do nothing + } + }) + ) { + testHarness.processWatermark1(new Watermark(10L)); + testHarness.processWatermark2(new Watermark(10L)); + testHarness.processElement2(new StreamRecord<>(5, 12L)); + } catch (NullPointerException e) { + Assert.assertEquals("No key set. This method should not be called outside of a keyed context.", e.getMessage()); + exceptionThrown = true; + } + + if (!exceptionThrown) { + Assert.fail("No exception thrown"); + } + } + + private static class IdentityKeySelector implements KeySelector { + private static final long serialVersionUID = 1L; + + @Override + public T getKey(T value) throws Exception { + return value; + } + } + + private static TwoInputStreamOperatorTestHarness getInitializedTestHarness( + final TypeInformation keyTypeInfo, + final KeySelector keyKeySelector, + final KeyedBroadcastProcessFunction function) throws Exception { + + return getInitializedTestHarness( + keyTypeInfo, + keyKeySelector, + function, + 1, + 1, + 0); + } + + private static TwoInputStreamOperatorTestHarness getInitializedTestHarness( + final TypeInformation keyTypeInfo, + final KeySelector keyKeySelector, + final KeyedBroadcastProcessFunction function, + final int maxParallelism, + final int numTasks, + final int taskIdx) throws Exception { + + return getInitializedTestHarness( + keyTypeInfo, + keyKeySelector, + function, + maxParallelism, + numTasks, + taskIdx, + null); + } + + private static TwoInputStreamOperatorTestHarness getInitializedTestHarness( + final TypeInformation keyTypeInfo, + final KeySelector keyKeySelector, + final KeyedBroadcastProcessFunction function, + final int maxParallelism, + final int numTasks, + final int taskIdx, + final OperatorStateHandles initState) throws Exception { + + final TwoInputStreamOperatorTestHarness testHarness = + new KeyedTwoInputStreamOperatorTestHarness<>( + new CoBroadcastWithKeyedOperator<>( + Preconditions.checkNotNull(function), + Collections.singletonList(STATE_DESCRIPTOR)), + keyKeySelector, + null, + keyTypeInfo, + maxParallelism, + numTasks, + taskIdx + ); + + testHarness.setup(); + testHarness.initializeState(initState); + testHarness.open(); + + return testHarness; + } + + private static String mapToString(List> entries) { + entries.sort( + Comparator.comparing( + (Function, String>) Map.Entry::getKey + ).thenComparingInt(Map.Entry::getValue) + ); + + final StringBuilder builder = new StringBuilder(); + for (Map.Entry entry : entries) { + builder.append(' ') + .append(entry.getKey()) + .append('=') + .append(entry.getValue()); + } + return builder.toString(); + } +} diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithNonKeyedOperatorTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithNonKeyedOperatorTest.java new file mode 100644 index 0000000000000..066a80ff95ac0 --- /dev/null +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithNonKeyedOperatorTest.java @@ -0,0 +1,497 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + *

+ * http://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.api.operators.co; + +import org.apache.flink.api.common.state.MapStateDescriptor; +import org.apache.flink.api.common.state.ValueStateDescriptor; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.streaming.api.functions.co.BroadcastProcessFunction; +import org.apache.flink.streaming.api.watermark.Watermark; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.streaming.runtime.tasks.OperatorStateHandles; +import org.apache.flink.streaming.util.AbstractStreamOperatorTestHarness; +import org.apache.flink.streaming.util.TestHarnessUtil; +import org.apache.flink.streaming.util.TwoInputStreamOperatorTestHarness; +import org.apache.flink.util.Collector; +import org.apache.flink.util.OutputTag; +import org.apache.flink.util.Preconditions; + +import org.junit.Assert; +import org.junit.Test; + +import java.util.Collections; +import java.util.HashSet; +import java.util.Map; +import java.util.Queue; +import java.util.Set; +import java.util.concurrent.ConcurrentLinkedQueue; + +/** + * Tests for the {@link CoBroadcastWithNonKeyedOperator}. + */ +public class CoBroadcastWithNonKeyedOperatorTest { + + private static final MapStateDescriptor STATE_DESCRIPTOR = + new MapStateDescriptor<>( + "broadcast-state", + BasicTypeInfo.STRING_TYPE_INFO, + BasicTypeInfo.INT_TYPE_INFO + ); + + @Test + public void testBroadcastState() throws Exception { + + final Set keysToRegister = new HashSet<>(); + keysToRegister.add("test1"); + keysToRegister.add("test2"); + keysToRegister.add("test3"); + + try ( + TwoInputStreamOperatorTestHarness testHarness = getInitializedTestHarness( + new TestFunction(keysToRegister)) + ) { + testHarness.processWatermark1(new Watermark(10L)); + testHarness.processWatermark2(new Watermark(10L)); + testHarness.processElement2(new StreamRecord<>(5, 12L)); + + testHarness.processWatermark1(new Watermark(40L)); + testHarness.processWatermark2(new Watermark(40L)); + testHarness.processElement1(new StreamRecord<>("6", 13L)); + testHarness.processElement1(new StreamRecord<>("6", 15L)); + + testHarness.processWatermark1(new Watermark(50L)); + testHarness.processWatermark2(new Watermark(50L)); + + Queue expectedOutput = new ConcurrentLinkedQueue<>(); + + expectedOutput.add(new Watermark(10L)); + expectedOutput.add(new StreamRecord<>("5WM:10 TS:12", 12L)); + expectedOutput.add(new Watermark(40L)); + expectedOutput.add(new StreamRecord<>("6WM:40 TS:13", 13L)); + expectedOutput.add(new StreamRecord<>("6WM:40 TS:15", 15L)); + expectedOutput.add(new Watermark(50L)); + + TestHarnessUtil.assertOutputEquals("Output was not correct.", expectedOutput, testHarness.getOutput()); + } + } + + private static class TestFunction extends BroadcastProcessFunction { + + private static final long serialVersionUID = 7496674620398203933L; + + private final Set keysToRegister; + + TestFunction(Set keysToRegister) { + this.keysToRegister = Preconditions.checkNotNull(keysToRegister); + } + + @Override + public void processBroadcastElement(Integer value, Context ctx, Collector out) throws Exception { + // put an element in the broadcast state + for (String k : keysToRegister) { + ctx.getBroadcastState(STATE_DESCRIPTOR).put(k, value); + } + out.collect(value + "WM:" + ctx.currentWatermark() + " TS:" + ctx.timestamp()); + } + + @Override + public void processElement(String value, ReadOnlyContext ctx, Collector out) throws Exception { + Set retrievedKeySet = new HashSet<>(); + for (Map.Entry entry : ctx.getBroadcastState(STATE_DESCRIPTOR).immutableEntries()) { + retrievedKeySet.add(entry.getKey()); + } + + Assert.assertEquals(keysToRegister, retrievedKeySet); + + out.collect(value + "WM:" + ctx.currentWatermark() + " TS:" + ctx.timestamp()); + } + } + + @Test + public void testSideOutput() throws Exception { + try ( + TwoInputStreamOperatorTestHarness testHarness = getInitializedTestHarness( + new FunctionWithSideOutput()) + ) { + + testHarness.processWatermark1(new Watermark(10L)); + testHarness.processWatermark2(new Watermark(10L)); + testHarness.processElement2(new StreamRecord<>(5, 12L)); + + testHarness.processWatermark1(new Watermark(40L)); + testHarness.processWatermark2(new Watermark(40L)); + testHarness.processElement1(new StreamRecord<>("6", 13L)); + testHarness.processElement1(new StreamRecord<>("6", 15L)); + + testHarness.processWatermark1(new Watermark(50L)); + testHarness.processWatermark2(new Watermark(50L)); + + ConcurrentLinkedQueue> expectedBr = new ConcurrentLinkedQueue<>(); + expectedBr.add(new StreamRecord<>("BR:5 WM:10 TS:12", 12L)); + + ConcurrentLinkedQueue> expectedNonBr = new ConcurrentLinkedQueue<>(); + expectedNonBr.add(new StreamRecord<>("NON-BR:6 WM:40 TS:13", 13L)); + expectedNonBr.add(new StreamRecord<>("NON-BR:6 WM:40 TS:15", 15L)); + + ConcurrentLinkedQueue> brSideOutput = testHarness.getSideOutput(FunctionWithSideOutput.BROADCAST_TAG); + ConcurrentLinkedQueue> nonBrSideOutput = testHarness.getSideOutput(FunctionWithSideOutput.NON_BROADCAST_TAG); + + TestHarnessUtil.assertOutputEquals("Wrong Side Output", expectedBr, brSideOutput); + TestHarnessUtil.assertOutputEquals("Wrong Side Output", expectedNonBr, nonBrSideOutput); + } + } + + /** + * {@link BroadcastProcessFunction} that emits elements on side outputs. + */ + private static class FunctionWithSideOutput extends BroadcastProcessFunction { + + private static final long serialVersionUID = 7496674620398203933L; + + static final OutputTag BROADCAST_TAG = new OutputTag("br-out") { + private static final long serialVersionUID = 8037335313997479800L; + }; + + static final OutputTag NON_BROADCAST_TAG = new OutputTag("non-br-out") { + private static final long serialVersionUID = -1092362442658548175L; + }; + + @Override + public void processBroadcastElement(Integer value, Context ctx, Collector out) throws Exception { + ctx.output(BROADCAST_TAG, "BR:" + value + " WM:" + ctx.currentWatermark() + " TS:" + ctx.timestamp()); + } + + @Override + public void processElement(String value, ReadOnlyContext ctx, Collector out) throws Exception { + ctx.output(NON_BROADCAST_TAG, "NON-BR:" + value + " WM:" + ctx.currentWatermark() + " TS:" + ctx.timestamp()); + } + } + + @Test + public void testScaleUp() throws Exception { + final Set keysToRegister = new HashSet<>(); + keysToRegister.add("test1"); + keysToRegister.add("test2"); + keysToRegister.add("test3"); + + final OperatorStateHandles mergedSnapshot; + + try ( + TwoInputStreamOperatorTestHarness testHarness1 = getInitializedTestHarness( + new TestFunctionWithOutput(keysToRegister), + 10, + 2, + 0); + + TwoInputStreamOperatorTestHarness testHarness2 = getInitializedTestHarness( + new TestFunctionWithOutput(keysToRegister), + 10, + 2, + 1) + ) { + // make sure all operators have the same state + testHarness1.processElement2(new StreamRecord<>(3)); + testHarness2.processElement2(new StreamRecord<>(3)); + + mergedSnapshot = AbstractStreamOperatorTestHarness.repackageState( + testHarness1.snapshot(0L, 0L), + testHarness2.snapshot(0L, 0L) + ); + } + + final Set expected = new HashSet<>(3); + expected.add("test1=3"); + expected.add("test2=3"); + expected.add("test3=3"); + + try ( + TwoInputStreamOperatorTestHarness testHarness1 = getInitializedTestHarness( + new TestFunctionWithOutput(keysToRegister), + 10, + 3, + 0, + mergedSnapshot); + + TwoInputStreamOperatorTestHarness testHarness2 = getInitializedTestHarness( + new TestFunctionWithOutput(keysToRegister), + 10, + 3, + 1, + mergedSnapshot); + + TwoInputStreamOperatorTestHarness testHarness3 = getInitializedTestHarness( + new TestFunctionWithOutput(keysToRegister), + 10, + 3, + 2, + mergedSnapshot) + ) { + testHarness1.processElement1(new StreamRecord<>("trigger")); + testHarness2.processElement1(new StreamRecord<>("trigger")); + testHarness3.processElement1(new StreamRecord<>("trigger")); + + Queue output1 = testHarness1.getOutput(); + Queue output2 = testHarness2.getOutput(); + Queue output3 = testHarness3.getOutput(); + + Assert.assertEquals(expected.size(), output1.size()); + for (Object o: output1) { + StreamRecord rec = (StreamRecord) o; + Assert.assertTrue(expected.contains(rec.getValue())); + } + + Assert.assertEquals(expected.size(), output2.size()); + for (Object o: output2) { + StreamRecord rec = (StreamRecord) o; + Assert.assertTrue(expected.contains(rec.getValue())); + } + + Assert.assertEquals(expected.size(), output3.size()); + for (Object o: output3) { + StreamRecord rec = (StreamRecord) o; + Assert.assertTrue(expected.contains(rec.getValue())); + } + } + } + + @Test + public void testScaleDown() throws Exception { + final Set keysToRegister = new HashSet<>(); + keysToRegister.add("test1"); + keysToRegister.add("test2"); + keysToRegister.add("test3"); + + final OperatorStateHandles mergedSnapshot; + + try ( + TwoInputStreamOperatorTestHarness testHarness1 = getInitializedTestHarness( + new TestFunctionWithOutput(keysToRegister), + 10, + 3, + 0); + + TwoInputStreamOperatorTestHarness testHarness2 = getInitializedTestHarness( + new TestFunctionWithOutput(keysToRegister), + 10, + 3, + 1); + + TwoInputStreamOperatorTestHarness testHarness3 = getInitializedTestHarness( + new TestFunctionWithOutput(keysToRegister), + 10, + 3, + 2) + ) { + + // make sure all operators have the same state + testHarness1.processElement2(new StreamRecord<>(3)); + testHarness2.processElement2(new StreamRecord<>(3)); + testHarness3.processElement2(new StreamRecord<>(3)); + + mergedSnapshot = AbstractStreamOperatorTestHarness.repackageState( + testHarness1.snapshot(0L, 0L), + testHarness2.snapshot(0L, 0L), + testHarness3.snapshot(0L, 0L) + ); + } + + final Set expected = new HashSet<>(3); + expected.add("test1=3"); + expected.add("test2=3"); + expected.add("test3=3"); + + try ( + TwoInputStreamOperatorTestHarness testHarness1 = getInitializedTestHarness( + new TestFunctionWithOutput(keysToRegister), + 10, + 2, + 0, + mergedSnapshot); + + TwoInputStreamOperatorTestHarness testHarness2 = getInitializedTestHarness( + new TestFunctionWithOutput(keysToRegister), + 10, + 2, + 1, + mergedSnapshot) + ) { + testHarness1.processElement1(new StreamRecord<>("trigger")); + testHarness2.processElement1(new StreamRecord<>("trigger")); + + Queue output1 = testHarness1.getOutput(); + Queue output2 = testHarness2.getOutput(); + + Assert.assertEquals(expected.size(), output1.size()); + for (Object o: output1) { + StreamRecord rec = (StreamRecord) o; + Assert.assertTrue(expected.contains(rec.getValue())); + } + + Assert.assertEquals(expected.size(), output2.size()); + for (Object o: output2) { + StreamRecord rec = (StreamRecord) o; + Assert.assertTrue(expected.contains(rec.getValue())); + } + } + } + + private static class TestFunctionWithOutput extends BroadcastProcessFunction { + + private static final long serialVersionUID = 7496674620398203933L; + + private final Set keysToRegister; + + TestFunctionWithOutput(Set keysToRegister) { + this.keysToRegister = Preconditions.checkNotNull(keysToRegister); + } + + @Override + public void processBroadcastElement(Integer value, Context ctx, Collector out) throws Exception { + // put an element in the broadcast state + for (String k : keysToRegister) { + ctx.getBroadcastState(STATE_DESCRIPTOR).put(k, value); + } + } + + @Override + public void processElement(String value, ReadOnlyContext ctx, Collector out) throws Exception { + for (Map.Entry entry : ctx.getBroadcastState(STATE_DESCRIPTOR).immutableEntries()) { + out.collect(entry.toString()); + } + } + } + + @Test + public void testNoKeyedStateOnBroadcastSide() throws Exception { + + boolean exceptionThrown = false; + + try ( + TwoInputStreamOperatorTestHarness testHarness = + getInitializedTestHarness( + new BroadcastProcessFunction() { + private static final long serialVersionUID = -1725365436500098384L; + + private final ValueStateDescriptor valueState = new ValueStateDescriptor<>("any", BasicTypeInfo.STRING_TYPE_INFO); + + @Override + public void processBroadcastElement(Integer value, Context ctx, Collector out) throws Exception { + getRuntimeContext().getState(valueState).value(); // this should fail + } + + @Override + public void processElement(String value, ReadOnlyContext ctx, Collector out) throws Exception { + // do nothing + } + }) + ) { + testHarness.processWatermark1(new Watermark(10L)); + testHarness.processWatermark2(new Watermark(10L)); + testHarness.processElement2(new StreamRecord<>(5, 12L)); + } catch (NullPointerException e) { + Assert.assertEquals("Keyed state can only be used on a 'keyed stream', i.e., after a 'keyBy()' operation.", e.getMessage()); + exceptionThrown = true; + } + + if (!exceptionThrown) { + Assert.fail("No exception thrown"); + } + } + + @Test + public void testNoKeyedStateOnNonBroadcastSide() throws Exception { + + boolean exceptionThrown = false; + + try ( + TwoInputStreamOperatorTestHarness testHarness = + getInitializedTestHarness( + new BroadcastProcessFunction() { + private static final long serialVersionUID = -1725365436500098384L; + + private final ValueStateDescriptor valueState = new ValueStateDescriptor<>("any", BasicTypeInfo.STRING_TYPE_INFO); + + @Override + public void processBroadcastElement(Integer value, Context ctx, Collector out) throws Exception { + // do nothing + } + + @Override + public void processElement(String value, ReadOnlyContext ctx, Collector out) throws Exception { + getRuntimeContext().getState(valueState).value(); // this should fail + } + }) + ) { + testHarness.processWatermark1(new Watermark(10L)); + testHarness.processWatermark2(new Watermark(10L)); + testHarness.processElement1(new StreamRecord<>("5", 12L)); + } catch (NullPointerException e) { + Assert.assertEquals("Keyed state can only be used on a 'keyed stream', i.e., after a 'keyBy()' operation.", e.getMessage()); + exceptionThrown = true; + } + + if (!exceptionThrown) { + Assert.fail("No exception thrown"); + } + } + + private static TwoInputStreamOperatorTestHarness getInitializedTestHarness( + final BroadcastProcessFunction function) throws Exception { + + return getInitializedTestHarness( + function, + 1, + 1, + 0); + } + + private static TwoInputStreamOperatorTestHarness getInitializedTestHarness( + final BroadcastProcessFunction function, + final int maxParallelism, + final int numTasks, + final int taskIdx) throws Exception { + + return getInitializedTestHarness( + function, + maxParallelism, + numTasks, + taskIdx, + null); + } + + private static TwoInputStreamOperatorTestHarness getInitializedTestHarness( + final BroadcastProcessFunction function, + final int maxParallelism, + final int numTasks, + final int taskIdx, + final OperatorStateHandles initState) throws Exception { + + TwoInputStreamOperatorTestHarness testHarness = new TwoInputStreamOperatorTestHarness<>( + new CoBroadcastWithNonKeyedOperator<>( + Preconditions.checkNotNull(function), + Collections.singletonList(STATE_DESCRIPTOR)), + maxParallelism, numTasks, taskIdx + ); + testHarness.setup(); + testHarness.initializeState(initState); + testHarness.open(); + + return testHarness; + } +} diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/TwoInputStreamOperatorTestHarness.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/TwoInputStreamOperatorTestHarness.java index d0bbf8f89ae62..7bb697331da6c 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/TwoInputStreamOperatorTestHarness.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/TwoInputStreamOperatorTestHarness.java @@ -29,7 +29,7 @@ * and watermarks into the operator. {@link java.util.Deque}s containing the emitted elements * and watermarks can be retrieved. you are free to modify these. */ -public class TwoInputStreamOperatorTestHarnessextends AbstractStreamOperatorTestHarness { +public class TwoInputStreamOperatorTestHarness extends AbstractStreamOperatorTestHarness { private final TwoInputStreamOperator twoInputOperator; From be16aff664a1040675787e19b74caddaffa7e0f2 Mon Sep 17 00:00:00 2001 From: kkloudas Date: Mon, 29 Jan 2018 16:17:24 +0100 Subject: [PATCH 4/6] [FLINK-8345] Add iterator of keyed state on broadcast side of connected streams. --- .../state/AbstractKeyedStateBackend.java | 33 +++++ .../runtime/state/KeyedStateBackend.java | 18 +++ .../runtime/state/KeyedStateFunction.java | 38 ++++++ .../datastream/BroadcastConnectedStream.java | 10 +- .../co/KeyedBroadcastProcessFunction.java | 39 +++++- .../co/CoBroadcastWithKeyedOperator.java | 39 ++++-- .../flink/streaming/api/DataStreamTest.java | 8 +- .../co/CoBroadcastWithKeyedOperatorTest.java | 128 +++++++++++++++--- 8 files changed, 274 insertions(+), 39 deletions(-) create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateFunction.java diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java index cc53c0cb895f9..d159d46f54ff8 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java @@ -279,6 +279,39 @@ public KeyGroupRange getKeyGroupRange() { return keyGroupRange; } + /** + * @see KeyedStateBackend + */ + @Override + public void applyToAllKeys( + final N namespace, + final TypeSerializer namespaceSerializer, + final StateDescriptor stateDescriptor, + final KeyedStateFunction function) throws Exception { + + try { + getKeys(stateDescriptor.getName(), namespace) + .forEach((K key) -> { + setCurrentKey(key); + try { + function.process( + key, + getPartitionedState( + namespace, + namespaceSerializer, + stateDescriptor) + ); + } catch (Throwable e) { + // we wrap the checked exception in an unchecked + // one and catch it (and re-throw it) later. + throw new RuntimeException(e); + } + }); + } catch (RuntimeException e) { + throw e; + } + } + /** * @see KeyedStateBackend */ diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java index c74cfcf20df4c..cbe40ee7b1080 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java @@ -38,6 +38,24 @@ public interface KeyedStateBackend extends InternalKeyContext { */ void setCurrentKey(K newKey); + /** + * Applies the provided {@link KeyedStateFunction} to the state with the provided + * {@link StateDescriptor} of all the currently active keys. + * + * @param namespace the namespace of the state. + * @param namespaceSerializer the serializer for the namespace. + * @param stateDescriptor the descriptor of the state to which the function is going to be applied. + * @param function the function to be applied to the keyed state. + * + * @param The type of the namespace. + * @param The type of the state. + */ + void applyToAllKeys( + final N namespace, + final TypeSerializer namespaceSerializer, + final StateDescriptor stateDescriptor, + final KeyedStateFunction function) throws Exception; + /** * @return A stream of all keys for the given state and namespace. Modifications to the state during iterating * over it keys are not supported. diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateFunction.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateFunction.java new file mode 100644 index 0000000000000..de23dec56565f --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateFunction.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.runtime.state; + +import org.apache.flink.api.common.state.State; + +/** + * A function to be applied to all keyed states. + * + *

This functionality is only available through the + * {@code BroadcastConnectedStream.process(final KeyedBroadcastProcessFunction function)}. + */ +public abstract class KeyedStateFunction { + + /** + * The actual method to be applied on each of the states. + * + * @param key a safe copy of the key (see {@link KeyedStateBackend#getCurrentKeySafe()}) + * whose state is being processed. + * @param state the state associated with the aforementioned key. + */ + public abstract void process(K key, S state) throws Exception; +} diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/BroadcastConnectedStream.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/BroadcastConnectedStream.java index aeb3bc27b0825..453c850fbb44c 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/BroadcastConnectedStream.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/BroadcastConnectedStream.java @@ -119,18 +119,19 @@ public TypeInformation getType2() { * {@link KeyedBroadcastProcessFunction} on them, thereby creating a transformed output stream. * * @param function The {@link KeyedBroadcastProcessFunction} that is called for each element in the stream. + * @param The type of the keys in the keyed stream. * @param The type of the output elements. * @return The transformed {@link DataStream}. */ @PublicEvolving - public SingleOutputStreamOperator process(final KeyedBroadcastProcessFunction function) { + public SingleOutputStreamOperator process(final KeyedBroadcastProcessFunction function) { TypeInformation outTypeInfo = TypeExtractor.getBinaryOperatorReturnType( function, KeyedBroadcastProcessFunction.class, - 0, 1, 2, + 3, TypeExtractor.NO_INDEX, TypeExtractor.NO_INDEX, TypeExtractor.NO_INDEX, @@ -148,12 +149,13 @@ public SingleOutputStreamOperator process(final KeyedBroadcastProcess * * @param function The {@link KeyedBroadcastProcessFunction} that is called for each element in the stream. * @param outTypeInfo The type of the output elements. + * @param The type of the keys in the keyed stream. * @param The type of the output elements. * @return The transformed {@link DataStream}. */ @PublicEvolving - public SingleOutputStreamOperator process( - final KeyedBroadcastProcessFunction function, + public SingleOutputStreamOperator process( + final KeyedBroadcastProcessFunction function, final TypeInformation outTypeInfo) { Preconditions.checkNotNull(function); diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/co/KeyedBroadcastProcessFunction.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/co/KeyedBroadcastProcessFunction.java index 9d14259b9aed3..4b9f13828e0de 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/co/KeyedBroadcastProcessFunction.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/co/KeyedBroadcastProcessFunction.java @@ -20,6 +20,9 @@ import org.apache.flink.annotation.PublicEvolving; import org.apache.flink.api.common.state.MapStateDescriptor; +import org.apache.flink.api.common.state.State; +import org.apache.flink.api.common.state.StateDescriptor; +import org.apache.flink.runtime.state.KeyedStateFunction; import org.apache.flink.streaming.api.TimeDomain; import org.apache.flink.streaming.api.TimerService; import org.apache.flink.util.Collector; @@ -36,7 +39,7 @@ * *

The user has to implement two methods: *

    - *
  1. the {@link #processBroadcastElement(Object, Context, Collector)} which will be applied to + *
  2. the {@link #processBroadcastElement(Object, KeyedContext, Collector)} which will be applied to * each element in the broadcast side *
  3. and the {@link #processElement(Object, KeyedReadOnlyContext, Collector)} which will be applied to the * non-broadcasted/keyed side. @@ -47,12 +50,13 @@ * {@code processElement()} has read-only access to the broadcast state, but can read/write to the keyed state and * register timers. * + * @param The key type of the input keyed stream. * @param The input type of the keyed (non-broadcast) side. * @param The input type of the broadcast side. * @param The output type of the operator. */ @PublicEvolving -public abstract class KeyedBroadcastProcessFunction extends BaseBroadcastProcessFunction { +public abstract class KeyedBroadcastProcessFunction extends BaseBroadcastProcessFunction { private static final long serialVersionUID = -2584726797564976453L; @@ -83,19 +87,22 @@ public abstract class KeyedBroadcastProcessFunction extends BaseB * *

    It can output zero or more elements using the {@link Collector} parameter, * query the current processing/event time, and also query and update the internal - * {@link org.apache.flink.api.common.state.BroadcastState broadcast state}. These can - * be done through the provided {@link Context}. + * {@link org.apache.flink.api.common.state.BroadcastState broadcast state}. In addition, it + * can register a {@link KeyedStateFunction function} to be applied to all keyed states on + * the local partition. These can be done through the provided {@link Context}. * The context is only valid during the invocation of this method, do not store it. * * @param value The stream element. * @param ctx A {@link Context} that allows querying the timestamp of the element, * querying the current processing/event time and updating the broadcast state. + * In addition, it allows the registration of a {@link KeyedStateFunction function} + * to be applied to all keyed state with a given {@link StateDescriptor} on the local partition. * The context is only valid during the invocation of this method, do not store it. * @param out The collector to emit resulting elements to * @throws Exception The function may throw exceptions which cause the streaming program * to fail and go into recovery. */ - public abstract void processBroadcastElement(final IN2 value, final Context ctx, final Collector out) throws Exception; + public abstract void processBroadcastElement(final IN2 value, final KeyedContext ctx, final Collector out) throws Exception; /** * Called when a timer set using {@link TimerService} fires. @@ -115,6 +122,28 @@ public void onTimer(final long timestamp, final OnTimerContext ctx, final Collec // the default implementation does nothing. } + /** + * A {@link BaseBroadcastProcessFunction.Context context} available to the broadcast side of + * a {@link org.apache.flink.streaming.api.datastream.BroadcastConnectedStream}. + * + *

    Apart from the basic functionality of a {@link BaseBroadcastProcessFunction.Context context}, + * this also allows to apply a {@link KeyedStateFunction} to the (local) states of all active keys + * in the your backend. + */ + public abstract class KeyedContext extends Context { + + /** + * Applies the provided {@code function} to the state + * associated with the provided {@code state descriptor}. + * + * @param stateDescriptor the descriptor of the state to be processed. + * @param function the function to be applied. + */ + public abstract void applyToKeyedState( + final StateDescriptor stateDescriptor, + final KeyedStateFunction function) throws Exception; + } + /** * A {@link BaseBroadcastProcessFunction.Context context} available to the keyed stream side of * a {@link org.apache.flink.streaming.api.datastream.BroadcastConnectedStream} (if any). diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithKeyedOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithKeyedOperator.java index 794b0db707dfa..4872c6116fbc2 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithKeyedOperator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithKeyedOperator.java @@ -22,12 +22,15 @@ import org.apache.flink.api.common.state.BroadcastState; import org.apache.flink.api.common.state.MapStateDescriptor; import org.apache.flink.api.common.state.ReadOnlyBroadcastState; +import org.apache.flink.api.common.state.State; +import org.apache.flink.api.common.state.StateDescriptor; +import org.apache.flink.runtime.state.KeyedStateBackend; +import org.apache.flink.runtime.state.KeyedStateFunction; import org.apache.flink.runtime.state.VoidNamespace; import org.apache.flink.runtime.state.VoidNamespaceSerializer; import org.apache.flink.streaming.api.SimpleTimerService; import org.apache.flink.streaming.api.TimeDomain; import org.apache.flink.streaming.api.TimerService; -import org.apache.flink.streaming.api.functions.co.BaseBroadcastProcessFunction; import org.apache.flink.streaming.api.functions.co.KeyedBroadcastProcessFunction; import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator; import org.apache.flink.streaming.api.operators.InternalTimer; @@ -56,7 +59,7 @@ */ @Internal public class CoBroadcastWithKeyedOperator - extends AbstractUdfStreamOperator> + extends AbstractUdfStreamOperator> implements TwoInputStreamOperator, Triggerable { private static final long serialVersionUID = 5926499536290284870L; @@ -74,7 +77,7 @@ public class CoBroadcastWithKeyedOperator private transient OnTimerContextImpl onTimerContext; public CoBroadcastWithKeyedOperator( - final KeyedBroadcastProcessFunction function, + final KeyedBroadcastProcessFunction function, final List> broadcastStateDescriptors) { super(function); this.broadcastStateDescriptors = Preconditions.checkNotNull(broadcastStateDescriptors); @@ -96,7 +99,7 @@ public void open() throws Exception { broadcastStates.put(descriptor, getOperatorStateBackend().getBroadcastState(descriptor)); } - rwContext = new ReadWriteContextImpl(userFunction, broadcastStates, timerService); + rwContext = new ReadWriteContextImpl(getKeyedStateBackend(), userFunction, broadcastStates, timerService); rContext = new ReadOnlyContextImpl(userFunction, broadcastStates, timerService); onTimerContext = new OnTimerContextImpl(userFunction, broadcastStates, timerService); } @@ -137,7 +140,9 @@ public void onProcessingTime(InternalTimer timer) throws Exce onTimerContext.timer = null; } - private class ReadWriteContextImpl extends BaseBroadcastProcessFunction.Context { + private class ReadWriteContextImpl extends KeyedBroadcastProcessFunction.KeyedContext { + + private final KeyedStateBackend keyedStateBackend; private final Map, BroadcastState> states; @@ -146,11 +151,13 @@ private class ReadWriteContextImpl extends BaseBroadcastProcessFunction.Context private StreamRecord element; ReadWriteContextImpl ( - final KeyedBroadcastProcessFunction function, + final KeyedStateBackend keyedStateBackend, + final KeyedBroadcastProcessFunction function, final Map, BroadcastState> broadcastStates, final TimerService timerService) { function.super(); + this.keyedStateBackend = Preconditions.checkNotNull(keyedStateBackend); this.states = Preconditions.checkNotNull(broadcastStates); this.timerService = Preconditions.checkNotNull(timerService); } @@ -192,9 +199,21 @@ public long currentProcessingTime() { public long currentWatermark() { return timerService.currentWatermark(); } + + @Override + public void applyToKeyedState( + final StateDescriptor stateDescriptor, + final KeyedStateFunction function) throws Exception { + + keyedStateBackend.applyToAllKeys( + VoidNamespace.INSTANCE, + VoidNamespaceSerializer.INSTANCE, + Preconditions.checkNotNull(stateDescriptor), + Preconditions.checkNotNull(function)); + } } - private class ReadOnlyContextImpl extends KeyedBroadcastProcessFunction.KeyedReadOnlyContext { + private class ReadOnlyContextImpl extends KeyedBroadcastProcessFunction.KeyedReadOnlyContext { private final Map, BroadcastState> states; @@ -203,7 +222,7 @@ private class ReadOnlyContextImpl extends KeyedBroadcastProcessFunction element; ReadOnlyContextImpl( - final KeyedBroadcastProcessFunction function, + final KeyedBroadcastProcessFunction function, final Map, BroadcastState> broadcastStates, final TimerService timerService) { @@ -256,7 +275,7 @@ public ReadOnlyBroadcastState getBroadcastState(MapStateDescriptor } } - private class OnTimerContextImpl extends KeyedBroadcastProcessFunction.OnTimerContext { + private class OnTimerContextImpl extends KeyedBroadcastProcessFunction.OnTimerContext { private final Map, BroadcastState> states; @@ -267,7 +286,7 @@ private class OnTimerContextImpl extends KeyedBroadcastProcessFunction timer; OnTimerContextImpl( - final KeyedBroadcastProcessFunction function, + final KeyedBroadcastProcessFunction function, final Map, BroadcastState> broadcastStates, final TimerService timerService) { diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/DataStreamTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/DataStreamTest.java index 59f54b5d4858b..bcbbfd69272c2 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/DataStreamTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/DataStreamTest.java @@ -813,7 +813,7 @@ public Watermark checkAndGetNextWatermark(T lastElement, long extractedTimestamp } } - private static class TestBroadcastProcessFunction extends KeyedBroadcastProcessFunction { + private static class TestBroadcastProcessFunction extends KeyedBroadcastProcessFunction { private final Map expectedState; @@ -837,7 +837,7 @@ public void processElement(Long value, KeyedReadOnlyContext ctx, Collector out) throws Exception { + public void processBroadcastElement(String value, KeyedContext ctx, Collector out) throws Exception { long key = Long.parseLong(value.split(":")[1]); ctx.getBroadcastState(DESCRIPTOR).put(key, value); } @@ -925,10 +925,10 @@ public long extractTimestamp(String element, long previousElementTimestamp) { BroadcastStream broadcast = srcTwo.broadcast(descriptor); srcOne.connect(broadcast) - .process(new KeyedBroadcastProcessFunction() { + .process(new KeyedBroadcastProcessFunction() { @Override - public void processBroadcastElement(String value, Context ctx, Collector out) throws Exception { + public void processBroadcastElement(String value, KeyedContext ctx, Collector out) throws Exception { // do nothing } diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithKeyedOperatorTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithKeyedOperatorTest.java index 3398d14b58126..3fa439f033adc 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithKeyedOperatorTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithKeyedOperatorTest.java @@ -18,11 +18,14 @@ package org.apache.flink.streaming.api.operators.co; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; import org.apache.flink.api.common.state.MapStateDescriptor; import org.apache.flink.api.common.state.ValueStateDescriptor; import org.apache.flink.api.common.typeinfo.BasicTypeInfo; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.runtime.state.KeyedStateFunction; import org.apache.flink.streaming.api.functions.co.KeyedBroadcastProcessFunction; import org.apache.flink.streaming.api.watermark.Watermark; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; @@ -63,6 +66,99 @@ public class CoBroadcastWithKeyedOperatorTest { BasicTypeInfo.INT_TYPE_INFO ); + /** Test the iteration over the keyed state on the broadcast side. */ + @Test + public void testAccessToKeyedStateIt() throws Exception { + final List test1content = new ArrayList<>(); + test1content.add("test1"); + test1content.add("test1"); + + final List test2content = new ArrayList<>(); + test2content.add("test2"); + test2content.add("test2"); + test2content.add("test2"); + test2content.add("test2"); + + final List test3content = new ArrayList<>(); + test3content.add("test3"); + test3content.add("test3"); + test3content.add("test3"); + + final Map> expectedState = new HashMap<>(); + expectedState.put("test1", test1content); + expectedState.put("test2", test2content); + expectedState.put("test3", test3content); + + try ( + TwoInputStreamOperatorTestHarness testHarness = getInitializedTestHarness( + BasicTypeInfo.STRING_TYPE_INFO, + new IdentityKeySelector<>(), + new StatefulFunctionWithKeyedStateAccessedOnBroadcast(expectedState)) + ) { + + // send elements to the keyed state + testHarness.processElement1(new StreamRecord<>("test1", 12L)); + testHarness.processElement1(new StreamRecord<>("test1", 12L)); + + testHarness.processElement1(new StreamRecord<>("test2", 13L)); + testHarness.processElement1(new StreamRecord<>("test2", 13L)); + testHarness.processElement1(new StreamRecord<>("test2", 13L)); + + testHarness.processElement1(new StreamRecord<>("test3", 14L)); + testHarness.processElement1(new StreamRecord<>("test3", 14L)); + testHarness.processElement1(new StreamRecord<>("test3", 14L)); + + testHarness.processElement1(new StreamRecord<>("test2", 13L)); + + // this is the element on the broadcast side that will trigger the verification + // check the StatefulFunctionWithKeyedStateAccessedOnBroadcast#processBroadcastElement() + testHarness.processElement2(new StreamRecord<>(1, 13L)); + } + } + + /** + * Simple {@link KeyedBroadcastProcessFunction} that adds all incoming elements in the non-broadcast + * side to a listState and at the broadcast side it verifies if the stored data is the expected ones. + */ + private static class StatefulFunctionWithKeyedStateAccessedOnBroadcast + extends KeyedBroadcastProcessFunction { + + private static final long serialVersionUID = 7496674620398203933L; + + private final ListStateDescriptor listStateDesc = + new ListStateDescriptor<>("listStateTest", BasicTypeInfo.STRING_TYPE_INFO); + + private final Map> expectedKeyedStates; + + StatefulFunctionWithKeyedStateAccessedOnBroadcast(Map> expectedKeyedState) { + this.expectedKeyedStates = Preconditions.checkNotNull(expectedKeyedState); + } + + @Override + public void processBroadcastElement(Integer value, KeyedContext ctx, Collector out) throws Exception { + // put an element in the broadcast state + ctx.applyToKeyedState( + listStateDesc, + new KeyedStateFunction>() { + @Override + public void process(String key, ListState state) throws Exception { + final Iterator it = state.get().iterator(); + + final List list = new ArrayList<>(); + while (it.hasNext()) { + list.add(it.next()); + } + Assert.assertEquals(expectedKeyedStates.get(key), list); + } + }); + } + + @Override + public void processElement(String value, KeyedReadOnlyContext ctx, Collector out) throws Exception { + getRuntimeContext().getListState(listStateDesc).add(value); + } + } + @Test public void testFunctionWithTimer() throws Exception { @@ -102,7 +198,7 @@ public void testFunctionWithTimer() throws Exception { * {@link KeyedBroadcastProcessFunction} that registers a timer and emits * for every element the watermark and the timestamp of the element. */ - private static class FunctionWithTimerOnKeyed extends KeyedBroadcastProcessFunction { + private static class FunctionWithTimerOnKeyed extends KeyedBroadcastProcessFunction { private static final long serialVersionUID = 7496674620398203933L; @@ -113,7 +209,7 @@ private static class FunctionWithTimerOnKeyed extends KeyedBroadcastProcessFunct } @Override - public void processBroadcastElement(Integer value, Context ctx, Collector out) throws Exception { + public void processBroadcastElement(Integer value, KeyedContext ctx, Collector out) throws Exception { out.collect("BR:" + value + " WM:" + ctx.currentWatermark() + " TS:" + ctx.timestamp()); } @@ -172,7 +268,7 @@ public void testSideOutput() throws Exception { /** * {@link KeyedBroadcastProcessFunction} that emits elements on side outputs. */ - private static class FunctionWithSideOutput extends KeyedBroadcastProcessFunction { + private static class FunctionWithSideOutput extends KeyedBroadcastProcessFunction { private static final long serialVersionUID = 7496674620398203933L; @@ -185,7 +281,7 @@ private static class FunctionWithSideOutput extends KeyedBroadcastProcessFunctio }; @Override - public void processBroadcastElement(Integer value, Context ctx, Collector out) throws Exception { + public void processBroadcastElement(Integer value, KeyedContext ctx, Collector out) throws Exception { ctx.output(BROADCAST_TAG, "BR:" + value + " WM:" + ctx.currentWatermark() + " TS:" + ctx.timestamp()); } @@ -254,7 +350,7 @@ public void testFunctionWithBroadcastState() throws Exception { } } - private static class FunctionWithBroadcastState extends KeyedBroadcastProcessFunction { + private static class FunctionWithBroadcastState extends KeyedBroadcastProcessFunction { private static final long serialVersionUID = 7496674620398203933L; @@ -273,7 +369,7 @@ private static class FunctionWithBroadcastState extends KeyedBroadcastProcessFun } @Override - public void processBroadcastElement(Integer value, Context ctx, Collector out) throws Exception { + public void processBroadcastElement(Integer value, KeyedContext ctx, Collector out) throws Exception { // put an element in the broadcast state final String key = value + "." + keyPostfix; ctx.getBroadcastState(STATE_DESCRIPTOR).put(key, value); @@ -501,7 +597,7 @@ public void testScaleDown() throws Exception { } } - private static class TestFunctionWithOutput extends KeyedBroadcastProcessFunction { + private static class TestFunctionWithOutput extends KeyedBroadcastProcessFunction { private static final long serialVersionUID = 7496674620398203933L; @@ -512,7 +608,7 @@ private static class TestFunctionWithOutput extends KeyedBroadcastProcessFunctio } @Override - public void processBroadcastElement(Integer value, Context ctx, Collector out) throws Exception { + public void processBroadcastElement(Integer value, KeyedContext ctx, Collector out) throws Exception { // put an element in the broadcast state for (String k : keysToRegister) { ctx.getBroadcastState(STATE_DESCRIPTOR).put(k, value); @@ -536,14 +632,14 @@ public void testNoKeyedStateOnBroadcastSide() throws Exception { TwoInputStreamOperatorTestHarness testHarness = getInitializedTestHarness( BasicTypeInfo.STRING_TYPE_INFO, new IdentityKeySelector<>(), - new KeyedBroadcastProcessFunction() { + new KeyedBroadcastProcessFunction() { private static final long serialVersionUID = -1725365436500098384L; private final ValueStateDescriptor valueState = new ValueStateDescriptor<>("any", BasicTypeInfo.STRING_TYPE_INFO); @Override - public void processBroadcastElement(Integer value, Context ctx, Collector out) throws Exception { + public void processBroadcastElement(Integer value, KeyedContext ctx, Collector out) throws Exception { getRuntimeContext().getState(valueState).value(); // this should fail } @@ -575,10 +671,10 @@ public T getKey(T value) throws Exception { } } - private static TwoInputStreamOperatorTestHarness getInitializedTestHarness( + private static TwoInputStreamOperatorTestHarness getInitializedTestHarness( final TypeInformation keyTypeInfo, final KeySelector keyKeySelector, - final KeyedBroadcastProcessFunction function) throws Exception { + final KeyedBroadcastProcessFunction function) throws Exception { return getInitializedTestHarness( keyTypeInfo, @@ -589,10 +685,10 @@ private static TwoInputStreamOperatorTestHarness TwoInputStreamOperatorTestHarness getInitializedTestHarness( + private static TwoInputStreamOperatorTestHarness getInitializedTestHarness( final TypeInformation keyTypeInfo, final KeySelector keyKeySelector, - final KeyedBroadcastProcessFunction function, + final KeyedBroadcastProcessFunction function, final int maxParallelism, final int numTasks, final int taskIdx) throws Exception { @@ -607,10 +703,10 @@ private static TwoInputStreamOperatorTestHarness TwoInputStreamOperatorTestHarness getInitializedTestHarness( + private static TwoInputStreamOperatorTestHarness getInitializedTestHarness( final TypeInformation keyTypeInfo, final KeySelector keyKeySelector, - final KeyedBroadcastProcessFunction function, + final KeyedBroadcastProcessFunction function, final int maxParallelism, final int numTasks, final int taskIdx, From 72060e02640c43c8925639dd78560a926816d71f Mon Sep 17 00:00:00 2001 From: kkloudas Date: Mon, 29 Jan 2018 16:23:04 +0100 Subject: [PATCH 5/6] [FLINK-8446] Support multiple broadcast states. --- .../datastream/BroadcastConnectedStream.java | 24 ++-- .../api/datastream/BroadcastStream.java | 29 ++--- .../streaming/api/datastream/DataStream.java | 11 +- .../co/BroadcastProcessFunction.java | 2 +- .../co/KeyedBroadcastProcessFunction.java | 2 +- .../flink/streaming/api/DataStreamTest.java | 6 +- .../CoBroadcastWithNonKeyedOperatorTest.java | 116 +++++++++++++++--- 7 files changed, 134 insertions(+), 56 deletions(-) diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/BroadcastConnectedStream.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/BroadcastConnectedStream.java index 453c850fbb44c..f3c4838125ca8 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/BroadcastConnectedStream.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/BroadcastConnectedStream.java @@ -33,14 +33,14 @@ import org.apache.flink.streaming.api.transformations.TwoInputTransformation; import org.apache.flink.util.Preconditions; -import java.util.Collections; +import java.util.List; import static java.util.Objects.requireNonNull; /** * A BroadcastConnectedStream represents the result of connecting a keyed or non-keyed stream, * with a {@link BroadcastStream} with {@link org.apache.flink.api.common.state.BroadcastState - * BroadcastState}. As in the case of {@link ConnectedStreams} these streams are useful for cases + * broadcast state(s)}. As in the case of {@link ConnectedStreams} these streams are useful for cases * where operations on one stream directly affect the operations on the other stream, usually via * shared state between the streams. * @@ -52,26 +52,24 @@ * * @param The input type of the non-broadcast side. * @param The input type of the broadcast side. - * @param The key type of the elements in the {@link org.apache.flink.api.common.state.BroadcastState BroadcastState}. - * @param The value type of the elements in the {@link org.apache.flink.api.common.state.BroadcastState BroadcastState}. */ @PublicEvolving -public class BroadcastConnectedStream { +public class BroadcastConnectedStream { private final StreamExecutionEnvironment environment; private final DataStream inputStream1; - private final BroadcastStream inputStream2; - private final MapStateDescriptor broadcastStateDescriptor; + private final BroadcastStream inputStream2; + private final List> broadcastStateDescriptors; protected BroadcastConnectedStream( final StreamExecutionEnvironment env, final DataStream input1, - final BroadcastStream input2, - final MapStateDescriptor broadcastStateDescriptor) { + final BroadcastStream input2, + final List> broadcastStateDescriptors) { this.environment = requireNonNull(env); this.inputStream1 = requireNonNull(input1); this.inputStream2 = requireNonNull(input2); - this.broadcastStateDescriptor = requireNonNull(broadcastStateDescriptor); + this.broadcastStateDescriptors = requireNonNull(broadcastStateDescriptors); } public StreamExecutionEnvironment getExecutionEnvironment() { @@ -92,7 +90,7 @@ public DataStream getFirstInput() { * * @return The stream which, by convention, is the broadcast one. */ - public BroadcastStream getSecondInput() { + public BroadcastStream getSecondInput() { return inputStream2; } @@ -163,7 +161,7 @@ public SingleOutputStreamOperator process( "A KeyedBroadcastProcessFunction can only be used with a keyed stream as the second input."); TwoInputStreamOperator operator = - new CoBroadcastWithKeyedOperator<>(function, Collections.singletonList(broadcastStateDescriptor)); + new CoBroadcastWithKeyedOperator<>(function, broadcastStateDescriptors); return transform("Co-Process-Broadcast-Keyed", outTypeInfo, operator); } @@ -214,7 +212,7 @@ public SingleOutputStreamOperator process( "A BroadcastProcessFunction can only be used with a non-keyed stream as the second input."); TwoInputStreamOperator operator = - new CoBroadcastWithNonKeyedOperator<>(function, Collections.singletonList(broadcastStateDescriptor)); + new CoBroadcastWithNonKeyedOperator<>(function, broadcastStateDescriptors); return transform("Co-Process-Broadcast", outTypeInfo, operator); } diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/BroadcastStream.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/BroadcastStream.java index e21e36faff22b..6c56f9805cd25 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/BroadcastStream.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/BroadcastStream.java @@ -24,12 +24,15 @@ import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.streaming.api.transformations.StreamTransformation; +import java.util.Arrays; +import java.util.List; + import static java.util.Objects.requireNonNull; /** - * A {@code BroadcastStream} is a stream with {@link org.apache.flink.api.common.state.BroadcastState BroadcastState}. - * This can be created by any stream using the {@link DataStream#broadcast(MapStateDescriptor)} method and - * implicitly creates a state where the user can store elements of the created {@code BroadcastStream}. + * A {@code BroadcastStream} is a stream with {@link org.apache.flink.api.common.state.BroadcastState broadcast state(s)}. + * This can be created by any stream using the {@link DataStream#broadcast(MapStateDescriptor[])} method and + * implicitly creates states where the user can store elements of the created {@code BroadcastStream}. * (see {@link BroadcastConnectedStream}). * *

    Note that no further operation can be applied to these streams. The only available option is to connect them @@ -38,31 +41,29 @@ * {@link BroadcastConnectedStream} for further processing. * * @param The type of input/output elements. - * @param The key type of the elements in the {@link org.apache.flink.api.common.state.BroadcastState BroadcastState}. - * @param The value type of the elements in the {@link org.apache.flink.api.common.state.BroadcastState BroadcastState}. */ @PublicEvolving -public class BroadcastStream { +public class BroadcastStream { private final StreamExecutionEnvironment environment; private final DataStream inputStream; /** - * The {@link org.apache.flink.api.common.state.StateDescriptor state descriptor} of the - * {@link org.apache.flink.api.common.state.BroadcastState broadcast state}. This state - * has a {@code key-value} format. + * The {@link org.apache.flink.api.common.state.StateDescriptor state descriptors} of the + * registered {@link org.apache.flink.api.common.state.BroadcastState broadcast states}. These + * states have {@code key-value} format. */ - private final MapStateDescriptor broadcastStateDescriptor; + private final List> broadcastStateDescriptors; protected BroadcastStream( final StreamExecutionEnvironment env, final DataStream input, - final MapStateDescriptor broadcastStateDescriptor) { + final MapStateDescriptor... broadcastStateDescriptors) { this.environment = requireNonNull(env); this.inputStream = requireNonNull(input); - this.broadcastStateDescriptor = requireNonNull(broadcastStateDescriptor); + this.broadcastStateDescriptors = Arrays.asList(requireNonNull(broadcastStateDescriptors)); } public TypeInformation getType() { @@ -77,8 +78,8 @@ public StreamTransformation getTransformation() { return inputStream.getTransformation(); } - public MapStateDescriptor getBroadcastStateDescriptor() { - return broadcastStateDescriptor; + public List> getBroadcastStateDescriptor() { + return broadcastStateDescriptors; } public StreamExecutionEnvironment getEnvironment() { diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/DataStream.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/DataStream.java index d85968957e057..8d18b804d2d9e 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/DataStream.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/DataStream.java @@ -257,7 +257,7 @@ public ConnectedStreams connect(DataStream dataStream) { * Creates a new {@link BroadcastConnectedStream} by connecting the current * {@link DataStream} or {@link KeyedStream} with a {@link BroadcastStream}. * - *

    The latter can be created using the {@link #broadcast(MapStateDescriptor)} method. + *

    The latter can be created using the {@link #broadcast(MapStateDescriptor[])} method. * *

    The resulting stream can be further processed using the {@code BroadcastConnectedStream.process(MyFunction)} * method, where {@code MyFunction} can be either a @@ -269,7 +269,7 @@ public ConnectedStreams connect(DataStream dataStream) { * @return The {@link BroadcastConnectedStream}. */ @PublicEvolving - public BroadcastConnectedStream connect(BroadcastStream broadcastStream) { + public BroadcastConnectedStream connect(BroadcastStream broadcastStream) { return new BroadcastConnectedStream<>( environment, this, @@ -402,14 +402,15 @@ public DataStream broadcast() { * it implicitly creates a {@link org.apache.flink.api.common.state.BroadcastState broadcast state} * which can be used to store the element of the stream. * + * @param broadcastStateDescriptors the descriptors of the broadcast states to create. * @return A {@link BroadcastStream} which can be used in the {@link #connect(BroadcastStream)} to * create a {@link BroadcastConnectedStream} for further processing of the elements. */ @PublicEvolving - public BroadcastStream broadcast(final MapStateDescriptor broadcastStateDescriptor) { - Preconditions.checkNotNull(broadcastStateDescriptor); + public BroadcastStream broadcast(final MapStateDescriptor... broadcastStateDescriptors) { + Preconditions.checkNotNull(broadcastStateDescriptors); final DataStream broadcastStream = setConnectionType(new BroadcastPartitioner<>()); - return new BroadcastStream<>(environment, broadcastStream, broadcastStateDescriptor); + return new BroadcastStream<>(environment, broadcastStream, broadcastStateDescriptors); } /** diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/co/BroadcastProcessFunction.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/co/BroadcastProcessFunction.java index 4dcc92992cd4c..257ea834c2c68 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/co/BroadcastProcessFunction.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/co/BroadcastProcessFunction.java @@ -29,7 +29,7 @@ * with broadcast state, with a non-keyed {@link org.apache.flink.streaming.api.datastream.DataStream DataStream}. * *

    The stream with the broadcast state can be created using the - * {@link org.apache.flink.streaming.api.datastream.DataStream#broadcast(MapStateDescriptor) + * {@link org.apache.flink.streaming.api.datastream.DataStream#broadcast(MapStateDescriptor[])} * stream.broadcast(MapStateDescriptor)} method. * *

    The user has to implement two methods: diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/co/KeyedBroadcastProcessFunction.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/co/KeyedBroadcastProcessFunction.java index 4b9f13828e0de..de9cb324dc307 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/co/KeyedBroadcastProcessFunction.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/co/KeyedBroadcastProcessFunction.java @@ -34,7 +34,7 @@ * with broadcast state, with a {@link org.apache.flink.streaming.api.datastream.KeyedStream KeyedStream}. * *

    The stream with the broadcast state can be created using the - * {@link org.apache.flink.streaming.api.datastream.KeyedStream#broadcast(MapStateDescriptor) + * {@link org.apache.flink.streaming.api.datastream.KeyedStream#broadcast(MapStateDescriptor[])} * keyedStream.broadcast(MapStateDescriptor)} method. * *

    The user has to implement two methods: diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/DataStreamTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/DataStreamTest.java index bcbbfd69272c2..ca76ef4c2923a 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/DataStreamTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/DataStreamTest.java @@ -794,7 +794,7 @@ public long extractTimestamp(String element, long previousElementTimestamp) { } }); - final BroadcastStream broadcast = srcTwo.broadcast(TestBroadcastProcessFunction.DESCRIPTOR); + final BroadcastStream broadcast = srcTwo.broadcast(TestBroadcastProcessFunction.DESCRIPTOR); // the timestamp should be high enough to trigger the timer after all the elements arrive. final DataStream output = srcOne.connect(broadcast).process( @@ -880,7 +880,7 @@ public long extractTimestamp(String element, long previousElementTimestamp) { } }); - BroadcastStream broadcast = srcTwo.broadcast(descriptor); + BroadcastStream broadcast = srcTwo.broadcast(descriptor); srcOne.connect(broadcast) .process(new BroadcastProcessFunction() { @Override @@ -923,7 +923,7 @@ public long extractTimestamp(String element, long previousElementTimestamp) { } }); - BroadcastStream broadcast = srcTwo.broadcast(descriptor); + BroadcastStream broadcast = srcTwo.broadcast(descriptor); srcOne.connect(broadcast) .process(new KeyedBroadcastProcessFunction() { diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithNonKeyedOperatorTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithNonKeyedOperatorTest.java index 066a80ff95ac0..96e1c3e390c54 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithNonKeyedOperatorTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithNonKeyedOperatorTest.java @@ -35,7 +35,7 @@ import org.junit.Assert; import org.junit.Test; -import java.util.Collections; +import java.util.Arrays; import java.util.HashSet; import java.util.Map; import java.util.Queue; @@ -54,6 +54,59 @@ public class CoBroadcastWithNonKeyedOperatorTest { BasicTypeInfo.INT_TYPE_INFO ); + private static final MapStateDescriptor STATE_DESCRIPTOR_A = + new MapStateDescriptor<>( + "broadcast-state-A", + BasicTypeInfo.INT_TYPE_INFO, + BasicTypeInfo.STRING_TYPE_INFO + ); + + @Test + public void testMultiStateSupport() throws Exception { + try ( + TwoInputStreamOperatorTestHarness testHarness = + getInitializedTestHarness(new FunctionWithMultipleStates(), STATE_DESCRIPTOR, STATE_DESCRIPTOR_A) + ) { + testHarness.processElement2(new StreamRecord<>(5, 12L)); + testHarness.processElement2(new StreamRecord<>(6, 13L)); + + testHarness.processElement1(new StreamRecord<>("9", 15L)); + + Queue expectedBr = new ConcurrentLinkedQueue<>(); + expectedBr.add(new StreamRecord<>("9:key.6->6", 15L)); + expectedBr.add(new StreamRecord<>("9:key.5->5", 15L)); + expectedBr.add(new StreamRecord<>("9:5->value.5", 15L)); + expectedBr.add(new StreamRecord<>("9:6->value.6", 15L)); + + TestHarnessUtil.assertOutputEquals("Wrong Side Output", expectedBr, testHarness.getOutput()); + } + } + + /** + * {@link BroadcastProcessFunction} that puts elements on multiple broadcast states. + */ + private static class FunctionWithMultipleStates extends BroadcastProcessFunction { + + private static final long serialVersionUID = 7496674620398203933L; + + @Override + public void processBroadcastElement(Integer value, Context ctx, Collector out) throws Exception { + ctx.getBroadcastState(STATE_DESCRIPTOR).put("key." + value, value); + ctx.getBroadcastState(STATE_DESCRIPTOR_A).put(value, "value." + value); + } + + @Override + public void processElement(String value, ReadOnlyContext ctx, Collector out) throws Exception { + for (Map.Entry entry: ctx.getBroadcastState(STATE_DESCRIPTOR).immutableEntries()) { + out.collect(value + ":" + entry.getKey() + "->" + entry.getValue()); + } + + for (Map.Entry entry: ctx.getBroadcastState(STATE_DESCRIPTOR_A).immutableEntries()) { + out.collect(value + ":" + entry.getKey() + "->" + entry.getValue()); + } + } + } + @Test public void testBroadcastState() throws Exception { @@ -64,7 +117,7 @@ public void testBroadcastState() throws Exception { try ( TwoInputStreamOperatorTestHarness testHarness = getInitializedTestHarness( - new TestFunction(keysToRegister)) + new TestFunction(keysToRegister), STATE_DESCRIPTOR) ) { testHarness.processWatermark1(new Watermark(10L)); testHarness.processWatermark2(new Watermark(10L)); @@ -127,7 +180,7 @@ public void processElement(String value, ReadOnlyContext ctx, Collector public void testSideOutput() throws Exception { try ( TwoInputStreamOperatorTestHarness testHarness = getInitializedTestHarness( - new FunctionWithSideOutput()) + new FunctionWithSideOutput(), STATE_DESCRIPTOR) ) { testHarness.processWatermark1(new Watermark(10L)); @@ -197,13 +250,15 @@ public void testScaleUp() throws Exception { new TestFunctionWithOutput(keysToRegister), 10, 2, - 0); + 0, + STATE_DESCRIPTOR); TwoInputStreamOperatorTestHarness testHarness2 = getInitializedTestHarness( new TestFunctionWithOutput(keysToRegister), 10, 2, - 1) + 1, + STATE_DESCRIPTOR) ) { // make sure all operators have the same state testHarness1.processElement2(new StreamRecord<>(3)); @@ -226,21 +281,24 @@ public void testScaleUp() throws Exception { 10, 3, 0, - mergedSnapshot); + mergedSnapshot, + STATE_DESCRIPTOR); TwoInputStreamOperatorTestHarness testHarness2 = getInitializedTestHarness( new TestFunctionWithOutput(keysToRegister), 10, 3, 1, - mergedSnapshot); + mergedSnapshot, + STATE_DESCRIPTOR); TwoInputStreamOperatorTestHarness testHarness3 = getInitializedTestHarness( new TestFunctionWithOutput(keysToRegister), 10, 3, 2, - mergedSnapshot) + mergedSnapshot, + STATE_DESCRIPTOR) ) { testHarness1.processElement1(new StreamRecord<>("trigger")); testHarness2.processElement1(new StreamRecord<>("trigger")); @@ -284,19 +342,22 @@ public void testScaleDown() throws Exception { new TestFunctionWithOutput(keysToRegister), 10, 3, - 0); + 0, + STATE_DESCRIPTOR); TwoInputStreamOperatorTestHarness testHarness2 = getInitializedTestHarness( new TestFunctionWithOutput(keysToRegister), 10, 3, - 1); + 1, + STATE_DESCRIPTOR); TwoInputStreamOperatorTestHarness testHarness3 = getInitializedTestHarness( new TestFunctionWithOutput(keysToRegister), 10, 3, - 2) + 2, + STATE_DESCRIPTOR) ) { // make sure all operators have the same state @@ -322,14 +383,16 @@ public void testScaleDown() throws Exception { 10, 2, 0, - mergedSnapshot); + mergedSnapshot, + STATE_DESCRIPTOR); TwoInputStreamOperatorTestHarness testHarness2 = getInitializedTestHarness( new TestFunctionWithOutput(keysToRegister), 10, 2, 1, - mergedSnapshot) + mergedSnapshot, + STATE_DESCRIPTOR) ) { testHarness1.processElement1(new StreamRecord<>("trigger")); testHarness2.processElement1(new StreamRecord<>("trigger")); @@ -452,40 +515,55 @@ public void processElement(String value, ReadOnlyContext ctx, Collector } private static TwoInputStreamOperatorTestHarness getInitializedTestHarness( - final BroadcastProcessFunction function) throws Exception { + final BroadcastProcessFunction function, + final MapStateDescriptor... descriptors) throws Exception { return getInitializedTestHarness( function, 1, 1, - 0); + 0, + descriptors); } private static TwoInputStreamOperatorTestHarness getInitializedTestHarness( final BroadcastProcessFunction function, final int maxParallelism, final int numTasks, - final int taskIdx) throws Exception { + final int taskIdx, + final MapStateDescriptor... descriptors) throws Exception { return getInitializedTestHarness( function, maxParallelism, numTasks, taskIdx, - null); + null, + descriptors); } +// private static TwoInputStreamOperatorTestHarness getInitializedTestHarness( +// final BroadcastProcessFunction function, +// final int maxParallelism, +// final int numTasks, +// final int taskIdx, +// final OperatorStateHandles initState) throws Exception { +// +// return getInitializedTestHarness(function, maxParallelism, numTasks, taskIdx, initState, STATE_DESCRIPTOR); +// } + private static TwoInputStreamOperatorTestHarness getInitializedTestHarness( final BroadcastProcessFunction function, final int maxParallelism, final int numTasks, final int taskIdx, - final OperatorStateHandles initState) throws Exception { + final OperatorStateHandles initState, + final MapStateDescriptor... descriptors) throws Exception { TwoInputStreamOperatorTestHarness testHarness = new TwoInputStreamOperatorTestHarness<>( new CoBroadcastWithNonKeyedOperator<>( Preconditions.checkNotNull(function), - Collections.singletonList(STATE_DESCRIPTOR)), + Arrays.asList(descriptors)), maxParallelism, numTasks, taskIdx ); testHarness.setup(); From d5f821a517e4692f5c896eb6a418b4aa65af649b Mon Sep 17 00:00:00 2001 From: kkloudas Date: Fri, 2 Feb 2018 14:59:34 +0100 Subject: [PATCH 6/6] [hotfix] Create BroadcastITCase. --- .../flink/streaming/api/DataStreamTest.java | 124 +++--------- .../runtime/BroadcastStateITCase.java | 183 ++++++++++++++++++ 2 files changed, 207 insertions(+), 100 deletions(-) create mode 100644 flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/BroadcastStateITCase.java diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/DataStreamTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/DataStreamTest.java index ca76ef4c2923a..ec8a134e8248b 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/DataStreamTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/DataStreamTest.java @@ -38,6 +38,7 @@ import org.apache.flink.api.java.typeutils.TupleTypeInfo; import org.apache.flink.api.java.typeutils.TypeExtractor; import org.apache.flink.streaming.api.collector.selector.OutputSelector; +import org.apache.flink.streaming.api.datastream.BroadcastConnectedStream; import org.apache.flink.streaming.api.datastream.BroadcastStream; import org.apache.flink.streaming.api.datastream.ConnectedStreams; import org.apache.flink.streaming.api.datastream.DataStream; @@ -87,9 +88,7 @@ import javax.annotation.Nullable; import java.lang.reflect.Method; -import java.util.HashMap; import java.util.List; -import java.util.Map; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; @@ -103,6 +102,9 @@ @SuppressWarnings("serial") public class DataStreamTest extends TestLogger { + @Rule + public ExpectedException expectedException = ExpectedException.none(); + /** * Tests union functionality. This ensures that self-unions and unions of streams * with differing parallelism work. @@ -763,99 +765,10 @@ public void onTimer( assertTrue(getOperatorForDataStream(processed) instanceof ProcessOperator); } - @Test - public void testConnectWithBroadcastTranslation() throws Exception { - - final Map expected = new HashMap<>(); - expected.put(0L, "test:0"); - expected.put(1L, "test:1"); - expected.put(2L, "test:2"); - expected.put(3L, "test:3"); - expected.put(4L, "test:4"); - expected.put(5L, "test:5"); - - final StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); - env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime); - - final DataStream srcOne = env.generateSequence(0L, 5L) - .assignTimestampsAndWatermarks(new CustomWmEmitter() { - - @Override - public long extractTimestamp(Long element, long previousElementTimestamp) { - return element; - } - }).keyBy((KeySelector) value -> value); - - final DataStream srcTwo = env.fromCollection(expected.values()) - .assignTimestampsAndWatermarks(new CustomWmEmitter() { - @Override - public long extractTimestamp(String element, long previousElementTimestamp) { - return Long.parseLong(element.split(":")[1]); - } - }); - - final BroadcastStream broadcast = srcTwo.broadcast(TestBroadcastProcessFunction.DESCRIPTOR); - - // the timestamp should be high enough to trigger the timer after all the elements arrive. - final DataStream output = srcOne.connect(broadcast).process( - new TestBroadcastProcessFunction(100000L, expected)); - - output.addSink(new DiscardingSink<>()); - env.execute(); - } - - private abstract static class CustomWmEmitter implements AssignerWithPunctuatedWatermarks { - - @Nullable - @Override - public Watermark checkAndGetNextWatermark(T lastElement, long extractedTimestamp) { - return new Watermark(extractedTimestamp); - } - } - - private static class TestBroadcastProcessFunction extends KeyedBroadcastProcessFunction { - - private final Map expectedState; - - private final long timerTimestamp; - - static final MapStateDescriptor DESCRIPTOR = new MapStateDescriptor<>( - "broadcast-state", BasicTypeInfo.LONG_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO - ); - - TestBroadcastProcessFunction( - final long timerTS, - final Map expectedBroadcastState - ) { - expectedState = expectedBroadcastState; - timerTimestamp = timerTS; - } - - @Override - public void processElement(Long value, KeyedReadOnlyContext ctx, Collector out) throws Exception { - ctx.timerService().registerEventTimeTimer(timerTimestamp); - } - - @Override - public void processBroadcastElement(String value, KeyedContext ctx, Collector out) throws Exception { - long key = Long.parseLong(value.split(":")[1]); - ctx.getBroadcastState(DESCRIPTOR).put(key, value); - } - - @Override - public void onTimer(long timestamp, OnTimerContext ctx, Collector out) throws Exception { - Map map = new HashMap<>(); - for (Map.Entry entry : ctx.getBroadcastState(DESCRIPTOR).immutableEntries()) { - map.put(entry.getKey(), entry.getValue()); - } - Assert.assertEquals(expectedState, map); - } - } - /** * Tests that with a {@link KeyedStream} we have to provide a {@link KeyedBroadcastProcessFunction}. */ - @Test(expected = IllegalArgumentException.class) + @Test public void testFailedTranslationOnKeyed() { final MapStateDescriptor descriptor = new MapStateDescriptor<>( @@ -881,8 +794,11 @@ public long extractTimestamp(String element, long previousElementTimestamp) { }); BroadcastStream broadcast = srcTwo.broadcast(descriptor); - srcOne.connect(broadcast) - .process(new BroadcastProcessFunction() { + BroadcastConnectedStream bcStream = srcOne.connect(broadcast); + + expectedException.expect(IllegalArgumentException.class); + bcStream.process( + new BroadcastProcessFunction() { @Override public void processBroadcastElement(String value, Context ctx, Collector out) throws Exception { // do nothing @@ -898,7 +814,7 @@ public void processElement(Long value, ReadOnlyContext ctx, Collector ou /** * Tests that with a non-keyed stream we have to provide a {@link BroadcastProcessFunction}. */ - @Test(expected = IllegalArgumentException.class) + @Test public void testFailedTranslationOnNonKeyed() { final MapStateDescriptor descriptor = new MapStateDescriptor<>( @@ -924,9 +840,11 @@ public long extractTimestamp(String element, long previousElementTimestamp) { }); BroadcastStream broadcast = srcTwo.broadcast(descriptor); - srcOne.connect(broadcast) - .process(new KeyedBroadcastProcessFunction() { + BroadcastConnectedStream bcStream = srcOne.connect(broadcast); + expectedException.expect(IllegalArgumentException.class); + bcStream.process( + new KeyedBroadcastProcessFunction() { @Override public void processBroadcastElement(String value, KeyedContext ctx, Collector out) throws Exception { // do nothing @@ -939,6 +857,15 @@ public void processElement(Long value, KeyedReadOnlyContext ctx, Collector implements AssignerWithPunctuatedWatermarks { + + @Nullable + @Override + public Watermark checkAndGetNextWatermark(T lastElement, long extractedTimestamp) { + return new Watermark(extractedTimestamp); + } + } + @Test public void operatorTest() { StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); @@ -1131,9 +1058,6 @@ public void testChannelSelectors() { // KeyBy testing ///////////////////////////////////////////////////////////// - @Rule - public ExpectedException expectedException = ExpectedException.none(); - @Test public void testPrimitiveArrayKeyRejection() { diff --git a/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/BroadcastStateITCase.java b/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/BroadcastStateITCase.java new file mode 100644 index 0000000000000..4b0b9c5e7c7db --- /dev/null +++ b/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/BroadcastStateITCase.java @@ -0,0 +1,183 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.test.streaming.runtime; + +import org.apache.flink.api.common.state.MapStateDescriptor; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.streaming.api.TimeCharacteristic; +import org.apache.flink.streaming.api.datastream.BroadcastStream; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks; +import org.apache.flink.streaming.api.functions.co.KeyedBroadcastProcessFunction; +import org.apache.flink.streaming.api.functions.sink.RichSinkFunction; +import org.apache.flink.streaming.api.watermark.Watermark; +import org.apache.flink.util.Collector; + +import org.junit.Assert; +import org.junit.Test; + +import javax.annotation.Nullable; + +import java.util.HashMap; +import java.util.Map; + +/** + * ITCase for the {@link org.apache.flink.api.common.state.BroadcastState}. + */ +public class BroadcastStateITCase { + + @Test + public void testConnectWithBroadcastTranslation() throws Exception { + + final Map expected = new HashMap<>(); + expected.put(0L, "test:0"); + expected.put(1L, "test:1"); + expected.put(2L, "test:2"); + expected.put(3L, "test:3"); + expected.put(4L, "test:4"); + expected.put(5L, "test:5"); + + final StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime); + + final DataStream srcOne = env.generateSequence(0L, 5L) + .assignTimestampsAndWatermarks(new CustomWmEmitter() { + + private static final long serialVersionUID = -8500904795760316195L; + + @Override + public long extractTimestamp(Long element, long previousElementTimestamp) { + return element; + } + }).keyBy((KeySelector) value -> value); + + final DataStream srcTwo = env.fromCollection(expected.values()) + .assignTimestampsAndWatermarks(new CustomWmEmitter() { + + private static final long serialVersionUID = -2148318224248467213L; + + @Override + public long extractTimestamp(String element, long previousElementTimestamp) { + return Long.parseLong(element.split(":")[1]); + } + }); + + final BroadcastStream broadcast = srcTwo.broadcast(TestBroadcastProcessFunction.DESCRIPTOR); + + // the timestamp should be high enough to trigger the timer after all the elements arrive. + final DataStream output = srcOne.connect(broadcast).process( + new TestBroadcastProcessFunction(100000L, expected)); + + output + .addSink(new TestSink(expected.size())) + .setParallelism(1); + env.execute(); + } + + private static class TestSink extends RichSinkFunction { + + private static final long serialVersionUID = 7252508825104554749L; + + private final int expectedOutputCounter; + + private int outputCounter; + + TestSink(int expectedOutputCounter) { + this.expectedOutputCounter = expectedOutputCounter; + this.outputCounter = 0; + } + + @Override + public void invoke(String value, Context context) throws Exception { + outputCounter++; + } + + @Override + public void close() throws Exception { + super.close(); + + // make sure that all the timers fired + Assert.assertEquals(expectedOutputCounter, outputCounter); + } + } + + private abstract static class CustomWmEmitter implements AssignerWithPunctuatedWatermarks { + + private static final long serialVersionUID = -5187335197674841233L; + + @Nullable + @Override + public Watermark checkAndGetNextWatermark(T lastElement, long extractedTimestamp) { + return new Watermark(extractedTimestamp); + } + } + + /** + * A {@link KeyedBroadcastProcessFunction} which on the broadcast side puts elements in the broadcast state + * while on the non-broadcast side, it sets a timer to fire at some point in the future. Finally, when the onTimer + * method is called (i.e. when the timer fires), we verify that the result is the expected one. + */ + private static class TestBroadcastProcessFunction extends KeyedBroadcastProcessFunction { + + private static final long serialVersionUID = 7616910653561100842L; + + private final Map expectedState; + + private final long timerTimestamp; + + static final MapStateDescriptor DESCRIPTOR = new MapStateDescriptor<>( + "broadcast-state", BasicTypeInfo.LONG_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO + ); + + TestBroadcastProcessFunction( + final long timerTS, + final Map expectedBroadcastState + ) { + expectedState = expectedBroadcastState; + timerTimestamp = timerTS; + } + + @Override + public void processElement(Long value, KeyedReadOnlyContext ctx, Collector out) throws Exception { + ctx.timerService().registerEventTimeTimer(timerTimestamp); + } + + @Override + public void processBroadcastElement(String value, KeyedContext ctx, Collector out) throws Exception { + long key = Long.parseLong(value.split(":")[1]); + ctx.getBroadcastState(DESCRIPTOR).put(key, value); + } + + @Override + public void onTimer(long timestamp, OnTimerContext ctx, Collector out) throws Exception { + Assert.assertEquals(timerTimestamp, timestamp); + + Map map = new HashMap<>(); + for (Map.Entry entry : ctx.getBroadcastState(DESCRIPTOR).immutableEntries()) { + map.put(entry.getKey(), entry.getValue()); + } + + Assert.assertEquals(expectedState, map); + + out.collect(Long.toString(timestamp)); + } + } +}