diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CombineTranslation.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CombineTranslation.java index 855fba740a02..472b6f8b1d48 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CombineTranslation.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CombineTranslation.java @@ -49,7 +49,7 @@ * RunnerApi.CombinePayload} protos. */ public class CombineTranslation { - private static final String JAVA_SERIALIZED_COMBINE_FN_URN = "urn:beam:java:combinefn:v1"; + public static final String JAVA_SERIALIZED_COMBINE_FN_URN = "urn:beam:java:combinefn:v1"; public static CombinePayload toProto( AppliedPTransform> combine, SdkComponents sdkComponents) @@ -86,7 +86,7 @@ private static Coder extractAccumulatorCoder( .getAccumulatorCoder(); } - private static SdkFunctionSpec toProto(GlobalCombineFn combineFn) { + public static SdkFunctionSpec toProto(GlobalCombineFn combineFn) { return SdkFunctionSpec.newBuilder() // TODO: Set Java SDK Environment URN .setSpec( diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformTranslation.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformTranslation.java index 9f5f3b50b723..00ea55e6e1ed 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformTranslation.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformTranslation.java @@ -138,7 +138,8 @@ public static String urnForTransform(PTransform transform) { */ public interface TransformPayloadTranslator> { String getUrn(T transform); - FunctionSpec translate(AppliedPTransform application, SdkComponents components); + FunctionSpec translate(AppliedPTransform application, SdkComponents components) + throws IOException; } /** diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java index 1c81f8ce05d5..fe66179bc5e7 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java @@ -19,10 +19,12 @@ package org.apache.beam.runners.core.construction; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; import static org.apache.beam.runners.core.construction.PTransformTranslation.PAR_DO_TRANSFORM_URN; import com.google.auto.service.AutoService; import com.google.auto.value.AutoValue; +import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Optional; import com.google.common.collect.Iterables; import com.google.common.collect.Sets; @@ -46,9 +48,12 @@ import org.apache.beam.sdk.common.runner.v1.RunnerApi.SdkFunctionSpec; import org.apache.beam.sdk.common.runner.v1.RunnerApi.SideInput; import org.apache.beam.sdk.common.runner.v1.RunnerApi.SideInput.Builder; -import org.apache.beam.sdk.common.runner.v1.RunnerApi.StateSpec; -import org.apache.beam.sdk.common.runner.v1.RunnerApi.TimerSpec; import org.apache.beam.sdk.runners.AppliedPTransform; +import org.apache.beam.sdk.state.StateSpec; +import org.apache.beam.sdk.state.StateSpecs; +import org.apache.beam.sdk.state.TimeDomain; +import org.apache.beam.sdk.state.TimerSpec; +import org.apache.beam.sdk.transforms.Combine; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.Materializations; import org.apache.beam.sdk.transforms.PTransform; @@ -107,7 +112,8 @@ public String getUrn(ParDo.MultiOutput transform) { @Override public FunctionSpec translate( - AppliedPTransform> transform, SdkComponents components) { + AppliedPTransform> transform, SdkComponents components) + throws IOException { ParDoPayload payload = toProto(transform.getTransform(), components); return RunnerApi.FunctionSpec.newBuilder() .setUrn(PAR_DO_TRANSFORM_URN) @@ -128,8 +134,10 @@ public static class Registrar implements TransformPayloadTranslatorRegistrar { } } - public static ParDoPayload toProto(ParDo.MultiOutput parDo, SdkComponents components) { - DoFnSignature signature = DoFnSignatures.getSignature(parDo.getFn().getClass()); + public static ParDoPayload toProto(ParDo.MultiOutput parDo, SdkComponents components) + throws IOException { + DoFn doFn = parDo.getFn(); + DoFnSignature signature = DoFnSignatures.getSignature(doFn.getClass()); Map states = signature.stateDeclarations(); Map timers = signature.timerDeclarations(); List parameters = signature.processElement().extraParameters(); @@ -146,16 +154,62 @@ public static ParDoPayload toProto(ParDo.MultiOutput parDo, SdkComponents } } for (Map.Entry state : states.entrySet()) { - StateSpec spec = toProto(state.getValue()); + RunnerApi.StateSpec spec = + toProto(getStateSpecOrCrash(state.getValue(), doFn), components); builder.putStateSpecs(state.getKey(), spec); } for (Map.Entry timer : timers.entrySet()) { - TimerSpec spec = toProto(timer.getValue()); + RunnerApi.TimerSpec spec = + toProto(getTimerSpecOrCrash(timer.getValue(), doFn)); builder.putTimerSpecs(timer.getKey(), spec); } return builder.build(); } + private static StateSpec getStateSpecOrCrash( + StateDeclaration stateDeclaration, DoFn target) { + try { + Object fieldValue = stateDeclaration.field().get(target); + checkState(fieldValue instanceof StateSpec, + "Malformed %s class %s: state declaration field %s does not have type %s.", + DoFn.class.getSimpleName(), + target.getClass().getName(), + stateDeclaration.field().getName(), + StateSpec.class); + + return (StateSpec) stateDeclaration.field().get(target); + } catch (IllegalAccessException exc) { + throw new RuntimeException( + String.format( + "Malformed %s class %s: state declaration field %s is not accessible.", + DoFn.class.getSimpleName(), + target.getClass().getName(), + stateDeclaration.field().getName())); + } + } + + private static TimerSpec getTimerSpecOrCrash( + TimerDeclaration timerDeclaration, DoFn target) { + try { + Object fieldValue = timerDeclaration.field().get(target); + checkState(fieldValue instanceof TimerSpec, + "Malformed %s class %s: timer declaration field %s does not have type %s.", + DoFn.class.getSimpleName(), + target.getClass().getName(), + timerDeclaration.field().getName(), + TimerSpec.class); + + return (TimerSpec) timerDeclaration.field().get(target); + } catch (IllegalAccessException exc) { + throw new RuntimeException( + String.format( + "Malformed %s class %s: timer declaration field %s is not accessible.", + DoFn.class.getSimpleName(), + target.getClass().getName(), + timerDeclaration.field().getName())); + } + } + public static DoFn getDoFn(ParDoPayload payload) throws InvalidProtocolBufferException { return doFnAndMainOutputTagFromProto(payload.getDoFn()).getDoFn(); } @@ -179,14 +233,149 @@ public static RunnerApi.PCollection getMainInput( return components.getPcollectionsOrThrow(ptransform.getInputsOrThrow(mainInputId)); } - // TODO: Implement - private static StateSpec toProto(StateDeclaration state) { - throw new UnsupportedOperationException("Not yet supported"); + @VisibleForTesting + static RunnerApi.StateSpec toProto(StateSpec stateSpec, final SdkComponents components) + throws IOException { + final RunnerApi.StateSpec.Builder builder = RunnerApi.StateSpec.newBuilder(); + + return stateSpec.match( + new StateSpec.Cases() { + @Override + public RunnerApi.StateSpec dispatchValue(Coder valueCoder) { + return builder + .setValueSpec( + RunnerApi.ValueStateSpec.newBuilder() + .setCoderId(registerCoderOrThrow(components, valueCoder))) + .build(); + } + + @Override + public RunnerApi.StateSpec dispatchBag(Coder elementCoder) { + return builder + .setBagSpec( + RunnerApi.BagStateSpec.newBuilder() + .setElementCoderId(registerCoderOrThrow(components, elementCoder))) + .build(); + } + + @Override + public RunnerApi.StateSpec dispatchCombining( + Combine.CombineFn combineFn, Coder accumCoder) { + return builder + .setCombiningSpec( + RunnerApi.CombiningStateSpec.newBuilder() + .setAccumulatorCoderId(registerCoderOrThrow(components, accumCoder)) + .setCombineFn(CombineTranslation.toProto(combineFn))) + .build(); + } + + @Override + public RunnerApi.StateSpec dispatchMap(Coder keyCoder, Coder valueCoder) { + return builder + .setMapSpec( + RunnerApi.MapStateSpec.newBuilder() + .setKeyCoderId(registerCoderOrThrow(components, keyCoder)) + .setValueCoderId(registerCoderOrThrow(components, valueCoder))) + .build(); + } + + @Override + public RunnerApi.StateSpec dispatchSet(Coder elementCoder) { + return builder + .setSetSpec( + RunnerApi.SetStateSpec.newBuilder() + .setElementCoderId(registerCoderOrThrow(components, elementCoder))) + .build(); + } + }); + } + + @VisibleForTesting + static StateSpec fromProto(RunnerApi.StateSpec stateSpec, RunnerApi.Components components) + throws IOException { + switch (stateSpec.getSpecCase()) { + case VALUE_SPEC: + return StateSpecs.value( + CoderTranslation.fromProto( + components.getCodersMap().get(stateSpec.getValueSpec().getCoderId()), components)); + case BAG_SPEC: + return StateSpecs.bag( + CoderTranslation.fromProto( + components.getCodersMap().get(stateSpec.getBagSpec().getElementCoderId()), + components)); + case COMBINING_SPEC: + FunctionSpec combineFnSpec = stateSpec.getCombiningSpec().getCombineFn().getSpec(); + + if (!combineFnSpec.getUrn().equals(CombineTranslation.JAVA_SERIALIZED_COMBINE_FN_URN)) { + throw new UnsupportedOperationException( + String.format( + "Cannot create %s from non-Java %s: %s", + StateSpec.class.getSimpleName(), + Combine.CombineFn.class.getSimpleName(), + combineFnSpec.getUrn())); + } + + Combine.CombineFn combineFn = + (Combine.CombineFn) + SerializableUtils.deserializeFromByteArray( + combineFnSpec.getParameter().unpack(BytesValue.class).toByteArray(), + Combine.CombineFn.class.getSimpleName()); + + // Rawtype coder cast because it is required to be a valid accumulator coder + // for the CombineFn, by construction + return StateSpecs.combining( + (Coder) + CoderTranslation.fromProto( + components + .getCodersMap() + .get(stateSpec.getCombiningSpec().getAccumulatorCoderId()), + components), + combineFn); + + case MAP_SPEC: + return StateSpecs.map( + CoderTranslation.fromProto( + components.getCodersOrThrow(stateSpec.getMapSpec().getKeyCoderId()), components), + CoderTranslation.fromProto( + components.getCodersOrThrow(stateSpec.getMapSpec().getValueCoderId()), components)); + + case SET_SPEC: + return StateSpecs.set( + CoderTranslation.fromProto( + components.getCodersMap().get(stateSpec.getSetSpec().getElementCoderId()), + components)); + + case SPEC_NOT_SET: + default: + throw new IllegalArgumentException( + String.format("Unknown %s: %s", RunnerApi.StateSpec.class.getName(), stateSpec)); + + } + } + + private static String registerCoderOrThrow(SdkComponents components, Coder coder) { + try { + return components.registerCoder(coder); + } catch (IOException exc) { + throw new RuntimeException("Failure to register coder", exc); + } } - // TODO: Implement - private static TimerSpec toProto(TimerDeclaration timer) { - throw new UnsupportedOperationException("Not yet supported"); + private static RunnerApi.TimerSpec toProto(TimerSpec timer) { + return RunnerApi.TimerSpec.newBuilder().setTimeDomain(toProto(timer.getTimeDomain())).build(); + } + + private static RunnerApi.TimeDomain toProto(TimeDomain timeDomain) { + switch(timeDomain) { + case EVENT_TIME: + return RunnerApi.TimeDomain.EVENT_TIME; + case PROCESSING_TIME: + return RunnerApi.TimeDomain.PROCESSING_TIME; + case SYNCHRONIZED_PROCESSING_TIME: + return RunnerApi.TimeDomain.SYNCHRONIZED_PROCESSING_TIME; + default: + throw new IllegalArgumentException("Unknown time domain"); + } } @AutoValue diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/ParDoTranslationTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/ParDoTranslationTest.java index ec2795746451..46f6a806f292 100644 --- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/ParDoTranslationTest.java +++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/ParDoTranslationTest.java @@ -28,6 +28,7 @@ import java.util.Map; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.coders.VarIntCoder; import org.apache.beam.sdk.coders.VarLongCoder; import org.apache.beam.sdk.common.runner.v1.RunnerApi; import org.apache.beam.sdk.common.runner.v1.RunnerApi.Components; @@ -62,98 +63,159 @@ import org.hamcrest.Matchers; import org.junit.Test; import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; import org.junit.runners.Parameterized; import org.junit.runners.Parameterized.Parameter; import org.junit.runners.Parameterized.Parameters; +import org.junit.runners.Suite; /** Tests for {@link ParDoTranslation}. */ -@RunWith(Parameterized.class) +@RunWith(Suite.class) +@Suite.SuiteClasses({ + ParDoTranslationTest.TestParDoPayloadTranslation.class, + ParDoTranslationTest.TestStateAndTimerTranslation.class +}) public class ParDoTranslationTest { - public static TestPipeline p = TestPipeline.create().enableAbandonedNodeEnforcement(false); - - private static PCollectionView singletonSideInput = - p.apply("GenerateSingleton", GenerateSequence.from(0L).to(1L)) - .apply(View.asSingleton()); - private static PCollectionView>> multimapSideInput = - p.apply("CreateMultimap", Create.of(KV.of(1L, "foo"), KV.of(1L, "bar"), KV.of(2L, "spam"))) - .setCoder(KvCoder.of(VarLongCoder.of(), StringUtf8Coder.of())) - .apply(View.asMultimap()); - - private static PCollection> mainInput = - p.apply("CreateMainInput", Create.empty(KvCoder.of(VarLongCoder.of(), StringUtf8Coder.of()))); - - @Parameters(name = "{index}: {0}") - public static Iterable> data() { - return ImmutableList.>of( - ParDo.of(new DropElementsFn()).withOutputTags(new TupleTag(), TupleTagList.empty()), - ParDo.of(new DropElementsFn()) - .withOutputTags(new TupleTag(), TupleTagList.empty()) - .withSideInputs(singletonSideInput, multimapSideInput), - ParDo.of(new DropElementsFn()) - .withOutputTags( - new TupleTag(), - TupleTagList.of(new TupleTag() {}).and(new TupleTag() {})) - .withSideInputs(singletonSideInput, multimapSideInput), - ParDo.of(new DropElementsFn()) - .withOutputTags( - new TupleTag(), - TupleTagList.of(new TupleTag() {}).and(new TupleTag() {}))); - } - @Parameter(0) - public ParDo.MultiOutput, Void> parDo; + /** + * Tests for translating various {@link ParDo} transforms to/from {@link ParDoPayload} protos. + */ + @RunWith(Parameterized.class) + public static class TestParDoPayloadTranslation { + public static TestPipeline p = TestPipeline.create().enableAbandonedNodeEnforcement(false); + + private static PCollectionView singletonSideInput = + p.apply("GenerateSingleton", GenerateSequence.from(0L).to(1L)) + .apply(View.asSingleton()); + private static PCollectionView>> multimapSideInput = + p.apply("CreateMultimap", Create.of(KV.of(1L, "foo"), KV.of(1L, "bar"), KV.of(2L, "spam"))) + .setCoder(KvCoder.of(VarLongCoder.of(), StringUtf8Coder.of())) + .apply(View.asMultimap()); - @Test - public void testToAndFromProto() throws Exception { - SdkComponents components = SdkComponents.create(); - ParDoPayload payload = ParDoTranslation.toProto(parDo, components); + private static PCollection> mainInput = + p.apply( + "CreateMainInput", Create.empty(KvCoder.of(VarLongCoder.of(), StringUtf8Coder.of()))); - assertThat(ParDoTranslation.getDoFn(payload), Matchers.>equalTo(parDo.getFn())); - assertThat( - ParDoTranslation.getMainOutputTag(payload), - Matchers.>equalTo(parDo.getMainOutputTag())); - for (PCollectionView view : parDo.getSideInputs()) { - payload.getSideInputsOrThrow(view.getTagInternal().getId()); + @Parameters(name = "{index}: {0}") + public static Iterable> data() { + return ImmutableList.>of( + ParDo.of(new DropElementsFn()).withOutputTags(new TupleTag(), TupleTagList.empty()), + ParDo.of(new DropElementsFn()) + .withOutputTags(new TupleTag(), TupleTagList.empty()) + .withSideInputs(singletonSideInput, multimapSideInput), + ParDo.of(new DropElementsFn()) + .withOutputTags( + new TupleTag(), + TupleTagList.of(new TupleTag() {}).and(new TupleTag() {})) + .withSideInputs(singletonSideInput, multimapSideInput), + ParDo.of(new DropElementsFn()) + .withOutputTags( + new TupleTag(), + TupleTagList.of(new TupleTag() {}).and(new TupleTag() {}))); } - } - @Test - public void toAndFromTransformProto() throws Exception { - Map, PValue> inputs = new HashMap<>(); - inputs.put(new TupleTag>() {}, mainInput); - inputs.putAll(parDo.getAdditionalInputs()); - PCollectionTuple output = mainInput.apply(parDo); - - SdkComponents components = SdkComponents.create(); - String transformId = - components.registerPTransform( - AppliedPTransform.>, PCollection, MultiOutput>of( - "foo", inputs, output.expand(), parDo, p), - Collections.>emptyList()); - - Components protoComponents = components.toComponents(); - RunnerApi.PTransform protoTransform = - protoComponents.getTransformsOrThrow(transformId); - ParDoPayload parDoPayload = protoTransform.getSpec().getParameter().unpack(ParDoPayload.class); - for (PCollectionView view : parDo.getSideInputs()) { - SideInput sideInput = parDoPayload.getSideInputsOrThrow(view.getTagInternal().getId()); - PCollectionView restoredView = - ParDoTranslation.fromProto( - sideInput, view.getTagInternal().getId(), protoTransform, protoComponents); - assertThat(restoredView.getTagInternal(), equalTo(view.getTagInternal())); - assertThat(restoredView.getViewFn(), instanceOf(view.getViewFn().getClass())); + @Parameter(0) + public ParDo.MultiOutput, Void> parDo; + + @Test + public void testToAndFromProto() throws Exception { + SdkComponents components = SdkComponents.create(); + ParDoPayload payload = ParDoTranslation.toProto(parDo, components); + + assertThat(ParDoTranslation.getDoFn(payload), Matchers.>equalTo(parDo.getFn())); assertThat( - restoredView.getWindowMappingFn(), instanceOf(view.getWindowMappingFn().getClass())); + ParDoTranslation.getMainOutputTag(payload), + Matchers.>equalTo(parDo.getMainOutputTag())); + for (PCollectionView view : parDo.getSideInputs()) { + payload.getSideInputsOrThrow(view.getTagInternal().getId()); + } + } + + @Test + public void toAndFromTransformProto() throws Exception { + Map, PValue> inputs = new HashMap<>(); + inputs.put(new TupleTag>() {}, mainInput); + inputs.putAll(parDo.getAdditionalInputs()); + PCollectionTuple output = mainInput.apply(parDo); + + SdkComponents components = SdkComponents.create(); + String transformId = + components.registerPTransform( + AppliedPTransform.>, PCollection, MultiOutput>of( + "foo", inputs, output.expand(), parDo, p), + Collections.>emptyList()); + + Components protoComponents = components.toComponents(); + RunnerApi.PTransform protoTransform = protoComponents.getTransformsOrThrow(transformId); + ParDoPayload parDoPayload = + protoTransform.getSpec().getParameter().unpack(ParDoPayload.class); + for (PCollectionView view : parDo.getSideInputs()) { + SideInput sideInput = parDoPayload.getSideInputsOrThrow(view.getTagInternal().getId()); + PCollectionView restoredView = + ParDoTranslation.fromProto( + sideInput, view.getTagInternal().getId(), protoTransform, protoComponents); + assertThat(restoredView.getTagInternal(), equalTo(view.getTagInternal())); + assertThat(restoredView.getViewFn(), instanceOf(view.getViewFn().getClass())); + assertThat( + restoredView.getWindowMappingFn(), instanceOf(view.getWindowMappingFn().getClass())); + assertThat( + restoredView.getWindowingStrategyInternal(), + Matchers.>equalTo( + view.getWindowingStrategyInternal().fixDefaults())); + assertThat(restoredView.getCoderInternal(), equalTo(view.getCoderInternal())); + } + String mainInputId = components.registerPCollection(mainInput); assertThat( - restoredView.getWindowingStrategyInternal(), - Matchers.>equalTo( - view.getWindowingStrategyInternal().fixDefaults())); - assertThat(restoredView.getCoderInternal(), equalTo(view.getCoderInternal())); + ParDoTranslation.getMainInput(protoTransform, protoComponents), + equalTo(protoComponents.getPcollectionsOrThrow(mainInputId))); + } + } + + /** + * Tests for translating state and timer bits to/from protos. + */ + @RunWith(JUnit4.class) + public static class TestStateAndTimerTranslation { + + @Test + public void testValueStateSpecToFromProto() throws Exception { + SdkComponents sdkComponents = SdkComponents.create(); + StateSpec stateSpec = StateSpecs.value(VarIntCoder.of()); + StateSpec deserializedStateSpec = + ParDoTranslation.fromProto( + ParDoTranslation.toProto(stateSpec, sdkComponents), sdkComponents.toComponents()); + assertThat(stateSpec, Matchers.>equalTo(deserializedStateSpec)); + } + + @Test + public void testBagStateSpecToFromProto() throws Exception { + SdkComponents sdkComponents = SdkComponents.create(); + StateSpec stateSpec = StateSpecs.bag(VarIntCoder.of()); + StateSpec deserializedStateSpec = + ParDoTranslation.fromProto( + ParDoTranslation.toProto(stateSpec, sdkComponents), sdkComponents.toComponents()); + assertThat(stateSpec, Matchers.>equalTo(deserializedStateSpec)); + } + + @Test + public void testSetStateSpecToFromProto() throws Exception { + SdkComponents sdkComponents = SdkComponents.create(); + StateSpec stateSpec = StateSpecs.set(VarIntCoder.of()); + StateSpec deserializedStateSpec = + ParDoTranslation.fromProto( + ParDoTranslation.toProto(stateSpec, sdkComponents), sdkComponents.toComponents()); + assertThat(stateSpec, Matchers.>equalTo(deserializedStateSpec)); + } + + @Test + public void testMapStateSpecToFromProto() throws Exception { + SdkComponents sdkComponents = SdkComponents.create(); + StateSpec stateSpec = StateSpecs.map(StringUtf8Coder.of(), VarIntCoder.of()); + StateSpec deserializedStateSpec = + ParDoTranslation.fromProto( + ParDoTranslation.toProto(stateSpec, sdkComponents), sdkComponents.toComponents()); + assertThat(stateSpec, Matchers.>equalTo(deserializedStateSpec)); } - String mainInputId = components.registerPCollection(mainInput); - assertThat( - ParDoTranslation.getMainInput(protoTransform, protoComponents), - equalTo(protoComponents.getPcollectionsOrThrow(mainInputId))); } private static class DropElementsFn extends DoFn, Void> { diff --git a/sdks/common/runner-api/src/main/proto/beam_runner_api.proto b/sdks/common/runner-api/src/main/proto/beam_runner_api.proto index c8722e6a39fb..16122093c812 100644 --- a/sdks/common/runner-api/src/main/proto/beam_runner_api.proto +++ b/sdks/common/runner-api/src/main/proto/beam_runner_api.proto @@ -247,21 +247,39 @@ message Parameter { } message StateSpec { - // TODO: AST for state spec - string id = 1; - Type type = 2; - - enum Type { - VALUE = 0; - BAG = 1; - MAP = 2; - SET = 3; + oneof spec { + ValueStateSpec value_spec = 1; + BagStateSpec bag_spec = 2; + CombiningStateSpec combining_spec = 3; + MapStateSpec map_spec = 4; + SetStateSpec set_spec = 5; } } +message ValueStateSpec { + string coder_id = 1; +} + +message BagStateSpec { + string element_coder_id = 1; +} + +message CombiningStateSpec { + string accumulator_coder_id = 1; + SdkFunctionSpec combine_fn = 2; +} + +message MapStateSpec { + string key_coder_id = 1; + string value_coder_id = 2; +} + +message SetStateSpec { + string element_coder_id = 1; +} + message TimerSpec { - // TODO: AST for timer spec - string id = 1; + TimeDomain time_domain = 1; } enum IsBounded { diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/state/StateSpec.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/state/StateSpec.java index b0412bf49c6d..0443f25f7c4e 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/state/StateSpec.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/state/StateSpec.java @@ -22,6 +22,7 @@ import org.apache.beam.sdk.annotations.Experimental.Kind; import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.transforms.Combine; /** * A specification of a persistent state cell. This includes information necessary to encode the @@ -40,6 +41,14 @@ public interface StateSpec extends Serializable { @Internal StateT bind(String id, StateBinder binder); + /** + * For internal use only; no backwards-compatibility guarantees. + * + *

Perform case analysis on this {@link StateSpec} using the provided {@link Cases}. + */ + @Internal + ResultT match(Cases cases); + /** * For internal use only; no backwards-compatibility guarantees. * @@ -60,4 +69,48 @@ public interface StateSpec extends Serializable { */ @Internal void finishSpecifying(); + + /** + * Cases for doing a "switch" on the type of {@link StateSpec}. + */ + interface Cases { + ResultT dispatchValue(Coder valueCoder); + ResultT dispatchBag(Coder elementCoder); + ResultT dispatchCombining(Combine.CombineFn combineFn, Coder accumCoder); + ResultT dispatchMap(Coder keyCoder, Coder valueCoder); + ResultT dispatchSet(Coder elementCoder); + + /** + * A base class for a visitor with a default method for cases it is not interested in. + */ + abstract class WithDefault implements Cases { + + protected abstract ResultT dispatchDefault(); + + @Override + public ResultT dispatchValue(Coder valueCoder) { + return dispatchDefault(); + } + + @Override + public ResultT dispatchBag(Coder elementCoder) { + return dispatchDefault(); + } + + @Override + public ResultT dispatchCombining(Combine.CombineFn combineFn, Coder accumCoder) { + return dispatchDefault(); + } + + @Override + public ResultT dispatchMap(Coder keyCoder, Coder valueCoder) { + return dispatchDefault(); + } + + @Override + public ResultT dispatchSet(Coder elementCoder) { + return dispatchDefault(); + } + } + } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/state/StateSpecs.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/state/StateSpecs.java index 7b7138489997..42223047cc58 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/state/StateSpecs.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/state/StateSpecs.java @@ -76,7 +76,9 @@ StateSpec> combining( } /** - * Create a {@link StateSpec} for a {@link CombiningState} which uses a {@link + * For internal use only; no backwards compatibility guarantees + * + *

Create a {@link StateSpec} for a {@link CombiningState} which uses a {@link * CombineFnWithContext} to automatically merge multiple values of type {@code InputT} into a * single resulting {@code OutputT}. * @@ -84,6 +86,7 @@ StateSpec> combining( * * @see #combining(Coder, CombineFnWithContext) */ + @Internal public static StateSpec> combining( CombineFnWithContext combineFn) { @@ -105,11 +108,14 @@ StateSpec> combining( } /** - * Identical to {@link #combining(CombineFnWithContext)}, but with an accumulator coder explicitly - * supplied. + * For internal use only; no backwards compatibility guarantees + * + *

Identical to {@link #combining(CombineFnWithContext)}, but with an accumulator coder + * explicitly supplied. * *

If automatic coder inference fails, use this method. */ + @Internal public static StateSpec> combining( Coder accumCoder, CombineFnWithContext combineFn) { @@ -272,6 +278,11 @@ public ValueState bind(String id, StateBinder visitor) { return visitor.bindValue(id, this, coder); } + @Override + public ResultT match(Cases cases) { + return cases.dispatchValue(coder); + } + @SuppressWarnings("unchecked") @Override public void offerCoders(Coder[] coders) { @@ -336,6 +347,11 @@ public CombiningState bind( return visitor.bindCombining(id, this, accumCoder, combineFn); } + @Override + public ResultT match(Cases cases) { + return cases.dispatchCombining(combineFn, accumCoder); + } + @SuppressWarnings("unchecked") @Override public void offerCoders(Coder[] coders) { @@ -407,6 +423,14 @@ public CombiningState bind( return visitor.bindCombiningWithContext(id, this, accumCoder, combineFn); } + @Override + public ResultT match(Cases cases) { + throw new UnsupportedOperationException( + String.format( + "%s is for internal use only and does not support case dispatch", + getClass().getSimpleName())); + } + @SuppressWarnings("unchecked") @Override public void offerCoders(Coder[] coders) { @@ -474,6 +498,11 @@ public BagState bind(String id, StateBinder visitor) { return visitor.bindBag(id, this, elemCoder); } + @Override + public ResultT match(Cases cases) { + return cases.dispatchBag(elemCoder); + } + @SuppressWarnings("unchecked") @Override public void offerCoders(Coder[] coders) { @@ -530,6 +559,11 @@ public MapState bind(String id, StateBinder visitor) { return visitor.bindMap(id, this, keyCoder, valueCoder); } + @Override + public ResultT match(Cases cases) { + return cases.dispatchMap(keyCoder, valueCoder); + } + @SuppressWarnings("unchecked") @Override public void offerCoders(Coder[] coders) { @@ -594,6 +628,11 @@ public SetState bind(String id, StateBinder visitor) { return visitor.bindSet(id, this, elemCoder); } + @Override + public ResultT match(Cases cases) { + return cases.dispatchSet(elemCoder); + } + @SuppressWarnings("unchecked") @Override public void offerCoders(Coder[] coders) { @@ -657,6 +696,14 @@ public WatermarkHoldState bind(String id, StateBinder visitor) { return visitor.bindWatermark(id, this, timestampCombiner); } + @Override + public ResultT match(Cases cases) { + throw new UnsupportedOperationException( + String.format( + "%s is for internal use only and does not support case dispatch", + getClass().getSimpleName())); + } + @Override public void offerCoders(Coder[] coders) { }