From 1edb7f160444fad08fc1a24689344d2ea76faad6 Mon Sep 17 00:00:00 2001 From: Luke Cwik Date: Tue, 29 Aug 2017 10:45:04 -0700 Subject: [PATCH 1/2] [BEAM-1347] Create value state, combining state, and bag state views over the BagUserState. Also bind the state persistence to the end of finishBundle. --- .../beam/fn/harness/FnApiDoFnRunner.java | 380 +++++++++++++++++- .../beam/fn/harness/FnApiDoFnRunnerTest.java | 229 +++++++++++ .../harness/state/FakeBeamFnStateClient.java | 2 +- 3 files changed, 605 insertions(+), 6 deletions(-) diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java index d325bb29d318..92ba1f3e3f8a 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java @@ -18,45 +18,77 @@ package org.apache.beam.fn.harness; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkState; import com.google.auto.service.AutoService; +import com.google.common.base.Suppliers; import com.google.common.collect.Collections2; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableMultimap; import com.google.common.collect.Multimap; import com.google.protobuf.ByteString; +import java.io.IOException; +import java.util.ArrayList; import java.util.Collection; +import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.Map; import java.util.Objects; import java.util.function.Consumer; +import java.util.function.Function; import java.util.function.Supplier; import org.apache.beam.fn.harness.data.BeamFnDataClient; import org.apache.beam.fn.harness.fn.ThrowingConsumer; import org.apache.beam.fn.harness.fn.ThrowingRunnable; +import org.apache.beam.fn.harness.state.BagUserState; import org.apache.beam.fn.harness.state.BeamFnStateClient; +import org.apache.beam.fn.v1.BeamFnApi.StateKey; +import org.apache.beam.fn.v1.BeamFnApi.StateRequest; +import org.apache.beam.fn.v1.BeamFnApi.StateRequest.Builder; import org.apache.beam.runners.core.DoFnRunner; import org.apache.beam.runners.core.construction.ParDoTranslation; import org.apache.beam.runners.dataflow.util.DoFnInfo; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.common.runner.v1.RunnerApi; import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.state.BagState; +import org.apache.beam.sdk.state.CombiningState; +import org.apache.beam.sdk.state.MapState; +import org.apache.beam.sdk.state.ReadableState; +import org.apache.beam.sdk.state.ReadableStates; +import org.apache.beam.sdk.state.SetState; import org.apache.beam.sdk.state.State; +import org.apache.beam.sdk.state.StateBinder; +import org.apache.beam.sdk.state.StateContext; +import org.apache.beam.sdk.state.StateSpec; import org.apache.beam.sdk.state.TimeDomain; import org.apache.beam.sdk.state.Timer; +import org.apache.beam.sdk.state.ValueState; +import org.apache.beam.sdk.state.WatermarkHoldState; +import org.apache.beam.sdk.transforms.Combine.CombineFn; +import org.apache.beam.sdk.transforms.CombineWithContext.CombineFnWithContext; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.DoFn.OnTimerContext; import org.apache.beam.sdk.transforms.DoFn.ProcessContext; import org.apache.beam.sdk.transforms.reflect.DoFnInvoker; import org.apache.beam.sdk.transforms.reflect.DoFnInvokers; import org.apache.beam.sdk.transforms.reflect.DoFnSignature; +import org.apache.beam.sdk.transforms.reflect.DoFnSignature.StateDeclaration; import org.apache.beam.sdk.transforms.reflect.DoFnSignatures; import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.PaneInfo; +import org.apache.beam.sdk.transforms.windowing.TimestampCombiner; +import org.apache.beam.sdk.util.CombineFnUtil; import org.apache.beam.sdk.util.SerializableUtils; import org.apache.beam.sdk.util.UserCodeException; import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.util.WindowedValue.WindowedValueCoder; +import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.WindowingStrategy; @@ -141,7 +173,13 @@ public DoFnRunner createRunnerForPTransform( @SuppressWarnings({"unchecked", "rawtypes"}) DoFnRunner runner = new FnApiDoFnRunner<>( pipelineOptions, + beamFnStateClient, + pTransformId, + processBundleInstructionId, doFnInfo.getDoFn(), + WindowedValue.getFullCoder( + doFnInfo.getInputCoder(), + doFnInfo.getWindowingStrategy().getWindowFn().windowCoder()), (Collection>>) (Collection) tagToOutputMap.get(doFnInfo.getOutputMap().get(doFnInfo.getMainOutput())), tagToOutputMap, @@ -162,42 +200,68 @@ public DoFnRunner createRunnerForPTransform( ////////////////////////////////////////////////////////////////////////////////////////////////// private final PipelineOptions pipelineOptions; + private final BeamFnStateClient beamFnStateClient; + private final String ptransformId; + private final Supplier processBundleInstructionId; private final DoFn doFn; + private final WindowedValueCoder inputCoder; private final Collection>> mainOutputConsumers; private final Multimap, ThrowingConsumer>> outputMap; + private final WindowingStrategy windowingStrategy; + private final DoFnSignature doFnSignature; private final DoFnInvoker doFnInvoker; + private final StateBinder stateBinder; private final StartBundleContext startBundleContext; private final ProcessBundleContext processBundleContext; private final FinishBundleContext finishBundleContext; - private final WindowingStrategy windowingStrategy; - private final DoFnSignature doFnSignature; + private final Collection stateFinalizers; /** - * The lifetime of this member is only valid during {@link #processElement(WindowedValue)}. + * The lifetime of this member is only valid during {@link #processElement} + * and is null otherwise. */ private WindowedValue currentElement; /** - * The lifetime of this member is only valid during {@link #processElement(WindowedValue)}. + * The lifetime of this member is only valid during {@link #processElement} + * and is null otherwise. */ private BoundedWindow currentWindow; + /** + * This member should only be accessed indirectly by calling + * {@link #createOrUseCachedBagUserStateKey} and is only valid during {@link #processElement} + * and is null otherwise. + */ + private StateKey.BagUserState cachedPartialBagUserStateKey; + + FnApiDoFnRunner( PipelineOptions pipelineOptions, + BeamFnStateClient beamFnStateClient, + String ptransformId, + Supplier processBundleInstructionId, DoFn doFn, + WindowedValueCoder inputCoder, Collection>> mainOutputConsumers, Multimap, ThrowingConsumer>> outputMap, WindowingStrategy windowingStrategy) { this.pipelineOptions = pipelineOptions; + this.beamFnStateClient = beamFnStateClient; + this.ptransformId = ptransformId; + this.processBundleInstructionId = processBundleInstructionId; this.doFn = doFn; + this.inputCoder = inputCoder; this.mainOutputConsumers = mainOutputConsumers; this.outputMap = outputMap; this.windowingStrategy = windowingStrategy; this.doFnSignature = DoFnSignatures.signatureForDoFn(doFn); this.doFnInvoker = DoFnInvokers.invokerFor(doFn); + this.stateBinder = new BeamFnStateBinder(); this.startBundleContext = new StartBundleContext(); this.processBundleContext = new ProcessBundleContext(); this.finishBundleContext = new FinishBundleContext(); + this.stateFinalizers = new ArrayList<>(); } @Override @@ -218,6 +282,7 @@ public void processElement(WindowedValue elem) { } finally { currentElement = null; currentWindow = null; + cachedPartialBagUserStateKey = null; } } @@ -233,6 +298,18 @@ public void onTimer( @Override public void finishBundle() { doFnInvoker.invokeFinishBundle(finishBundleContext); + + // Persist all dirty state cells + try { + for (ThrowingRunnable runnable : stateFinalizers) { + runnable.run(); + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new IllegalStateException(e); + } catch (Exception e) { + throw new IllegalStateException(e); + } } /** @@ -367,7 +444,15 @@ public RestrictionTracker restrictionTracker() { @Override public State state(String stateId) { - throw new UnsupportedOperationException("TODO: Add support for state"); + StateDeclaration stateDeclaration = doFnSignature.stateDeclarations().get(stateId); + checkNotNull(stateDeclaration, "No state declaration found for %s", stateId); + StateSpec spec; + try { + spec = (StateSpec) stateDeclaration.field().get(doFn); + } catch (IllegalAccessException e) { + throw new RuntimeException(e); + } + return spec.bind(stateId, stateBinder); } @Override @@ -545,4 +630,289 @@ public void output(TupleTag tag, T output, Instant timestamp, BoundedWind WindowedValue.of(output, timestamp, window, PaneInfo.NO_FIRING)); } } + + /** + * A {@link StateBinder} that uses the Beam Fn State API to read and write user state. + * + *

TODO: Add support for {@link #bindMap} and {@link #bindSet}. Note that + * {@link #bindWatermark} should never be implemented. + */ + private class BeamFnStateBinder implements StateBinder { + private final Map stateObjectCache = new HashMap<>(); + + @Override + public ValueState bindValue(String id, StateSpec> spec, Coder coder) { + return (ValueState) stateObjectCache.computeIfAbsent( + createOrUseCachedBagUserStateKey(id), + new Function() { + @Override + public Object apply(StateKey.BagUserState s) { + return new ValueState() { + private final BagUserState impl = createBagUserState(id, coder); + + @Override + public void clear() { + impl.clear(); + } + + @Override + public void write(T input) { + impl.clear(); + impl.append(input); + } + + @Override + public T read() { + Iterator value = impl.get().iterator(); + if (value.hasNext()) { + return value.next(); + } else { + return null; + } + } + + @Override + public ValueState readLater() { + // TODO: Support prefetching. + return this; + } + }; + } + }); + } + + @Override + public BagState bindBag(String id, StateSpec> spec, Coder elemCoder) { + return (BagState) stateObjectCache.computeIfAbsent( + createOrUseCachedBagUserStateKey(id), + new Function() { + @Override + public Object apply(StateKey.BagUserState s) { + return new BagState() { + private final BagUserState impl = createBagUserState(id, elemCoder); + + @Override + public void add(T value) { + impl.append(value); + } + + @Override + public ReadableState isEmpty() { + return ReadableStates.immediate(!impl.get().iterator().hasNext()); + } + + @Override + public Iterable read() { + return impl.get(); + } + + @Override + public BagState readLater() { + // TODO: Support prefetching. + return this; + } + + @Override + public void clear() { + impl.clear(); + } + }; + } + }); + } + + @Override + public SetState bindSet(String id, StateSpec> spec, Coder elemCoder) { + throw new UnsupportedOperationException("TODO: Add support for a map state to the Fn API."); + } + + @Override + public MapState bindMap(String id, + StateSpec> spec, Coder mapKeyCoder, + Coder mapValueCoder) { + throw new UnsupportedOperationException("TODO: Add support for a map state to the Fn API."); + } + + @Override + public CombiningState bindCombining( + String id, + StateSpec> spec, Coder accumCoder, + CombineFn combineFn) { + return (CombiningState) stateObjectCache.computeIfAbsent( + createOrUseCachedBagUserStateKey(id), + new Function() { + @Override + public Object apply(StateKey.BagUserState s) { + // TODO: Support squashing accumulators depending on whether we know of all + // remote accumulators and local accumulators or just local accumulators. + return new CombiningState() { + private final BagUserState impl = createBagUserState(id, accumCoder); + + @Override + public AccumT getAccum() { + Iterator iterator = impl.get().iterator(); + if (iterator.hasNext()) { + return iterator.next(); + } + return combineFn.createAccumulator(); + } + + @Override + public void addAccum(AccumT accum) { + Iterator iterator = impl.get().iterator(); + + // Only merge if there was a prior value + if (iterator.hasNext()) { + accum = combineFn.mergeAccumulators(ImmutableList.of(iterator.next(), accum)); + // Since there was a prior value, we need to clear. + impl.clear(); + } + + impl.append(accum); + } + + @Override + public AccumT mergeAccumulators(Iterable accumulators) { + return combineFn.mergeAccumulators(accumulators); + } + + @Override + public CombiningState readLater() { + return this; + } + + @Override + public OutputT read() { + Iterator iterator = impl.get().iterator(); + if (iterator.hasNext()) { + return combineFn.extractOutput(iterator.next()); + } + return combineFn.defaultValue(); + } + + @Override + public void add(InputT value) { + AccumT newAccumulator = combineFn.addInput(getAccum(), value); + impl.clear(); + impl.append(newAccumulator); + } + + @Override + public ReadableState isEmpty() { + return ReadableStates.immediate(!impl.get().iterator().hasNext()); + } + + @Override + public void clear() { + impl.clear(); + } + }; + } + }); + } + + @Override + public CombiningState + bindCombiningWithContext( + String id, + StateSpec> spec, + Coder accumCoder, + CombineFnWithContext combineFn) { + return (CombiningState) stateObjectCache.computeIfAbsent( + createOrUseCachedBagUserStateKey(id), + new Function() { + @Override + public Object apply(StateKey.BagUserState s) { + return bindCombining(id, spec, accumCoder, CombineFnUtil.bindContext(combineFn, + new StateContext() { + @Override + public PipelineOptions getPipelineOptions() { + return pipelineOptions; + } + + @Override + public T sideInput(PCollectionView view) { + return sideInput(view); + } + + @Override + public BoundedWindow window() { + return currentWindow; + } + })); + } + }); + } + + /** + * @deprecated The Fn API has no plans to implement WatermarkHoldState as of this writing + * and is waiting on resolution of BEAM-2535. + */ + @Override + @Deprecated + public WatermarkHoldState bindWatermark(String id, StateSpec spec, + TimestampCombiner timestampCombiner) { + throw new UnsupportedOperationException("WatermarkHoldState is unsupported by the Fn API."); + } + + private BagUserState createBagUserState(String id, Coder coder) { + BagUserState rval = new BagUserState( + beamFnStateClient, + id, + coder, + new Supplier() { + /** Memoizes the partial state key for the lifetime of the {@link BagUserState}. */ + private final Supplier memoizingSupplier = + Suppliers.memoize(() -> createOrUseCachedBagUserStateKey(id))::get; + + @Override + public Builder get() { + return StateRequest.newBuilder() + .setInstructionReference(processBundleInstructionId.get()) + .setStateKey(StateKey.newBuilder() + .setBagUserState(memoizingSupplier.get())); + } + }); + stateFinalizers.add(rval::asyncClose); + return rval; + } + } + + /** + * Memoizes a partially built {@link StateKey} saving on the encoding cost of the key and + * window across multiple state cells for the lifetime of {@link #processElement}. + * + *

This should only be called during {@link #processElement}. + */ + private StateKey.BagUserState createOrUseCachedBagUserStateKey(String id) { + if (cachedPartialBagUserStateKey == null) { + checkState(currentElement.getValue() instanceof KV, + "Accessing state in unkeyed context. Current element is not a KV: %s.", + currentElement); + checkState(inputCoder.getCoderArguments().get(0) instanceof KvCoder, + "Accessing state in unkeyed context. No keyed coder found."); + + ByteString.Output encodedKeyOut = ByteString.newOutput(); + + Coder keyCoder = ((KvCoder) inputCoder.getValueCoder()).getKeyCoder(); + try { + keyCoder.encode(((KV) currentElement.getValue()).getKey(), encodedKeyOut); + } catch (IOException e) { + throw new IllegalStateException(e); + } + + ByteString.Output encodedWindowOut = ByteString.newOutput(); + try { + windowingStrategy.getWindowFn().windowCoder().encode(currentWindow, encodedWindowOut); + } catch (IOException e) { + throw new IllegalStateException(e); + } + + cachedPartialBagUserStateKey = StateKey.BagUserState.newBuilder() + .setPtransformId(ptransformId) + .setKey(encodedKeyOut.toByteString()) + .setWindow(encodedWindowOut.toByteString()).buildPartial(); + } + return cachedPartialBagUserStateKey.toBuilder().setUserStateId(id).build(); + } } diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java index ebec608f7fb3..4aa8080d0acc 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java @@ -22,6 +22,8 @@ import static org.apache.beam.sdk.util.WindowedValue.valueInGlobalWindow; import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.empty; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThat; import static org.junit.Assert.fail; @@ -32,22 +34,36 @@ import com.google.common.collect.Iterables; import com.google.common.collect.Multimap; import com.google.protobuf.ByteString; +import java.io.IOException; import java.util.ArrayList; import java.util.List; import java.util.ServiceLoader; import org.apache.beam.fn.harness.PTransformRunnerFactory.Registrar; import org.apache.beam.fn.harness.fn.ThrowingConsumer; import org.apache.beam.fn.harness.fn.ThrowingRunnable; +import org.apache.beam.fn.harness.state.FakeBeamFnStateClient; +import org.apache.beam.fn.v1.BeamFnApi.StateKey; import org.apache.beam.runners.core.construction.ParDoTranslation; import org.apache.beam.runners.dataflow.util.DoFnInfo; +import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.common.runner.v1.RunnerApi; import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.state.BagState; +import org.apache.beam.sdk.state.CombiningState; +import org.apache.beam.sdk.state.StateSpec; +import org.apache.beam.sdk.state.StateSpecs; +import org.apache.beam.sdk.state.ValueState; +import org.apache.beam.sdk.transforms.Combine.CombineFn; +import org.apache.beam.sdk.transforms.CombineWithContext.CombineFnWithContext; +import org.apache.beam.sdk.transforms.CombineWithContext.Context; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.GlobalWindow; +import org.apache.beam.sdk.util.CoderUtils; import org.apache.beam.sdk.util.SerializableUtils; import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.WindowingStrategy; import org.hamcrest.collection.IsMapContaining; @@ -58,6 +74,9 @@ /** Tests for {@link FnApiDoFnRunner}. */ @RunWith(JUnit4.class) public class FnApiDoFnRunnerTest { + + public static final String TEST_PTRANSFORM_ID = "pTransformId"; + private static class TestDoFn extends DoFn { private static final TupleTag mainOutput = new TupleTag<>("mainOutput"); private static final TupleTag additionalOutput = new TupleTag<>("output"); @@ -164,6 +183,216 @@ public void testCreatingAndProcessingDoFn() throws Exception { mainOutputValues.clear(); } + private static class ConcatCombineFn extends CombineFn { + @Override + public String createAccumulator() { + return ""; + } + + @Override + public String addInput(String accumulator, String input) { + return accumulator.concat(input); + } + + @Override + public String mergeAccumulators(Iterable accumulators) { + StringBuilder builder = new StringBuilder(); + for (String value : accumulators) { + builder.append(value); + } + return builder.toString(); + } + + @Override + public String extractOutput(String accumulator) { + return accumulator; + } + } + + private static class ConcatCombineFnWithContext + extends CombineFnWithContext { + @Override + public String createAccumulator(Context c) { + return ""; + } + + @Override + public String addInput(String accumulator, String input, Context c) { + return accumulator.concat(input); + } + + @Override + public String mergeAccumulators(Iterable accumulators, Context c) { + StringBuilder builder = new StringBuilder(); + for (String value : accumulators) { + builder.append(value); + } + return builder.toString(); + } + + @Override + public String extractOutput(String accumulator, Context c) { + return accumulator; + } + } + + private static class TestStatefulDoFn extends DoFn, String> { + private static final TupleTag mainOutput = new TupleTag<>("mainOutput"); + private static final TupleTag additionalOutput = new TupleTag<>("output"); + + @StateId("value") + private final StateSpec> valueStateSpec = + StateSpecs.value(StringUtf8Coder.of()); + @StateId("bag") + private final StateSpec> bagStateSpec = + StateSpecs.bag(StringUtf8Coder.of()); + @StateId("combine") + private final StateSpec> combiningStateSpec = + StateSpecs.combining(StringUtf8Coder.of(), new ConcatCombineFn()); + @StateId("combineWithContext") + private final StateSpec> combiningWithContextStateSpec = + StateSpecs.combining(StringUtf8Coder.of(), new ConcatCombineFnWithContext()); + + @ProcessElement + public void processElement(ProcessContext context, + @StateId("value") ValueState valueState, + @StateId("bag") BagState bagState, + @StateId("combine") CombiningState combiningState, + @StateId("combineWithContext") + CombiningState combiningWithContextState) { + context.output("value:" + valueState.read()); + valueState.write(context.element().getValue()); + + context.output("bag:" + Iterables.toString(bagState.read())); + bagState.add(context.element().getValue()); + + context.output("combine:" + combiningState.read()); + combiningState.add(context.element().getValue()); + + context.output("combineWithContext:" + combiningWithContextState.read()); + combiningWithContextState.add(context.element().getValue()); + } + } + + @Test + public void testUsingUserState() throws Exception { + String mainOutputId = "101"; + + DoFnInfo doFnInfo = DoFnInfo.forFn( + new TestStatefulDoFn(), + WindowingStrategy.globalDefault(), + ImmutableList.of(), + KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of()), + Long.parseLong(mainOutputId), + ImmutableMap.of(Long.parseLong(mainOutputId), new TupleTag("mainOutput"))); + RunnerApi.FunctionSpec functionSpec = + RunnerApi.FunctionSpec.newBuilder() + .setUrn(ParDoTranslation.CUSTOM_JAVA_DO_FN_URN) + .setPayload(ByteString.copyFrom(SerializableUtils.serializeToByteArray(doFnInfo))) + .build(); + RunnerApi.PTransform pTransform = RunnerApi.PTransform.newBuilder() + .setSpec(functionSpec) + .putInputs("input", "inputTarget") + .putOutputs(mainOutputId, "mainOutputTarget") + .build(); + + FakeBeamFnStateClient fakeClient = new FakeBeamFnStateClient(ImmutableMap.of( + key("value", "X"), encode("X0"), + key("bag", "X"), encode("X0"), + key("combine", "X"), encode("X0"), + key("combineWithContext", "X"), encode("X0") + )); + + List> mainOutputValues = new ArrayList<>(); + Multimap>> consumers = HashMultimap.create(); + consumers.put("mainOutputTarget", + (ThrowingConsumer) (ThrowingConsumer>) mainOutputValues::add); + List startFunctions = new ArrayList<>(); + List finishFunctions = new ArrayList<>(); + + new FnApiDoFnRunner.Factory<>().createRunnerForPTransform( + PipelineOptionsFactory.create(), + null /* beamFnDataClient */, + fakeClient, + TEST_PTRANSFORM_ID, + pTransform, + Suppliers.ofInstance("57L")::get, + ImmutableMap.of(), + ImmutableMap.of(), + consumers, + startFunctions::add, + finishFunctions::add); + + Iterables.getOnlyElement(startFunctions).run(); + mainOutputValues.clear(); + + assertThat(consumers.keySet(), containsInAnyOrder("inputTarget", "mainOutputTarget")); + + // Ensure that bag user state that is initially empty or populated works. + // Ensure that the key order does not matter when we traverse over KV pairs. + ThrowingConsumer> mainInput = + Iterables.getOnlyElement(consumers.get("inputTarget")); + mainInput.accept(valueInGlobalWindow(KV.of("X", "X1"))); + mainInput.accept(valueInGlobalWindow(KV.of("Y", "Y1"))); + mainInput.accept(valueInGlobalWindow(KV.of("X", "X2"))); + mainInput.accept(valueInGlobalWindow(KV.of("Y", "Y2"))); + assertThat(mainOutputValues, contains( + valueInGlobalWindow("value:X0"), + valueInGlobalWindow("bag:[X0]"), + valueInGlobalWindow("combine:X0"), + valueInGlobalWindow("combineWithContext:X0"), + valueInGlobalWindow("value:null"), + valueInGlobalWindow("bag:[]"), + valueInGlobalWindow("combine:"), + valueInGlobalWindow("combineWithContext:"), + valueInGlobalWindow("value:X1"), + valueInGlobalWindow("bag:[X0, X1]"), + valueInGlobalWindow("combine:X0X1"), + valueInGlobalWindow("combineWithContext:X0X1"), + valueInGlobalWindow("value:Y1"), + valueInGlobalWindow("bag:[Y1]"), + valueInGlobalWindow("combine:Y1"), + valueInGlobalWindow("combineWithContext:Y1"))); + mainOutputValues.clear(); + + Iterables.getOnlyElement(finishFunctions).run(); + assertThat(mainOutputValues, empty()); + + assertEquals( + ImmutableMap.builder() + .put(key("value", "X"), encode("X2")) + .put(key("bag", "X"), encode("X0", "X1", "X2")) + .put(key("combine", "X"), encode("X0X1X2")) + .put(key("combineWithContext", "X"), encode("X0X1X2")) + .put(key("value", "Y"), encode("Y2")) + .put(key("bag", "Y"), encode("Y1", "Y2")) + .put(key("combine", "Y"), encode("Y1Y2")) + .put(key("combineWithContext", "Y"), encode("Y1Y2")) + .build(), + fakeClient.getData()); + mainOutputValues.clear(); + } + + /** Produces a {@link StateKey} for the test PTransform id in the Global Window. */ + private StateKey key(String userStateId, String key) throws IOException { + return StateKey.newBuilder().setBagUserState( + StateKey.BagUserState.newBuilder() + .setPtransformId(TEST_PTRANSFORM_ID) + .setUserStateId(userStateId) + .setKey(encode(key)) + .setWindow(ByteString.copyFrom( + CoderUtils.encodeToByteArray(GlobalWindow.Coder.INSTANCE, GlobalWindow.INSTANCE)))) + .build(); + } + + private ByteString encode(String ... values) throws IOException { + ByteString.Output out = ByteString.newOutput(); + for (String value : values) { + StringUtf8Coder.of().encode(value, out); + } + return out.toByteString(); + } + @Test public void testRegistration() { for (Registrar registrar : diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/FakeBeamFnStateClient.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/FakeBeamFnStateClient.java index d26020743512..60080e13c7fd 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/FakeBeamFnStateClient.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/FakeBeamFnStateClient.java @@ -69,7 +69,7 @@ public void handle(StateRequest.Builder requestBuilder, switch (request.getRequestCase()) { case GET: // Chunk gets into 5 byte return blocks - ByteString byteString = data.get(request.getStateKey()); + ByteString byteString = data.getOrDefault(request.getStateKey(), ByteString.EMPTY); int block = 0; if (request.getGet().getContinuationToken().size() > 0) { block = Integer.parseInt(request.getGet().getContinuationToken().toStringUtf8()); From 82927c5b2e2b7cd68502b31dae57cfc84896ff34 Mon Sep 17 00:00:00 2001 From: Luke Cwik Date: Wed, 30 Aug 2017 09:44:40 -0700 Subject: [PATCH 2/2] fixup! Fix infinite loop --- .../main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java index 92ba1f3e3f8a..c36164771e9a 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java @@ -832,7 +832,7 @@ public PipelineOptions getPipelineOptions() { @Override public T sideInput(PCollectionView view) { - return sideInput(view); + return processBundleContext.sideInput(view); } @Override