From 54983550c1538a0e53290dccede0b8a3cdc914d9 Mon Sep 17 00:00:00 2001 From: JingsongLi Date: Wed, 7 Jun 2017 14:34:25 +0800 Subject: [PATCH 1/3] Use CoderTypeSerializer and remove unuse code in FlinkStateInternals --- .../streaming/state/FlinkStateInternals.java | 198 +----------------- 1 file changed, 10 insertions(+), 188 deletions(-) diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java index f0d327819191..d8771de998f2 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java @@ -25,7 +25,6 @@ import org.apache.beam.runners.core.StateInternals; import org.apache.beam.runners.core.StateNamespace; import org.apache.beam.runners.core.StateTag; -import org.apache.beam.runners.flink.translation.types.CoderTypeInformation; import org.apache.beam.runners.flink.translation.types.CoderTypeSerializer; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.CoderException; @@ -196,9 +195,8 @@ private static class FlinkValueState implements ValueState { this.address = address; this.flinkStateBackend = flinkStateBackend; - CoderTypeInformation typeInfo = new CoderTypeInformation<>(coder); - - flinkStateDescriptor = new ValueStateDescriptor<>(address.getId(), typeInfo, null); + flinkStateDescriptor = new ValueStateDescriptor<>( + address.getId(), new CoderTypeSerializer<>(coder)); } @Override @@ -282,9 +280,8 @@ private static class FlinkBagState implements BagState { this.address = address; this.flinkStateBackend = flinkStateBackend; - CoderTypeInformation typeInfo = new CoderTypeInformation<>(coder); - - flinkStateDescriptor = new ListStateDescriptor<>(address.getId(), typeInfo); + flinkStateDescriptor = new ListStateDescriptor<>( + address.getId(), new CoderTypeSerializer<>(coder)); } @Override @@ -398,9 +395,8 @@ private static class FlinkCombiningState this.combineFn = combineFn; this.flinkStateBackend = flinkStateBackend; - CoderTypeInformation typeInfo = new CoderTypeInformation<>(accumCoder); - - flinkStateDescriptor = new ValueStateDescriptor<>(address.getId(), typeInfo, null); + flinkStateDescriptor = new ValueStateDescriptor<>( + address.getId(), new CoderTypeSerializer<>(accumCoder)); } @Override @@ -545,179 +541,6 @@ public int hashCode() { } } - private static class FlinkKeyedCombiningState - implements CombiningState { - - private final StateNamespace namespace; - private final StateTag> address; - private final Combine.CombineFn combineFn; - private final ValueStateDescriptor flinkStateDescriptor; - private final KeyedStateBackend flinkStateBackend; - private final FlinkStateInternals flinkStateInternals; - - FlinkKeyedCombiningState( - KeyedStateBackend flinkStateBackend, - StateTag> address, - Combine.CombineFn combineFn, - StateNamespace namespace, - Coder accumCoder, - FlinkStateInternals flinkStateInternals) { - - this.namespace = namespace; - this.address = address; - this.combineFn = combineFn; - this.flinkStateBackend = flinkStateBackend; - this.flinkStateInternals = flinkStateInternals; - - CoderTypeInformation typeInfo = new CoderTypeInformation<>(accumCoder); - - flinkStateDescriptor = new ValueStateDescriptor<>(address.getId(), typeInfo, null); - } - - @Override - public CombiningState readLater() { - return this; - } - - @Override - public void add(InputT value) { - try { - org.apache.flink.api.common.state.ValueState state = - flinkStateBackend.getPartitionedState( - namespace.stringKey(), - StringSerializer.INSTANCE, - flinkStateDescriptor); - - AccumT current = state.value(); - if (current == null) { - current = combineFn.createAccumulator(); - } - current = combineFn.addInput(current, value); - state.update(current); - } catch (RuntimeException re) { - throw re; - } catch (Exception e) { - throw new RuntimeException("Error adding to state." , e); - } - } - - @Override - public void addAccum(AccumT accum) { - try { - org.apache.flink.api.common.state.ValueState state = - flinkStateBackend.getPartitionedState( - namespace.stringKey(), - StringSerializer.INSTANCE, - flinkStateDescriptor); - - AccumT current = state.value(); - if (current == null) { - state.update(accum); - } else { - current = combineFn.mergeAccumulators(Lists.newArrayList(current, accum)); - state.update(current); - } - } catch (Exception e) { - throw new RuntimeException("Error adding to state.", e); - } - } - - @Override - public AccumT getAccum() { - try { - return flinkStateBackend.getPartitionedState( - namespace.stringKey(), - StringSerializer.INSTANCE, - flinkStateDescriptor).value(); - } catch (Exception e) { - throw new RuntimeException("Error reading state.", e); - } - } - - @Override - public AccumT mergeAccumulators(Iterable accumulators) { - return combineFn.mergeAccumulators(accumulators); - } - - @Override - public OutputT read() { - try { - org.apache.flink.api.common.state.ValueState state = - flinkStateBackend.getPartitionedState( - namespace.stringKey(), - StringSerializer.INSTANCE, - flinkStateDescriptor); - - AccumT accum = state.value(); - if (accum != null) { - return combineFn.extractOutput(accum); - } else { - return combineFn.extractOutput(combineFn.createAccumulator()); - } - } catch (Exception e) { - throw new RuntimeException("Error reading state.", e); - } - } - - @Override - public ReadableState isEmpty() { - return new ReadableState() { - @Override - public Boolean read() { - try { - return flinkStateBackend.getPartitionedState( - namespace.stringKey(), - StringSerializer.INSTANCE, - flinkStateDescriptor).value() == null; - } catch (Exception e) { - throw new RuntimeException("Error reading state.", e); - } - - } - - @Override - public ReadableState readLater() { - return this; - } - }; - } - - @Override - public void clear() { - try { - flinkStateBackend.getPartitionedState( - namespace.stringKey(), - StringSerializer.INSTANCE, - flinkStateDescriptor).clear(); - } catch (Exception e) { - throw new RuntimeException("Error clearing state.", e); - } - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - - FlinkKeyedCombiningState that = - (FlinkKeyedCombiningState) o; - - return namespace.equals(that.namespace) && address.equals(that.address); - - } - - @Override - public int hashCode() { - int result = namespace.hashCode(); - result = 31 * result + address.hashCode(); - return result; - } - } - private static class FlinkCombiningStateWithContext implements CombiningState { @@ -745,9 +568,8 @@ private static class FlinkCombiningStateWithContext this.flinkStateInternals = flinkStateInternals; this.context = context; - CoderTypeInformation typeInfo = new CoderTypeInformation<>(accumCoder); - - flinkStateDescriptor = new ValueStateDescriptor<>(address.getId(), typeInfo, null); + flinkStateDescriptor = new ValueStateDescriptor<>( + address.getId(), new CoderTypeSerializer<>(accumCoder)); } @Override @@ -913,8 +735,8 @@ public FlinkWatermarkHoldState( this.flinkStateBackend = flinkStateBackend; this.flinkStateInternals = flinkStateInternals; - CoderTypeInformation typeInfo = new CoderTypeInformation<>(InstantCoder.of()); - flinkStateDescriptor = new ValueStateDescriptor<>(address.getId(), typeInfo, null); + flinkStateDescriptor = new ValueStateDescriptor<>( + address.getId(), new CoderTypeSerializer<>(InstantCoder.of())); } @Override From 95d7482f3fd84c035ea1ddf1b6eae5199c04ecff Mon Sep 17 00:00:00 2001 From: JingsongLi Date: Wed, 7 Jun 2017 14:40:30 +0800 Subject: [PATCH 2/3] [BEAM-1483] Support SetState in Flink runner and fix MapState to be consistent with InMemoryStateInternals. --- runners/flink/pom.xml | 1 - .../streaming/state/FlinkStateInternals.java | 227 ++++++++++++++---- .../streaming/FlinkStateInternalsTest.java | 17 -- 3 files changed, 182 insertions(+), 63 deletions(-) diff --git a/runners/flink/pom.xml b/runners/flink/pom.xml index a5b8203507ff..339aa8e445a9 100644 --- a/runners/flink/pom.xml +++ b/runners/flink/pom.xml @@ -91,7 +91,6 @@ org.apache.beam.sdk.testing.FlattenWithHeterogeneousCoders, org.apache.beam.sdk.testing.LargeKeys$Above100MB, - org.apache.beam.sdk.testing.UsesSetState, org.apache.beam.sdk.testing.UsesCommittedMetrics, org.apache.beam.sdk.testing.UsesTestStream, org.apache.beam.sdk.testing.UsesSplittableParDo diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java index d8771de998f2..a0b015b57d32 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java @@ -17,6 +17,7 @@ */ package org.apache.beam.runners.flink.translation.wrappers.streaming.state; +import com.google.common.collect.Iterables; import com.google.common.collect.Lists; import java.nio.ByteBuffer; import java.util.Collections; @@ -33,6 +34,7 @@ import org.apache.beam.sdk.state.CombiningState; import org.apache.beam.sdk.state.MapState; import org.apache.beam.sdk.state.ReadableState; +import org.apache.beam.sdk.state.ReadableStates; import org.apache.beam.sdk.state.SetState; import org.apache.beam.sdk.state.State; import org.apache.beam.sdk.state.StateContext; @@ -48,6 +50,7 @@ import org.apache.flink.api.common.state.ListStateDescriptor; import org.apache.flink.api.common.state.MapStateDescriptor; import org.apache.flink.api.common.state.ValueStateDescriptor; +import org.apache.flink.api.common.typeutils.base.BooleanSerializer; import org.apache.flink.api.common.typeutils.base.StringSerializer; import org.apache.flink.runtime.state.KeyedStateBackend; import org.joda.time.Instant; @@ -127,8 +130,8 @@ public BagState bindBag( @Override public SetState bindSet( StateTag> address, Coder elemCoder) { - throw new UnsupportedOperationException( - String.format("%s is not supported", SetState.class.getSimpleName())); + return new FlinkSetState<>( + flinkStateBackend, address, namespace, elemCoder); } @Override @@ -875,24 +878,15 @@ private static class FlinkMapState implements MapState get(final KeyT input) { - return new ReadableState() { - @Override - public ValueT read() { - try { - return flinkStateBackend.getPartitionedState( + try { + return ReadableStates.immediate( + flinkStateBackend.getPartitionedState( namespace.stringKey(), StringSerializer.INSTANCE, - flinkStateDescriptor).get(input); - } catch (Exception e) { - throw new RuntimeException("Error get from state.", e); - } - } - - @Override - public ReadableState readLater() { - return this; - } - }; + flinkStateDescriptor).get(input)); + } catch (Exception e) { + throw new RuntimeException("Error get from state.", e); + } } @Override @@ -909,32 +903,22 @@ public void put(KeyT key, ValueT value) { @Override public ReadableState putIfAbsent(final KeyT key, final ValueT value) { - return new ReadableState() { - @Override - public ValueT read() { - try { - ValueT current = flinkStateBackend.getPartitionedState( - namespace.stringKey(), - StringSerializer.INSTANCE, - flinkStateDescriptor).get(key); - - if (current == null) { - flinkStateBackend.getPartitionedState( - namespace.stringKey(), - StringSerializer.INSTANCE, - flinkStateDescriptor).put(key, value); - } - return current; - } catch (Exception e) { - throw new RuntimeException("Error put kv to state.", e); - } - } + try { + ValueT current = flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor).get(key); - @Override - public ReadableState readLater() { - return this; + if (current == null) { + flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor).put(key, value); } - }; + return ReadableStates.immediate(current); + } catch (Exception e) { + throw new RuntimeException("Error put kv to state.", e); + } } @Override @@ -955,10 +939,11 @@ public ReadableState> keys() { @Override public Iterable read() { try { - return flinkStateBackend.getPartitionedState( + Iterable result = flinkStateBackend.getPartitionedState( namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor).keys(); + return result != null ? result : Collections.emptyList(); } catch (Exception e) { throw new RuntimeException("Error get map state keys.", e); } @@ -977,10 +962,11 @@ public ReadableState> values() { @Override public Iterable read() { try { - return flinkStateBackend.getPartitionedState( + Iterable result = flinkStateBackend.getPartitionedState( namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor).values(); + return result != null ? result : Collections.emptyList(); } catch (Exception e) { throw new RuntimeException("Error get map state values.", e); } @@ -999,10 +985,11 @@ public ReadableState>> entries() { @Override public Iterable> read() { try { - return flinkStateBackend.getPartitionedState( + Iterable> result = flinkStateBackend.getPartitionedState( namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor).entries(); + return result != null ? result : Collections.>emptyList(); } catch (Exception e) { throw new RuntimeException("Error get map state entries.", e); } @@ -1050,4 +1037,154 @@ public int hashCode() { } } + private static class FlinkSetState implements SetState { + + private final StateNamespace namespace; + private final StateTag> address; + private final MapStateDescriptor flinkStateDescriptor; + private final KeyedStateBackend flinkStateBackend; + + FlinkSetState( + KeyedStateBackend flinkStateBackend, + StateTag> address, + StateNamespace namespace, + Coder coder) { + this.namespace = namespace; + this.address = address; + this.flinkStateBackend = flinkStateBackend; + this.flinkStateDescriptor = new MapStateDescriptor<>(address.getId(), + new CoderTypeSerializer<>(coder), new BooleanSerializer()); + } + + @Override + public ReadableState contains(final T t) { + try { + Boolean result = flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor).get(t); + return ReadableStates.immediate(result != null ? result : false); + } catch (Exception e) { + throw new RuntimeException("Error contains value from state.", e); + } + } + + @Override + public ReadableState addIfAbsent(final T t) { + try { + org.apache.flink.api.common.state.MapState state = + flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor); + boolean alreadyContained = state.contains(t); + if (!alreadyContained) { + state.put(t, true); + } + return ReadableStates.immediate(!alreadyContained); + } catch (Exception e) { + throw new RuntimeException("Error addIfAbsent value to state.", e); + } + } + + @Override + public void remove(T t) { + try { + flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor).remove(t); + } catch (Exception e) { + throw new RuntimeException("Error remove value to state.", e); + } + } + + @Override + public SetState readLater() { + return this; + } + + @Override + public void add(T value) { + try { + flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor).put(value, true); + } catch (Exception e) { + throw new RuntimeException("Error add value to state.", e); + } + } + + @Override + public ReadableState isEmpty() { + return new ReadableState() { + @Override + public Boolean read() { + try { + Iterable result = flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor).keys(); + return result == null || Iterables.isEmpty(result); + } catch (Exception e) { + throw new RuntimeException("Error isEmpty from state.", e); + } + } + + @Override + public ReadableState readLater() { + return this; + } + }; + } + + @Override + public Iterable read() { + try { + Iterable result = flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor).keys(); + return result != null ? result : Collections.emptyList(); + } catch (Exception e) { + throw new RuntimeException("Error read from state.", e); + } + } + + @Override + public void clear() { + try { + flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor).clear(); + } catch (Exception e) { + throw new RuntimeException("Error clearing state.", e); + } + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + FlinkSetState that = (FlinkSetState) o; + + return namespace.equals(that.namespace) && address.equals(that.address); + + } + + @Override + public int hashCode() { + int result = namespace.hashCode(); + result = 31 * result + address.hashCode(); + return result; + } + } + } diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java index e7564ec914a2..b8d41de77b44 100644 --- a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java @@ -63,21 +63,4 @@ protected StateInternals createStateInternals() { } } - ///////////////////////// Unsupported tests \\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\ - - @Override - public void testSet() {} - - @Override - public void testSetIsEmpty() {} - - @Override - public void testMergeSetIntoSource() {} - - @Override - public void testMergeSetIntoNewNamespace() {} - - @Override - public void testMap() {} - } From a9bcd0f80fc1c5c60bf1e55558c6d508dbccd761 Mon Sep 17 00:00:00 2001 From: JingsongLi Date: Tue, 13 Jun 2017 10:15:33 +0800 Subject: [PATCH 3/3] Add set and map readable test to StateInternalsTest --- .../beam/runners/core/StateInternalsTest.java | 40 +++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/StateInternalsTest.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/StateInternalsTest.java index bf3156aad110..6011fb48aed6 100644 --- a/runners/core-java/src/test/java/org/apache/beam/runners/core/StateInternalsTest.java +++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/StateInternalsTest.java @@ -27,6 +27,7 @@ import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; +import com.google.common.collect.Iterables; import java.util.Arrays; import java.util.Map; import java.util.Objects; @@ -570,4 +571,43 @@ public void testMergeLatestWatermarkIntoSource() throws Exception { assertThat(value1.read(), equalTo(null)); assertThat(value2.read(), equalTo(null)); } + + @Test + public void testSetReadable() throws Exception { + SetState value = underTest.state(NAMESPACE_1, STRING_SET_ADDR); + + // test contains + ReadableState readable = value.contains("A"); + value.add("A"); + assertFalse(readable.read()); + + // test addIfAbsent + value.addIfAbsent("B"); + assertTrue(value.contains("B").read()); + } + + @Test + public void testMapReadable() throws Exception { + MapState value = underTest.state(NAMESPACE_1, STRING_MAP_ADDR); + + // test iterable, should just return a iterable view of the values contained in this map. + // The iterable is backed by the map, so changes to the map are reflected in the iterable. + ReadableState> keys = value.keys(); + ReadableState> values = value.values(); + ReadableState>> entries = value.entries(); + value.put("A", 1); + assertFalse(Iterables.isEmpty(keys.read())); + assertFalse(Iterables.isEmpty(values.read())); + assertFalse(Iterables.isEmpty(entries.read())); + + // test get + ReadableState get = value.get("B"); + value.put("B", 2); + assertNull(get.read()); + + // test addIfAbsent + value.putIfAbsent("C", 3); + assertThat(value.get("C").read(), equalTo(3)); + } + }