From c8478fe1fe107b842d3cfa56b652d740fdf0c18b Mon Sep 17 00:00:00 2001 From: Kenneth Knowles Date: Thu, 25 May 2017 07:25:08 -0700 Subject: [PATCH 1/6] Mark CombineFnWithContext StateSpecs internal --- .../java/org/apache/beam/sdk/state/StateSpecs.java | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/state/StateSpecs.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/state/StateSpecs.java index 7b7138489997..5a2a1b6b9c72 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/state/StateSpecs.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/state/StateSpecs.java @@ -76,7 +76,9 @@ StateSpec> combining( } /** - * Create a {@link StateSpec} for a {@link CombiningState} which uses a {@link + * For internal use only; no backwards compatibility guarantees + * + *

Create a {@link StateSpec} for a {@link CombiningState} which uses a {@link * CombineFnWithContext} to automatically merge multiple values of type {@code InputT} into a * single resulting {@code OutputT}. * @@ -84,6 +86,7 @@ StateSpec> combining( * * @see #combining(Coder, CombineFnWithContext) */ + @Internal public static StateSpec> combining( CombineFnWithContext combineFn) { @@ -105,11 +108,14 @@ StateSpec> combining( } /** - * Identical to {@link #combining(CombineFnWithContext)}, but with an accumulator coder explicitly - * supplied. + * For internal use only; no backwards compatibility guarantees + * + *

Identical to {@link #combining(CombineFnWithContext)}, but with an accumulator coder + * explicitly supplied. * *

If automatic coder inference fails, use this method. */ + @Internal public static StateSpec> combining( Coder accumCoder, CombineFnWithContext combineFn) { From b0dc523c72a68e870392fbac8ff9f3a87459ab22 Mon Sep 17 00:00:00 2001 From: Kenneth Knowles Date: Thu, 25 May 2017 13:02:15 -0700 Subject: [PATCH 2/6] Allow translation to throw IOException --- .../beam/runners/core/construction/PTransformTranslation.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformTranslation.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformTranslation.java index 9f5f3b50b723..00ea55e6e1ed 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformTranslation.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformTranslation.java @@ -138,7 +138,8 @@ public static String urnForTransform(PTransform transform) { */ public interface TransformPayloadTranslator> { String getUrn(T transform); - FunctionSpec translate(AppliedPTransform application, SdkComponents components); + FunctionSpec translate(AppliedPTransform application, SdkComponents components) + throws IOException; } /** From 9497e5eaecf5d7eb7f18709935c183b03116f75f Mon Sep 17 00:00:00 2001 From: Kenneth Knowles Date: Thu, 25 May 2017 07:12:08 -0700 Subject: [PATCH 3/6] Flesh out TimerSpec and StateSpec in Runner API --- .../src/main/proto/beam_runner_api.proto | 40 ++++++++++++++----- 1 file changed, 29 insertions(+), 11 deletions(-) diff --git a/sdks/common/runner-api/src/main/proto/beam_runner_api.proto b/sdks/common/runner-api/src/main/proto/beam_runner_api.proto index c8722e6a39fb..16122093c812 100644 --- a/sdks/common/runner-api/src/main/proto/beam_runner_api.proto +++ b/sdks/common/runner-api/src/main/proto/beam_runner_api.proto @@ -247,21 +247,39 @@ message Parameter { } message StateSpec { - // TODO: AST for state spec - string id = 1; - Type type = 2; - - enum Type { - VALUE = 0; - BAG = 1; - MAP = 2; - SET = 3; + oneof spec { + ValueStateSpec value_spec = 1; + BagStateSpec bag_spec = 2; + CombiningStateSpec combining_spec = 3; + MapStateSpec map_spec = 4; + SetStateSpec set_spec = 5; } } +message ValueStateSpec { + string coder_id = 1; +} + +message BagStateSpec { + string element_coder_id = 1; +} + +message CombiningStateSpec { + string accumulator_coder_id = 1; + SdkFunctionSpec combine_fn = 2; +} + +message MapStateSpec { + string key_coder_id = 1; + string value_coder_id = 2; +} + +message SetStateSpec { + string element_coder_id = 1; +} + message TimerSpec { - // TODO: AST for timer spec - string id = 1; + TimeDomain time_domain = 1; } enum IsBounded { From 8fc2eb0aeee9c3bdeaf93897e5e8aa4bb98b98de Mon Sep 17 00:00:00 2001 From: Kenneth Knowles Date: Thu, 25 May 2017 07:27:52 -0700 Subject: [PATCH 4/6] Add case dispatch to StateSpec This is different than a StateBinder: for a binder, the id is needed and the StateSpec controls the return type. For case dispatch, the dispatcher controls the type and it should just be reading the spec, which does not require the id. Eventually, StateBinder could be removed in favor of StateSpec.Cases>. --- .../org/apache/beam/sdk/state/StateSpec.java | 53 +++++++++++++++++++ .../org/apache/beam/sdk/state/StateSpecs.java | 41 ++++++++++++++ 2 files changed, 94 insertions(+) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/state/StateSpec.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/state/StateSpec.java index b0412bf49c6d..0443f25f7c4e 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/state/StateSpec.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/state/StateSpec.java @@ -22,6 +22,7 @@ import org.apache.beam.sdk.annotations.Experimental.Kind; import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.transforms.Combine; /** * A specification of a persistent state cell. This includes information necessary to encode the @@ -40,6 +41,14 @@ public interface StateSpec extends Serializable { @Internal StateT bind(String id, StateBinder binder); + /** + * For internal use only; no backwards-compatibility guarantees. + * + *

Perform case analysis on this {@link StateSpec} using the provided {@link Cases}. + */ + @Internal + ResultT match(Cases cases); + /** * For internal use only; no backwards-compatibility guarantees. * @@ -60,4 +69,48 @@ public interface StateSpec extends Serializable { */ @Internal void finishSpecifying(); + + /** + * Cases for doing a "switch" on the type of {@link StateSpec}. + */ + interface Cases { + ResultT dispatchValue(Coder valueCoder); + ResultT dispatchBag(Coder elementCoder); + ResultT dispatchCombining(Combine.CombineFn combineFn, Coder accumCoder); + ResultT dispatchMap(Coder keyCoder, Coder valueCoder); + ResultT dispatchSet(Coder elementCoder); + + /** + * A base class for a visitor with a default method for cases it is not interested in. + */ + abstract class WithDefault implements Cases { + + protected abstract ResultT dispatchDefault(); + + @Override + public ResultT dispatchValue(Coder valueCoder) { + return dispatchDefault(); + } + + @Override + public ResultT dispatchBag(Coder elementCoder) { + return dispatchDefault(); + } + + @Override + public ResultT dispatchCombining(Combine.CombineFn combineFn, Coder accumCoder) { + return dispatchDefault(); + } + + @Override + public ResultT dispatchMap(Coder keyCoder, Coder valueCoder) { + return dispatchDefault(); + } + + @Override + public ResultT dispatchSet(Coder elementCoder) { + return dispatchDefault(); + } + } + } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/state/StateSpecs.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/state/StateSpecs.java index 5a2a1b6b9c72..42223047cc58 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/state/StateSpecs.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/state/StateSpecs.java @@ -278,6 +278,11 @@ public ValueState bind(String id, StateBinder visitor) { return visitor.bindValue(id, this, coder); } + @Override + public ResultT match(Cases cases) { + return cases.dispatchValue(coder); + } + @SuppressWarnings("unchecked") @Override public void offerCoders(Coder[] coders) { @@ -342,6 +347,11 @@ public CombiningState bind( return visitor.bindCombining(id, this, accumCoder, combineFn); } + @Override + public ResultT match(Cases cases) { + return cases.dispatchCombining(combineFn, accumCoder); + } + @SuppressWarnings("unchecked") @Override public void offerCoders(Coder[] coders) { @@ -413,6 +423,14 @@ public CombiningState bind( return visitor.bindCombiningWithContext(id, this, accumCoder, combineFn); } + @Override + public ResultT match(Cases cases) { + throw new UnsupportedOperationException( + String.format( + "%s is for internal use only and does not support case dispatch", + getClass().getSimpleName())); + } + @SuppressWarnings("unchecked") @Override public void offerCoders(Coder[] coders) { @@ -480,6 +498,11 @@ public BagState bind(String id, StateBinder visitor) { return visitor.bindBag(id, this, elemCoder); } + @Override + public ResultT match(Cases cases) { + return cases.dispatchBag(elemCoder); + } + @SuppressWarnings("unchecked") @Override public void offerCoders(Coder[] coders) { @@ -536,6 +559,11 @@ public MapState bind(String id, StateBinder visitor) { return visitor.bindMap(id, this, keyCoder, valueCoder); } + @Override + public ResultT match(Cases cases) { + return cases.dispatchMap(keyCoder, valueCoder); + } + @SuppressWarnings("unchecked") @Override public void offerCoders(Coder[] coders) { @@ -600,6 +628,11 @@ public SetState bind(String id, StateBinder visitor) { return visitor.bindSet(id, this, elemCoder); } + @Override + public ResultT match(Cases cases) { + return cases.dispatchSet(elemCoder); + } + @SuppressWarnings("unchecked") @Override public void offerCoders(Coder[] coders) { @@ -663,6 +696,14 @@ public WatermarkHoldState bind(String id, StateBinder visitor) { return visitor.bindWatermark(id, this, timestampCombiner); } + @Override + public ResultT match(Cases cases) { + throw new UnsupportedOperationException( + String.format( + "%s is for internal use only and does not support case dispatch", + getClass().getSimpleName())); + } + @Override public void offerCoders(Coder[] coders) { } From a250ce58c6a0caf473842c4e5e6f980a828dde55 Mon Sep 17 00:00:00 2001 From: Kenneth Knowles Date: Thu, 25 May 2017 22:51:18 -0700 Subject: [PATCH 5/6] Make Java serialized CombineFn URN public --- .../beam/runners/core/construction/CombineTranslation.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CombineTranslation.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CombineTranslation.java index 855fba740a02..28bc9a15e47f 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CombineTranslation.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CombineTranslation.java @@ -49,7 +49,7 @@ * RunnerApi.CombinePayload} protos. */ public class CombineTranslation { - private static final String JAVA_SERIALIZED_COMBINE_FN_URN = "urn:beam:java:combinefn:v1"; + public static final String JAVA_SERIALIZED_COMBINE_FN_URN = "urn:beam:java:combinefn:v1"; public static CombinePayload toProto( AppliedPTransform> combine, SdkComponents sdkComponents) From 39220dbca944a2496587c543de2a4eb01004bd76 Mon Sep 17 00:00:00 2001 From: Kenneth Knowles Date: Thu, 25 May 2017 07:12:29 -0700 Subject: [PATCH 6/6] Implement TimerSpec and StateSpec translation --- .../core/construction/CombineTranslation.java | 2 +- .../core/construction/ParDoTranslation.java | 215 +++++++++++++++-- .../construction/ParDoTranslationTest.java | 218 +++++++++++------- 3 files changed, 343 insertions(+), 92 deletions(-) diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CombineTranslation.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CombineTranslation.java index 28bc9a15e47f..472b6f8b1d48 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CombineTranslation.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CombineTranslation.java @@ -86,7 +86,7 @@ private static Coder extractAccumulatorCoder( .getAccumulatorCoder(); } - private static SdkFunctionSpec toProto(GlobalCombineFn combineFn) { + public static SdkFunctionSpec toProto(GlobalCombineFn combineFn) { return SdkFunctionSpec.newBuilder() // TODO: Set Java SDK Environment URN .setSpec( diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java index 1c81f8ce05d5..fe66179bc5e7 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java @@ -19,10 +19,12 @@ package org.apache.beam.runners.core.construction; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; import static org.apache.beam.runners.core.construction.PTransformTranslation.PAR_DO_TRANSFORM_URN; import com.google.auto.service.AutoService; import com.google.auto.value.AutoValue; +import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Optional; import com.google.common.collect.Iterables; import com.google.common.collect.Sets; @@ -46,9 +48,12 @@ import org.apache.beam.sdk.common.runner.v1.RunnerApi.SdkFunctionSpec; import org.apache.beam.sdk.common.runner.v1.RunnerApi.SideInput; import org.apache.beam.sdk.common.runner.v1.RunnerApi.SideInput.Builder; -import org.apache.beam.sdk.common.runner.v1.RunnerApi.StateSpec; -import org.apache.beam.sdk.common.runner.v1.RunnerApi.TimerSpec; import org.apache.beam.sdk.runners.AppliedPTransform; +import org.apache.beam.sdk.state.StateSpec; +import org.apache.beam.sdk.state.StateSpecs; +import org.apache.beam.sdk.state.TimeDomain; +import org.apache.beam.sdk.state.TimerSpec; +import org.apache.beam.sdk.transforms.Combine; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.Materializations; import org.apache.beam.sdk.transforms.PTransform; @@ -107,7 +112,8 @@ public String getUrn(ParDo.MultiOutput transform) { @Override public FunctionSpec translate( - AppliedPTransform> transform, SdkComponents components) { + AppliedPTransform> transform, SdkComponents components) + throws IOException { ParDoPayload payload = toProto(transform.getTransform(), components); return RunnerApi.FunctionSpec.newBuilder() .setUrn(PAR_DO_TRANSFORM_URN) @@ -128,8 +134,10 @@ public static class Registrar implements TransformPayloadTranslatorRegistrar { } } - public static ParDoPayload toProto(ParDo.MultiOutput parDo, SdkComponents components) { - DoFnSignature signature = DoFnSignatures.getSignature(parDo.getFn().getClass()); + public static ParDoPayload toProto(ParDo.MultiOutput parDo, SdkComponents components) + throws IOException { + DoFn doFn = parDo.getFn(); + DoFnSignature signature = DoFnSignatures.getSignature(doFn.getClass()); Map states = signature.stateDeclarations(); Map timers = signature.timerDeclarations(); List parameters = signature.processElement().extraParameters(); @@ -146,16 +154,62 @@ public static ParDoPayload toProto(ParDo.MultiOutput parDo, SdkComponents } } for (Map.Entry state : states.entrySet()) { - StateSpec spec = toProto(state.getValue()); + RunnerApi.StateSpec spec = + toProto(getStateSpecOrCrash(state.getValue(), doFn), components); builder.putStateSpecs(state.getKey(), spec); } for (Map.Entry timer : timers.entrySet()) { - TimerSpec spec = toProto(timer.getValue()); + RunnerApi.TimerSpec spec = + toProto(getTimerSpecOrCrash(timer.getValue(), doFn)); builder.putTimerSpecs(timer.getKey(), spec); } return builder.build(); } + private static StateSpec getStateSpecOrCrash( + StateDeclaration stateDeclaration, DoFn target) { + try { + Object fieldValue = stateDeclaration.field().get(target); + checkState(fieldValue instanceof StateSpec, + "Malformed %s class %s: state declaration field %s does not have type %s.", + DoFn.class.getSimpleName(), + target.getClass().getName(), + stateDeclaration.field().getName(), + StateSpec.class); + + return (StateSpec) stateDeclaration.field().get(target); + } catch (IllegalAccessException exc) { + throw new RuntimeException( + String.format( + "Malformed %s class %s: state declaration field %s is not accessible.", + DoFn.class.getSimpleName(), + target.getClass().getName(), + stateDeclaration.field().getName())); + } + } + + private static TimerSpec getTimerSpecOrCrash( + TimerDeclaration timerDeclaration, DoFn target) { + try { + Object fieldValue = timerDeclaration.field().get(target); + checkState(fieldValue instanceof TimerSpec, + "Malformed %s class %s: timer declaration field %s does not have type %s.", + DoFn.class.getSimpleName(), + target.getClass().getName(), + timerDeclaration.field().getName(), + TimerSpec.class); + + return (TimerSpec) timerDeclaration.field().get(target); + } catch (IllegalAccessException exc) { + throw new RuntimeException( + String.format( + "Malformed %s class %s: timer declaration field %s is not accessible.", + DoFn.class.getSimpleName(), + target.getClass().getName(), + timerDeclaration.field().getName())); + } + } + public static DoFn getDoFn(ParDoPayload payload) throws InvalidProtocolBufferException { return doFnAndMainOutputTagFromProto(payload.getDoFn()).getDoFn(); } @@ -179,14 +233,149 @@ public static RunnerApi.PCollection getMainInput( return components.getPcollectionsOrThrow(ptransform.getInputsOrThrow(mainInputId)); } - // TODO: Implement - private static StateSpec toProto(StateDeclaration state) { - throw new UnsupportedOperationException("Not yet supported"); + @VisibleForTesting + static RunnerApi.StateSpec toProto(StateSpec stateSpec, final SdkComponents components) + throws IOException { + final RunnerApi.StateSpec.Builder builder = RunnerApi.StateSpec.newBuilder(); + + return stateSpec.match( + new StateSpec.Cases() { + @Override + public RunnerApi.StateSpec dispatchValue(Coder valueCoder) { + return builder + .setValueSpec( + RunnerApi.ValueStateSpec.newBuilder() + .setCoderId(registerCoderOrThrow(components, valueCoder))) + .build(); + } + + @Override + public RunnerApi.StateSpec dispatchBag(Coder elementCoder) { + return builder + .setBagSpec( + RunnerApi.BagStateSpec.newBuilder() + .setElementCoderId(registerCoderOrThrow(components, elementCoder))) + .build(); + } + + @Override + public RunnerApi.StateSpec dispatchCombining( + Combine.CombineFn combineFn, Coder accumCoder) { + return builder + .setCombiningSpec( + RunnerApi.CombiningStateSpec.newBuilder() + .setAccumulatorCoderId(registerCoderOrThrow(components, accumCoder)) + .setCombineFn(CombineTranslation.toProto(combineFn))) + .build(); + } + + @Override + public RunnerApi.StateSpec dispatchMap(Coder keyCoder, Coder valueCoder) { + return builder + .setMapSpec( + RunnerApi.MapStateSpec.newBuilder() + .setKeyCoderId(registerCoderOrThrow(components, keyCoder)) + .setValueCoderId(registerCoderOrThrow(components, valueCoder))) + .build(); + } + + @Override + public RunnerApi.StateSpec dispatchSet(Coder elementCoder) { + return builder + .setSetSpec( + RunnerApi.SetStateSpec.newBuilder() + .setElementCoderId(registerCoderOrThrow(components, elementCoder))) + .build(); + } + }); + } + + @VisibleForTesting + static StateSpec fromProto(RunnerApi.StateSpec stateSpec, RunnerApi.Components components) + throws IOException { + switch (stateSpec.getSpecCase()) { + case VALUE_SPEC: + return StateSpecs.value( + CoderTranslation.fromProto( + components.getCodersMap().get(stateSpec.getValueSpec().getCoderId()), components)); + case BAG_SPEC: + return StateSpecs.bag( + CoderTranslation.fromProto( + components.getCodersMap().get(stateSpec.getBagSpec().getElementCoderId()), + components)); + case COMBINING_SPEC: + FunctionSpec combineFnSpec = stateSpec.getCombiningSpec().getCombineFn().getSpec(); + + if (!combineFnSpec.getUrn().equals(CombineTranslation.JAVA_SERIALIZED_COMBINE_FN_URN)) { + throw new UnsupportedOperationException( + String.format( + "Cannot create %s from non-Java %s: %s", + StateSpec.class.getSimpleName(), + Combine.CombineFn.class.getSimpleName(), + combineFnSpec.getUrn())); + } + + Combine.CombineFn combineFn = + (Combine.CombineFn) + SerializableUtils.deserializeFromByteArray( + combineFnSpec.getParameter().unpack(BytesValue.class).toByteArray(), + Combine.CombineFn.class.getSimpleName()); + + // Rawtype coder cast because it is required to be a valid accumulator coder + // for the CombineFn, by construction + return StateSpecs.combining( + (Coder) + CoderTranslation.fromProto( + components + .getCodersMap() + .get(stateSpec.getCombiningSpec().getAccumulatorCoderId()), + components), + combineFn); + + case MAP_SPEC: + return StateSpecs.map( + CoderTranslation.fromProto( + components.getCodersOrThrow(stateSpec.getMapSpec().getKeyCoderId()), components), + CoderTranslation.fromProto( + components.getCodersOrThrow(stateSpec.getMapSpec().getValueCoderId()), components)); + + case SET_SPEC: + return StateSpecs.set( + CoderTranslation.fromProto( + components.getCodersMap().get(stateSpec.getSetSpec().getElementCoderId()), + components)); + + case SPEC_NOT_SET: + default: + throw new IllegalArgumentException( + String.format("Unknown %s: %s", RunnerApi.StateSpec.class.getName(), stateSpec)); + + } + } + + private static String registerCoderOrThrow(SdkComponents components, Coder coder) { + try { + return components.registerCoder(coder); + } catch (IOException exc) { + throw new RuntimeException("Failure to register coder", exc); + } } - // TODO: Implement - private static TimerSpec toProto(TimerDeclaration timer) { - throw new UnsupportedOperationException("Not yet supported"); + private static RunnerApi.TimerSpec toProto(TimerSpec timer) { + return RunnerApi.TimerSpec.newBuilder().setTimeDomain(toProto(timer.getTimeDomain())).build(); + } + + private static RunnerApi.TimeDomain toProto(TimeDomain timeDomain) { + switch(timeDomain) { + case EVENT_TIME: + return RunnerApi.TimeDomain.EVENT_TIME; + case PROCESSING_TIME: + return RunnerApi.TimeDomain.PROCESSING_TIME; + case SYNCHRONIZED_PROCESSING_TIME: + return RunnerApi.TimeDomain.SYNCHRONIZED_PROCESSING_TIME; + default: + throw new IllegalArgumentException("Unknown time domain"); + } } @AutoValue diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/ParDoTranslationTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/ParDoTranslationTest.java index ec2795746451..46f6a806f292 100644 --- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/ParDoTranslationTest.java +++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/ParDoTranslationTest.java @@ -28,6 +28,7 @@ import java.util.Map; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.coders.VarIntCoder; import org.apache.beam.sdk.coders.VarLongCoder; import org.apache.beam.sdk.common.runner.v1.RunnerApi; import org.apache.beam.sdk.common.runner.v1.RunnerApi.Components; @@ -62,98 +63,159 @@ import org.hamcrest.Matchers; import org.junit.Test; import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; import org.junit.runners.Parameterized; import org.junit.runners.Parameterized.Parameter; import org.junit.runners.Parameterized.Parameters; +import org.junit.runners.Suite; /** Tests for {@link ParDoTranslation}. */ -@RunWith(Parameterized.class) +@RunWith(Suite.class) +@Suite.SuiteClasses({ + ParDoTranslationTest.TestParDoPayloadTranslation.class, + ParDoTranslationTest.TestStateAndTimerTranslation.class +}) public class ParDoTranslationTest { - public static TestPipeline p = TestPipeline.create().enableAbandonedNodeEnforcement(false); - - private static PCollectionView singletonSideInput = - p.apply("GenerateSingleton", GenerateSequence.from(0L).to(1L)) - .apply(View.asSingleton()); - private static PCollectionView>> multimapSideInput = - p.apply("CreateMultimap", Create.of(KV.of(1L, "foo"), KV.of(1L, "bar"), KV.of(2L, "spam"))) - .setCoder(KvCoder.of(VarLongCoder.of(), StringUtf8Coder.of())) - .apply(View.asMultimap()); - - private static PCollection> mainInput = - p.apply("CreateMainInput", Create.empty(KvCoder.of(VarLongCoder.of(), StringUtf8Coder.of()))); - - @Parameters(name = "{index}: {0}") - public static Iterable> data() { - return ImmutableList.>of( - ParDo.of(new DropElementsFn()).withOutputTags(new TupleTag(), TupleTagList.empty()), - ParDo.of(new DropElementsFn()) - .withOutputTags(new TupleTag(), TupleTagList.empty()) - .withSideInputs(singletonSideInput, multimapSideInput), - ParDo.of(new DropElementsFn()) - .withOutputTags( - new TupleTag(), - TupleTagList.of(new TupleTag() {}).and(new TupleTag() {})) - .withSideInputs(singletonSideInput, multimapSideInput), - ParDo.of(new DropElementsFn()) - .withOutputTags( - new TupleTag(), - TupleTagList.of(new TupleTag() {}).and(new TupleTag() {}))); - } - @Parameter(0) - public ParDo.MultiOutput, Void> parDo; + /** + * Tests for translating various {@link ParDo} transforms to/from {@link ParDoPayload} protos. + */ + @RunWith(Parameterized.class) + public static class TestParDoPayloadTranslation { + public static TestPipeline p = TestPipeline.create().enableAbandonedNodeEnforcement(false); + + private static PCollectionView singletonSideInput = + p.apply("GenerateSingleton", GenerateSequence.from(0L).to(1L)) + .apply(View.asSingleton()); + private static PCollectionView>> multimapSideInput = + p.apply("CreateMultimap", Create.of(KV.of(1L, "foo"), KV.of(1L, "bar"), KV.of(2L, "spam"))) + .setCoder(KvCoder.of(VarLongCoder.of(), StringUtf8Coder.of())) + .apply(View.asMultimap()); - @Test - public void testToAndFromProto() throws Exception { - SdkComponents components = SdkComponents.create(); - ParDoPayload payload = ParDoTranslation.toProto(parDo, components); + private static PCollection> mainInput = + p.apply( + "CreateMainInput", Create.empty(KvCoder.of(VarLongCoder.of(), StringUtf8Coder.of()))); - assertThat(ParDoTranslation.getDoFn(payload), Matchers.>equalTo(parDo.getFn())); - assertThat( - ParDoTranslation.getMainOutputTag(payload), - Matchers.>equalTo(parDo.getMainOutputTag())); - for (PCollectionView view : parDo.getSideInputs()) { - payload.getSideInputsOrThrow(view.getTagInternal().getId()); + @Parameters(name = "{index}: {0}") + public static Iterable> data() { + return ImmutableList.>of( + ParDo.of(new DropElementsFn()).withOutputTags(new TupleTag(), TupleTagList.empty()), + ParDo.of(new DropElementsFn()) + .withOutputTags(new TupleTag(), TupleTagList.empty()) + .withSideInputs(singletonSideInput, multimapSideInput), + ParDo.of(new DropElementsFn()) + .withOutputTags( + new TupleTag(), + TupleTagList.of(new TupleTag() {}).and(new TupleTag() {})) + .withSideInputs(singletonSideInput, multimapSideInput), + ParDo.of(new DropElementsFn()) + .withOutputTags( + new TupleTag(), + TupleTagList.of(new TupleTag() {}).and(new TupleTag() {}))); } - } - @Test - public void toAndFromTransformProto() throws Exception { - Map, PValue> inputs = new HashMap<>(); - inputs.put(new TupleTag>() {}, mainInput); - inputs.putAll(parDo.getAdditionalInputs()); - PCollectionTuple output = mainInput.apply(parDo); - - SdkComponents components = SdkComponents.create(); - String transformId = - components.registerPTransform( - AppliedPTransform.>, PCollection, MultiOutput>of( - "foo", inputs, output.expand(), parDo, p), - Collections.>emptyList()); - - Components protoComponents = components.toComponents(); - RunnerApi.PTransform protoTransform = - protoComponents.getTransformsOrThrow(transformId); - ParDoPayload parDoPayload = protoTransform.getSpec().getParameter().unpack(ParDoPayload.class); - for (PCollectionView view : parDo.getSideInputs()) { - SideInput sideInput = parDoPayload.getSideInputsOrThrow(view.getTagInternal().getId()); - PCollectionView restoredView = - ParDoTranslation.fromProto( - sideInput, view.getTagInternal().getId(), protoTransform, protoComponents); - assertThat(restoredView.getTagInternal(), equalTo(view.getTagInternal())); - assertThat(restoredView.getViewFn(), instanceOf(view.getViewFn().getClass())); + @Parameter(0) + public ParDo.MultiOutput, Void> parDo; + + @Test + public void testToAndFromProto() throws Exception { + SdkComponents components = SdkComponents.create(); + ParDoPayload payload = ParDoTranslation.toProto(parDo, components); + + assertThat(ParDoTranslation.getDoFn(payload), Matchers.>equalTo(parDo.getFn())); assertThat( - restoredView.getWindowMappingFn(), instanceOf(view.getWindowMappingFn().getClass())); + ParDoTranslation.getMainOutputTag(payload), + Matchers.>equalTo(parDo.getMainOutputTag())); + for (PCollectionView view : parDo.getSideInputs()) { + payload.getSideInputsOrThrow(view.getTagInternal().getId()); + } + } + + @Test + public void toAndFromTransformProto() throws Exception { + Map, PValue> inputs = new HashMap<>(); + inputs.put(new TupleTag>() {}, mainInput); + inputs.putAll(parDo.getAdditionalInputs()); + PCollectionTuple output = mainInput.apply(parDo); + + SdkComponents components = SdkComponents.create(); + String transformId = + components.registerPTransform( + AppliedPTransform.>, PCollection, MultiOutput>of( + "foo", inputs, output.expand(), parDo, p), + Collections.>emptyList()); + + Components protoComponents = components.toComponents(); + RunnerApi.PTransform protoTransform = protoComponents.getTransformsOrThrow(transformId); + ParDoPayload parDoPayload = + protoTransform.getSpec().getParameter().unpack(ParDoPayload.class); + for (PCollectionView view : parDo.getSideInputs()) { + SideInput sideInput = parDoPayload.getSideInputsOrThrow(view.getTagInternal().getId()); + PCollectionView restoredView = + ParDoTranslation.fromProto( + sideInput, view.getTagInternal().getId(), protoTransform, protoComponents); + assertThat(restoredView.getTagInternal(), equalTo(view.getTagInternal())); + assertThat(restoredView.getViewFn(), instanceOf(view.getViewFn().getClass())); + assertThat( + restoredView.getWindowMappingFn(), instanceOf(view.getWindowMappingFn().getClass())); + assertThat( + restoredView.getWindowingStrategyInternal(), + Matchers.>equalTo( + view.getWindowingStrategyInternal().fixDefaults())); + assertThat(restoredView.getCoderInternal(), equalTo(view.getCoderInternal())); + } + String mainInputId = components.registerPCollection(mainInput); assertThat( - restoredView.getWindowingStrategyInternal(), - Matchers.>equalTo( - view.getWindowingStrategyInternal().fixDefaults())); - assertThat(restoredView.getCoderInternal(), equalTo(view.getCoderInternal())); + ParDoTranslation.getMainInput(protoTransform, protoComponents), + equalTo(protoComponents.getPcollectionsOrThrow(mainInputId))); + } + } + + /** + * Tests for translating state and timer bits to/from protos. + */ + @RunWith(JUnit4.class) + public static class TestStateAndTimerTranslation { + + @Test + public void testValueStateSpecToFromProto() throws Exception { + SdkComponents sdkComponents = SdkComponents.create(); + StateSpec stateSpec = StateSpecs.value(VarIntCoder.of()); + StateSpec deserializedStateSpec = + ParDoTranslation.fromProto( + ParDoTranslation.toProto(stateSpec, sdkComponents), sdkComponents.toComponents()); + assertThat(stateSpec, Matchers.>equalTo(deserializedStateSpec)); + } + + @Test + public void testBagStateSpecToFromProto() throws Exception { + SdkComponents sdkComponents = SdkComponents.create(); + StateSpec stateSpec = StateSpecs.bag(VarIntCoder.of()); + StateSpec deserializedStateSpec = + ParDoTranslation.fromProto( + ParDoTranslation.toProto(stateSpec, sdkComponents), sdkComponents.toComponents()); + assertThat(stateSpec, Matchers.>equalTo(deserializedStateSpec)); + } + + @Test + public void testSetStateSpecToFromProto() throws Exception { + SdkComponents sdkComponents = SdkComponents.create(); + StateSpec stateSpec = StateSpecs.set(VarIntCoder.of()); + StateSpec deserializedStateSpec = + ParDoTranslation.fromProto( + ParDoTranslation.toProto(stateSpec, sdkComponents), sdkComponents.toComponents()); + assertThat(stateSpec, Matchers.>equalTo(deserializedStateSpec)); + } + + @Test + public void testMapStateSpecToFromProto() throws Exception { + SdkComponents sdkComponents = SdkComponents.create(); + StateSpec stateSpec = StateSpecs.map(StringUtf8Coder.of(), VarIntCoder.of()); + StateSpec deserializedStateSpec = + ParDoTranslation.fromProto( + ParDoTranslation.toProto(stateSpec, sdkComponents), sdkComponents.toComponents()); + assertThat(stateSpec, Matchers.>equalTo(deserializedStateSpec)); } - String mainInputId = components.registerPCollection(mainInput); - assertThat( - ParDoTranslation.getMainInput(protoTransform, protoComponents), - equalTo(protoComponents.getPcollectionsOrThrow(mainInputId))); } private static class DropElementsFn extends DoFn, Void> {