From eb6f5bc66f22ad664d9f433624c8a4ca8173b5d1 Mon Sep 17 00:00:00 2001 From: Kenneth Knowles Date: Thu, 7 Sep 2017 19:36:17 -0700 Subject: [PATCH] Key StateTable off id, not full StateTag --- .../apache/beam/runners/core/StateTable.java | 40 +++++++++++-- .../apache/beam/runners/core/StateTags.java | 13 ++++ .../core/TestInMemoryStateInternals.java | 6 +- .../beam/runners/core/ReduceFnTester.java | 22 +++++-- .../beam/runners/core/StateInternalsTest.java | 59 +++++++++++++++++++ .../CopyOnAccessInMemoryStateInternals.java | 10 +++- 6 files changed, 135 insertions(+), 15 deletions(-) diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/StateTable.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/StateTable.java index d996729a476a..fa858b0df87d 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/StateTable.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/StateTable.java @@ -17,20 +17,27 @@ */ package org.apache.beam.runners.core; +import com.google.common.base.Equivalence; import com.google.common.collect.HashBasedTable; import com.google.common.collect.Table; +import java.util.HashMap; import java.util.Map; import java.util.Set; +import javax.annotation.Nullable; import org.apache.beam.runners.core.StateTag.StateBinder; import org.apache.beam.sdk.state.State; import org.apache.beam.sdk.state.StateContext; /** * Table mapping {@code StateNamespace} and {@code StateTag} to a {@code State} instance. + * + *

Two {@link StateTag StateTags} with the same ID are considered equivalent. The remaining + * information carried by the {@link StateTag} is used to configure an empty state cell if it is not + * yet initialized. */ public abstract class StateTable { - private final Table, State> stateTable = + private final Table, State> stateTable = HashBasedTable.create(); /** @@ -40,7 +47,10 @@ public abstract class StateTable { */ public StateT get( StateNamespace namespace, StateTag tag, StateContext c) { - State storage = stateTable.get(namespace, tag); + + Equivalence.Wrapper tagById = StateTags.ID_EQUIVALENCE.wrap((StateTag) tag); + + @Nullable State storage = getOrNull(namespace, tagById, c); if (storage != null) { @SuppressWarnings("unchecked") StateT typedStorage = (StateT) storage; @@ -48,10 +58,20 @@ public StateT get( } StateT typedStorage = tag.bind(binderForNamespace(namespace, c)); - stateTable.put(namespace, tag, typedStorage); + stateTable.put(namespace, tagById, typedStorage); return typedStorage; } + /** + * Gets the {@link State} in the specified {@link StateNamespace} with the specified identifier or + * {@code null} if it is not yet present. + */ + @Nullable + public State getOrNull( + StateNamespace namespace, Equivalence.Wrapper tag, StateContext c) { + return stateTable.get(namespace, tag); + } + public void clearNamespace(StateNamespace namespace) { stateTable.rowKeySet().remove(namespace); } @@ -68,8 +88,18 @@ public boolean isNamespaceInUse(StateNamespace namespace) { return stateTable.containsRow(namespace); } - public Map, State> getTagsInUse(StateNamespace namespace) { - return stateTable.row(namespace); + public Map getTagsInUse(StateNamespace namespace) { + // Because of shading, Equivalence.Wrapper cannot be on the API surface; it won't work. + // If runners-core ceases to shade Guava then it can (all runners should shade runners-core + // anyhow) + Map, State> row = stateTable.row(namespace); + HashMap result = new HashMap<>(); + + for (Map.Entry, State> entry : row.entrySet()) { + result.put(entry.getKey().get(), entry.getValue()); + } + + return result; } public Set getNamespacesInUse() { diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/StateTags.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/StateTags.java index a98f47d02e02..da94ef289715 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/StateTags.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/StateTags.java @@ -17,6 +17,7 @@ */ package org.apache.beam.runners.core; +import com.google.common.base.Equivalence; import com.google.common.base.MoreObjects; import java.io.IOException; import java.io.Serializable; @@ -48,6 +49,18 @@ public class StateTags { private static final CoderRegistry STANDARD_REGISTRY = CoderRegistry.createDefault(); + public static final Equivalence ID_EQUIVALENCE = new Equivalence() { + @Override + protected boolean doEquivalent(StateTag a, StateTag b) { + return a.getId().equals(b.getId()); + } + + @Override + protected int doHash(StateTag stateTag) { + return stateTag.getId().hashCode(); + } + }; + /** @deprecated for migration purposes only */ @Deprecated private static StateBinder adaptTagBinder(final StateTag.StateBinder binder) { diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/TestInMemoryStateInternals.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/TestInMemoryStateInternals.java index 2052c039f80a..ee8d56065086 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/TestInMemoryStateInternals.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/TestInMemoryStateInternals.java @@ -32,9 +32,9 @@ public TestInMemoryStateInternals(K key) { super(key); } - public Set> getTagsInUse(StateNamespace namespace) { - Set> inUse = new HashSet<>(); - for (Map.Entry, State> entry : + public Set getTagsInUse(StateNamespace namespace) { + Set inUse = new HashSet<>(); + for (Map.Entry entry : inMemoryState.getTagsInUse(namespace).entrySet()) { if (!isEmptyForTesting(entry.getValue())) { inUse.add(entry.getKey()); diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/ReduceFnTester.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/ReduceFnTester.java index 7ca96b9b549d..6f7a4f430016 100644 --- a/runners/core-java/src/test/java/org/apache/beam/runners/core/ReduceFnTester.java +++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/ReduceFnTester.java @@ -21,6 +21,7 @@ import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; +import com.google.common.base.Equivalence; import com.google.common.base.Function; import com.google.common.collect.FluentIterable; import com.google.common.collect.ImmutableList; @@ -365,28 +366,41 @@ public final void assertHasOnlyGlobalAndPaneInfoFor(W... expectedWindows) { private void assertHasOnlyGlobalAndAllowedTags( Set expectedWindows, Set> allowedTags) { Set expectedWindowsSet = new HashSet<>(); + + Set> allowedEquivalentTags = new HashSet<>(); + for (StateTag tag : allowedTags) { + allowedEquivalentTags.add(StateTags.ID_EQUIVALENCE.wrap(tag)); + } + for (W expectedWindow : expectedWindows) { expectedWindowsSet.add(windowNamespace(expectedWindow)); } - Map>> actualWindows = new HashMap<>(); + Map>> actualWindows = new HashMap<>(); for (StateNamespace namespace : stateInternals.getNamespacesInUse()) { if (namespace instanceof StateNamespaces.GlobalNamespace) { continue; } else if (namespace instanceof StateNamespaces.WindowNamespace) { - Set> tagsInUse = stateInternals.getTagsInUse(namespace); + Set> tagsInUse = new HashSet<>(); + for (StateTag tag : stateInternals.getTagsInUse(namespace)) { + tagsInUse.add(StateTags.ID_EQUIVALENCE.wrap(tag)); + } if (tagsInUse.isEmpty()) { continue; } actualWindows.put(namespace, tagsInUse); - Set> unexpected = Sets.difference(tagsInUse, allowedTags); + Set> unexpected = + Sets.difference(tagsInUse, allowedEquivalentTags); if (unexpected.isEmpty()) { continue; } else { fail(namespace + " has unexpected states: " + tagsInUse); } } else if (namespace instanceof StateNamespaces.WindowAndTriggerNamespace) { - Set> tagsInUse = stateInternals.getTagsInUse(namespace); + Set> tagsInUse = new HashSet<>(); + for (StateTag tag : stateInternals.getTagsInUse(namespace)) { + tagsInUse.add(StateTags.ID_EQUIVALENCE.wrap(tag)); + } assertTrue(namespace + " contains " + tagsInUse, tagsInUse.isEmpty()); } else { fail("Unrecognized namespace " + namespace); diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/StateInternalsTest.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/StateInternalsTest.java index ae07fe6b1ced..eb438bac3cea 100644 --- a/runners/core-java/src/test/java/org/apache/beam/runners/core/StateInternalsTest.java +++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/StateInternalsTest.java @@ -28,9 +28,15 @@ import static org.junit.Assert.assertTrue; import com.google.common.collect.Iterables; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; import java.util.Arrays; +import java.util.List; import java.util.Map; import java.util.Objects; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.CoderException; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.coders.VarIntCoder; import org.apache.beam.sdk.state.BagState; @@ -78,6 +84,13 @@ public abstract class StateInternalsTest { private static final StateTag WATERMARK_EOW_ADDR = StateTags.watermarkStateInternal("watermark", TimestampCombiner.END_OF_WINDOW); + // Two distinct tags because they have non-equals() coders + private static final StateTag> STRING_BAG_ADDR1 = + StateTags.bag("badStringBag", new StringCoderWithIdentityEquality()); + + private static final StateTag> STRING_BAG_ADDR2 = + StateTags.bag("badStringBag", new StringCoderWithIdentityEquality()); + private StateInternals underTest; @Before @@ -610,4 +623,50 @@ public void testMapReadable() throws Exception { assertThat(value.get("C").read(), equalTo(3)); } + @Test + public void testBagWithBadCoderEquality() throws Exception { + // Ensure two instances of the bad coder are distinct; models user who fails to + // override equals() or inherit from CustomCoder for StructuredCoder + assertThat( + new StringCoderWithIdentityEquality(), not(equalTo(new StringCoderWithIdentityEquality()))); + + BagState state1 = underTest.state(NAMESPACE_1, STRING_BAG_ADDR1); + state1.add("hello"); + + BagState state2 = underTest.state(NAMESPACE_1, STRING_BAG_ADDR2); + assertThat(state2.read(), containsInAnyOrder("hello")); + } + + private static class StringCoderWithIdentityEquality extends Coder { + + private final StringUtf8Coder realCoder = StringUtf8Coder.of(); + + @Override + public void encode(String value, OutputStream outStream) throws CoderException, IOException { + realCoder.encode(value, outStream); + } + + @Override + public String decode(InputStream inStream) throws CoderException, IOException { + return realCoder.decode(inStream); + } + + @Override + public List> getCoderArguments() { + return null; + } + + @Override + public void verifyDeterministic() throws NonDeterministicException {} + + @Override + public boolean equals(Object other) { + return other == this; + } + + @Override + public int hashCode() { + return super.hashCode(); + } + } } diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/CopyOnAccessInMemoryStateInternals.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/CopyOnAccessInMemoryStateInternals.java index 3c701c77695a..848bf712da80 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/CopyOnAccessInMemoryStateInternals.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/CopyOnAccessInMemoryStateInternals.java @@ -264,8 +264,12 @@ public CopyOnBindBinderFactory(Optional underlying) { } private boolean containedInUnderlying(StateNamespace namespace, StateTag tag) { - return underlying.isPresent() && underlying.get().isNamespaceInUse(namespace) - && underlying.get().getTagsInUse(namespace).containsKey(tag); + return underlying.isPresent() + && underlying.get().isNamespaceInUse(namespace) + && underlying + .get() + .getTagsInUse(namespace) + .containsKey(tag); } @Override @@ -388,7 +392,7 @@ public ReadThroughBinderFactory(StateTable underlying) { public Instant readThroughAndGetEarliestHold(StateTable readTo) { Instant earliestHold = BoundedWindow.TIMESTAMP_MAX_VALUE; for (StateNamespace namespace : underlying.getNamespacesInUse()) { - for (Map.Entry, ? extends State> existingState : + for (Map.Entry existingState : underlying.getTagsInUse(namespace).entrySet()) { if (!((InMemoryState) existingState.getValue()).isCleared()) { // Only read through non-cleared values to ensure that completed windows are