diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransforms.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransforms.java new file mode 100644 index 000000000000..7ec0863860b6 --- /dev/null +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransforms.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.core.construction; + +import static com.google.common.base.Preconditions.checkArgument; + +import com.google.common.collect.ImmutableMap; +import java.io.IOException; +import java.util.List; +import java.util.Map; +import org.apache.beam.sdk.common.runner.v1.RunnerApi; +import org.apache.beam.sdk.common.runner.v1.RunnerApi.FunctionSpec; +import org.apache.beam.sdk.transforms.AppliedPTransform; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PValue; +import org.apache.beam.sdk.values.TupleTag; + +/** + * Utilities for converting {@link PTransform PTransforms} to and from {@link RunnerApi Runner API + * protocol buffers}. + */ +public class PTransforms { + private static final Map, TransformPayloadTranslator> + KNOWN_PAYLOAD_TRANSLATORS = + ImmutableMap., TransformPayloadTranslator>builder().build(); + // TODO: ParDoPayload, WindowIntoPayload, ReadPayload, CombinePayload + // TODO: "Flatten Payload", etc? + // TODO: Load via service loader. + private PTransforms() {} + + /** + * Translates an {@link AppliedPTransform} into a runner API proto. + * + *

Does not register the {@code appliedPTransform} within the provided {@link SdkComponents}. + */ + static RunnerApi.PTransform toProto( + AppliedPTransform appliedPTransform, + List> subtransforms, + SdkComponents components) + throws IOException { + RunnerApi.PTransform.Builder transformBuilder = RunnerApi.PTransform.newBuilder(); + for (Map.Entry, PValue> taggedInput : appliedPTransform.getInputs().entrySet()) { + checkArgument( + taggedInput.getValue() instanceof PCollection, + "Unexpected input type %s", + taggedInput.getValue().getClass()); + transformBuilder.putInputs( + toProto(taggedInput.getKey()), + components.registerPCollection((PCollection) taggedInput.getValue())); + } + for (Map.Entry, PValue> taggedOutput : appliedPTransform.getOutputs().entrySet()) { + checkArgument( + taggedOutput.getValue() instanceof PCollection, + "Unexpected output type %s", + taggedOutput.getValue().getClass()); + transformBuilder.putOutputs( + toProto(taggedOutput.getKey()), + components.registerPCollection((PCollection) taggedOutput.getValue())); + } + for (AppliedPTransform subtransform : subtransforms) { + transformBuilder.addSubtransforms(components.getExistingPTransformId(subtransform)); + } + + transformBuilder.setUniqueName(appliedPTransform.getFullName()); + // TODO: Display Data + + PTransform transform = appliedPTransform.getTransform(); + if (KNOWN_PAYLOAD_TRANSLATORS.containsKey(transform.getClass())) { + FunctionSpec payload = + KNOWN_PAYLOAD_TRANSLATORS + .get(transform.getClass()) + .translate(appliedPTransform, components); + transformBuilder.setSpec(payload); + } + + return transformBuilder.build(); + } + + private static String toProto(TupleTag tag) { + return tag.getId(); + } + + /** + * A translator consumes a {@link PTransform} application and produces the appropriate + * FunctionSpec for a distinguished or primitive transform within the Beam runner API. + */ + public interface TransformPayloadTranslator> { + FunctionSpec translate(AppliedPTransform transform, SdkComponents components); + } +} diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SdkComponents.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SdkComponents.java index 3f1748514845..35af3006d2fe 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SdkComponents.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SdkComponents.java @@ -18,10 +18,14 @@ package org.apache.beam.runners.core.construction; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; + import com.google.common.base.Equivalence; import com.google.common.collect.BiMap; import com.google.common.collect.HashBiMap; import java.io.IOException; +import java.util.List; import java.util.Set; import org.apache.beam.sdk.annotations.Experimental; import org.apache.beam.sdk.coders.Coder; @@ -62,20 +66,52 @@ private SdkComponents() { * Registers the provided {@link AppliedPTransform} into this {@link SdkComponents}, returning a * unique ID for the {@link AppliedPTransform}. Multiple registrations of the same * {@link AppliedPTransform} will return the same unique ID. + * + *

All of the children must already be registered within this {@link SdkComponents}. */ - String registerPTransform(AppliedPTransform pTransform) { - String existing = transformIds.get(pTransform); + String registerPTransform( + AppliedPTransform appliedPTransform, List> children) + throws IOException { + String name = getApplicationName(appliedPTransform); + // If this transform is present in the components, nothing to do. return the existing name. + // Otherwise the transform must be translated and added to the components. + if (componentsBuilder.getTransformsOrDefault(name, null) != null) { + return name; + } + checkNotNull(children, "child nodes may not be null"); + componentsBuilder.putTransforms(name, PTransforms.toProto(appliedPTransform, children, this)); + return name; + } + + /** + * Gets the ID for the provided {@link AppliedPTransform}. The provided {@link AppliedPTransform} + * will not be added to the components produced by this {@link SdkComponents} until it is + * translated via {@link #registerPTransform(AppliedPTransform, List)}. + */ + private String getApplicationName(AppliedPTransform appliedPTransform) { + String existing = transformIds.get(appliedPTransform); if (existing != null) { return existing; } - String name = pTransform.getFullName(); + + String name = appliedPTransform.getFullName(); if (name.isEmpty()) { - name = uniqify("unnamed_ptransform", transformIds.values()); + name = "unnamed-ptransform"; } - transformIds.put(pTransform, name); + name = uniqify(name, transformIds.values()); + transformIds.put(appliedPTransform, name); return name; } + String getExistingPTransformId(AppliedPTransform appliedPTransform) { + checkArgument( + transformIds.containsKey(appliedPTransform), + "%s %s has not been previously registered", + AppliedPTransform.class.getSimpleName(), + appliedPTransform); + return transformIds.get(appliedPTransform); + } + /** * Registers the provided {@link PCollection} into this {@link SdkComponents}, returning a unique * ID for the {@link PCollection}. Multiple registrations of the same {@link PCollection} will diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/PTransformsTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/PTransformsTest.java new file mode 100644 index 000000000000..4e3cdb63bcb5 --- /dev/null +++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/PTransformsTest.java @@ -0,0 +1,189 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.core.construction; + +import static org.hamcrest.Matchers.equalTo; +import static org.junit.Assert.assertThat; + +import com.google.auto.value.AutoValue; +import com.google.common.collect.ImmutableList; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.common.runner.v1.RunnerApi; +import org.apache.beam.sdk.common.runner.v1.RunnerApi.Components; +import org.apache.beam.sdk.common.runner.v1.RunnerApi.PTransform; +import org.apache.beam.sdk.io.CountingInput; +import org.apache.beam.sdk.io.CountingInput.UnboundedCountingInput; +import org.apache.beam.sdk.io.CountingSource; +import org.apache.beam.sdk.io.Read; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.AppliedPTransform; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.View; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PBegin; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionTuple; +import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.PValue; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.TupleTagList; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; + +/** + * Tests for {@link PTransforms}. + */ +@RunWith(Parameterized.class) +public class PTransformsTest { + + @Parameters(name = "{index}: {0}") + public static Iterable data() { + // This pipeline exists for construction, not to run any test. + // TODO: Leaf node with understood payload - i.e. validate payloads + ToAndFromProtoSpec readLeaf = ToAndFromProtoSpec.leaf(read(TestPipeline.create())); + ToAndFromProtoSpec readMultipleInAndOut = + ToAndFromProtoSpec.leaf(multiMultiParDo(TestPipeline.create())); + TestPipeline compositeReadPipeline = TestPipeline.create(); + ToAndFromProtoSpec compositeRead = + ToAndFromProtoSpec.composite( + countingInput(compositeReadPipeline), + ToAndFromProtoSpec.leaf(read(compositeReadPipeline))); + return ImmutableList.builder() + .add(readLeaf) + .add(readMultipleInAndOut) + .add(compositeRead) + // TODO: Composite with multiple children + // TODO: Composite with a composite child + .build(); + } + + @AutoValue + abstract static class ToAndFromProtoSpec { + public static ToAndFromProtoSpec leaf(AppliedPTransform transform) { + return new AutoValue_PTransformsTest_ToAndFromProtoSpec( + transform, Collections.emptyList()); + } + + public static ToAndFromProtoSpec composite( + AppliedPTransform topLevel, ToAndFromProtoSpec spec, ToAndFromProtoSpec... specs) { + List childSpecs = new ArrayList<>(); + childSpecs.add(spec); + childSpecs.addAll(Arrays.asList(specs)); + return new AutoValue_PTransformsTest_ToAndFromProtoSpec(topLevel, childSpecs); + } + + abstract AppliedPTransform getTransform(); + abstract Collection getChildren(); + } + + @Parameter(0) + public ToAndFromProtoSpec spec; + + @Test + public void toAndFromProto() throws IOException { + SdkComponents components = SdkComponents.create(); + RunnerApi.PTransform converted = convert(spec, components); + Components protoComponents = components.toComponents(); + + // Sanity checks + assertThat(converted.getInputsCount(), equalTo(spec.getTransform().getInputs().size())); + assertThat(converted.getOutputsCount(), equalTo(spec.getTransform().getOutputs().size())); + assertThat(converted.getSubtransformsCount(), equalTo(spec.getChildren().size())); + + assertThat(converted.getUniqueName(), equalTo(spec.getTransform().getFullName())); + for (PValue inputValue : spec.getTransform().getInputs().values()) { + PCollection inputPc = (PCollection) inputValue; + protoComponents.getPcollectionsOrThrow(components.registerPCollection(inputPc)); + } + for (PValue outputValue : spec.getTransform().getOutputs().values()) { + PCollection outputPc = (PCollection) outputValue; + protoComponents.getPcollectionsOrThrow(components.registerPCollection(outputPc)); + } + } + + private RunnerApi.PTransform convert(ToAndFromProtoSpec spec, SdkComponents components) + throws IOException { + List> childTransforms = new ArrayList<>(); + for (ToAndFromProtoSpec child : spec.getChildren()) { + childTransforms.add(child.getTransform()); + System.out.println("Converting child " + child); + convert(child, components); + // Sanity call + components.getExistingPTransformId(child.getTransform()); + } + PTransform convert = PTransforms.toProto(spec.getTransform(), childTransforms, components); + // Make sure the converted transform is registered. Convert it independently, but if this is a + // child spec, the child must be in the components. + components.registerPTransform(spec.getTransform(), childTransforms); + return convert; + } + + private static class TestDoFn extends DoFn> { + // Exists to stop the ParDo application from throwing + @ProcessElement public void process(ProcessContext context) {} + } + + private static AppliedPTransform countingInput(Pipeline pipeline) { + UnboundedCountingInput input = CountingInput.unbounded(); + PCollection pcollection = pipeline.apply(input); + return AppliedPTransform., UnboundedCountingInput>of( + "Count", pipeline.begin().expand(), pcollection.expand(), input, pipeline); + } + + private static AppliedPTransform read(Pipeline pipeline) { + Read.Unbounded transform = Read.from(CountingSource.unbounded()); + PCollection pcollection = pipeline.apply(transform); + return AppliedPTransform., Read.Unbounded>of( + "ReadTheCount", pipeline.begin().expand(), pcollection.expand(), transform, pipeline); + } + + private static AppliedPTransform multiMultiParDo(Pipeline pipeline) { + PCollectionView view = + pipeline.apply(Create.of("foo")).apply(View.asSingleton()); + PCollection input = pipeline.apply(CountingInput.unbounded()); + ParDo.MultiOutput> parDo = + ParDo.of(new TestDoFn()) + .withSideInputs(view) + .withOutputTags( + new TupleTag>() {}, + TupleTagList.of(new TupleTag>() {})); + PCollectionTuple output = input.apply(parDo); + + Map, PValue> inputs = new HashMap<>(); + inputs.putAll(parDo.getAdditionalInputs()); + inputs.putAll(input.expand()); + + return AppliedPTransform + ., PCollectionTuple, ParDo.MultiOutput>>of( + "MultiParDoInAndOut", inputs, output.expand(), parDo, pipeline); + } +} diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/SdkComponentsTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/SdkComponentsTest.java index 1854e5a449e0..895aec48572d 100644 --- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/SdkComponentsTest.java +++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/SdkComponentsTest.java @@ -25,6 +25,7 @@ import static org.junit.Assert.assertThat; import java.io.IOException; +import java.util.Collections; import org.apache.beam.sdk.coders.ByteArrayCoder; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.IterableCoder; @@ -32,6 +33,7 @@ import org.apache.beam.sdk.coders.SetCoder; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.coders.VarLongCoder; +import org.apache.beam.sdk.common.runner.v1.RunnerApi.Components; import org.apache.beam.sdk.io.CountingInput; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.AppliedPTransform; @@ -87,28 +89,92 @@ public void registerCoderEqualsNotSame() throws IOException { } @Test - public void registerTransform() { + public void registerTransformNoChildren() throws IOException { Create.Values create = Create.of(1, 2, 3); PCollection pt = pipeline.apply(create); String userName = "my_transform/my_nesting"; AppliedPTransform transform = AppliedPTransform., Create.Values>of( userName, pipeline.begin().expand(), pt.expand(), create, pipeline); - String componentName = components.registerPTransform(transform); + String componentName = + components.registerPTransform( + transform, Collections.>emptyList()); assertThat(componentName, equalTo(userName)); - assertThat(components.registerPTransform(transform), equalTo(componentName)); + assertThat(components.getExistingPTransformId(transform), equalTo(componentName)); } @Test - public void registerTransformIdEmptyFullName() { + public void registerTransformAfterChildren() throws IOException { + Create.Values create = Create.of(1L, 2L, 3L); + CountingInput.UnboundedCountingInput createChild = CountingInput.unbounded(); + + PCollection pt = pipeline.apply(create); + String userName = "my_transform"; + String childUserName = "my_transform/my_nesting"; + AppliedPTransform transform = + AppliedPTransform., Create.Values>of( + userName, pipeline.begin().expand(), pt.expand(), create, pipeline); + AppliedPTransform childTransform = + AppliedPTransform., CountingInput.UnboundedCountingInput>of( + childUserName, pipeline.begin().expand(), pt.expand(), createChild, pipeline); + + String childId = components.registerPTransform(childTransform, + Collections.>emptyList()); + String parentId = components.registerPTransform(transform, + Collections.>singletonList(childTransform)); + Components components = this.components.toComponents(); + assertThat(components.getTransformsOrThrow(parentId).getSubtransforms(0), equalTo(childId)); + assertThat(components.getTransformsOrThrow(childId).getSubtransformsCount(), equalTo(0)); + } + + @Test + public void registerTransformEmptyFullName() throws IOException { Create.Values create = Create.of(1, 2, 3); PCollection pt = pipeline.apply(create); AppliedPTransform transform = AppliedPTransform., Create.Values>of( "", pipeline.begin().expand(), pt.expand(), create, pipeline); - String assignedName = components.registerPTransform(transform); - assertThat(assignedName, not(isEmptyOrNullString())); + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage(transform.toString()); + components.getExistingPTransformId(transform); + } + + @Test + public void registerTransformNullComponents() throws IOException { + Create.Values create = Create.of(1, 2, 3); + PCollection pt = pipeline.apply(create); + String userName = "my_transform/my_nesting"; + AppliedPTransform transform = + AppliedPTransform., Create.Values>of( + userName, pipeline.begin().expand(), pt.expand(), create, pipeline); + thrown.expect(NullPointerException.class); + thrown.expectMessage("child nodes may not be null"); + components.registerPTransform(transform, null); + } + + /** + * Tests that trying to register a transform which has unregistered children throws. + */ + @Test + public void registerTransformWithUnregisteredChildren() throws IOException { + Create.Values create = Create.of(1L, 2L, 3L); + CountingInput.UnboundedCountingInput createChild = CountingInput.unbounded(); + + PCollection pt = pipeline.apply(create); + String userName = "my_transform"; + String childUserName = "my_transform/my_nesting"; + AppliedPTransform transform = + AppliedPTransform., Create.Values>of( + userName, pipeline.begin().expand(), pt.expand(), create, pipeline); + AppliedPTransform childTransform = + AppliedPTransform., CountingInput.UnboundedCountingInput>of( + childUserName, pipeline.begin().expand(), pt.expand(), createChild, pipeline); + + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage(childTransform.toString()); + components.registerPTransform( + transform, Collections.>singletonList(childTransform)); } @Test diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Combine.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Combine.java index 8fe483109127..e69d6d87b726 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Combine.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Combine.java @@ -1444,7 +1444,7 @@ public List> getSideInputs() { public Map, PValue> getAdditionalInputs() { ImmutableMap.Builder, PValue> additionalInputs = ImmutableMap.builder(); for (PCollectionView sideInput : sideInputs) { - additionalInputs.put(sideInput.getTagInternal(), sideInput); + additionalInputs.put(sideInput.getTagInternal(), sideInput.getPCollection()); } return additionalInputs.build(); } @@ -1900,7 +1900,7 @@ public List> getSideInputs() { public Map, PValue> getAdditionalInputs() { ImmutableMap.Builder, PValue> additionalInputs = ImmutableMap.builder(); for (PCollectionView sideInput : sideInputs) { - additionalInputs.put(sideInput.getTagInternal(), sideInput); + additionalInputs.put(sideInput.getTagInternal(), sideInput.getPCollection()); } return additionalInputs.build(); } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java index 3de845b6b46c..d83270a7f48d 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java @@ -664,7 +664,7 @@ public List> getSideInputs() { public Map, PValue> getAdditionalInputs() { ImmutableMap.Builder, PValue> additionalInputs = ImmutableMap.builder(); for (PCollectionView sideInput : sideInputs) { - additionalInputs.put(sideInput.getTagInternal(), sideInput); + additionalInputs.put(sideInput.getTagInternal(), sideInput.getPCollection()); } return additionalInputs.build(); } @@ -811,7 +811,7 @@ public List> getSideInputs() { public Map, PValue> getAdditionalInputs() { ImmutableMap.Builder, PValue> additionalInputs = ImmutableMap.builder(); for (PCollectionView sideInput : sideInputs) { - additionalInputs.put(sideInput.getTagInternal(), sideInput); + additionalInputs.put(sideInput.getTagInternal(), sideInput.getPCollection()); } return additionalInputs.build(); }