From 1a05a6a66fabfaa2968f81a73df6b8a1a6fc1301 Mon Sep 17 00:00:00 2001 From: Thomas Groh Date: Tue, 11 Apr 2017 16:50:47 -0700 Subject: [PATCH] Translate a Pipeline in SdkComponents --- .../core/construction/PTransforms.java | 17 +-- .../core/construction/SdkComponents.java | 50 +++++++++ .../core/construction/SdkComponentsTest.java | 100 ++++++++++++++++++ 3 files changed, 160 insertions(+), 7 deletions(-) diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransforms.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransforms.java index 6d2c6b6fe053..d25d342bfd65 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransforms.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransforms.java @@ -66,13 +66,16 @@ static RunnerApi.PTransform toProto( components.registerPCollection((PCollection) taggedInput.getValue())); } for (Map.Entry, PValue> taggedOutput : appliedPTransform.getOutputs().entrySet()) { - checkArgument( - taggedOutput.getValue() instanceof PCollection, - "Unexpected output type %s", - taggedOutput.getValue().getClass()); - transformBuilder.putOutputs( - toProto(taggedOutput.getKey()), - components.registerPCollection((PCollection) taggedOutput.getValue())); + // TODO: Remove gating + if (taggedOutput.getValue() instanceof PCollection) { + checkArgument( + taggedOutput.getValue() instanceof PCollection, + "Unexpected output type %s", + taggedOutput.getValue().getClass()); + transformBuilder.putOutputs( + toProto(taggedOutput.getKey()), + components.registerPCollection((PCollection) taggedOutput.getValue())); + } } for (AppliedPTransform subtransform : subtransforms) { transformBuilder.addSubtransforms(components.getExistingPTransformId(subtransform)); diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SdkComponents.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SdkComponents.java index 2de8237c8f36..eb29b9a3ae7a 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SdkComponents.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SdkComponents.java @@ -22,16 +22,24 @@ import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.base.Equivalence; +import com.google.common.collect.ArrayListMultimap; import com.google.common.collect.BiMap; import com.google.common.collect.HashBiMap; +import com.google.common.collect.ListMultimap; import java.io.IOException; +import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; import java.util.List; import java.util.Set; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.Pipeline.PipelineVisitor; import org.apache.beam.sdk.annotations.Experimental; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.common.runner.v1.RunnerApi; import org.apache.beam.sdk.common.runner.v1.RunnerApi.Components; import org.apache.beam.sdk.runners.AppliedPTransform; +import org.apache.beam.sdk.runners.TransformHierarchy.Node; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.util.NameUtils; import org.apache.beam.sdk.values.PCollection; @@ -54,6 +62,48 @@ static SdkComponents create() { return new SdkComponents(); } + public static RunnerApi.Pipeline translatePipeline(Pipeline p) { + final SdkComponents components = create(); + final Collection rootIds = new HashSet<>(); + p.traverseTopologically( + new PipelineVisitor.Defaults() { + private final ListMultimap> children = + ArrayListMultimap.create(); + + @Override + public void leaveCompositeTransform(Node node) { + if (node.isRootNode()) { + for (AppliedPTransform pipelineRoot : children.get(node)) { + rootIds.add(components.getExistingPTransformId(pipelineRoot)); + } + } else { + children.put(node.getEnclosingNode(), node.toAppliedPTransform()); + try { + components.registerPTransform(node.toAppliedPTransform(), children.get(node)); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + } + + @Override + public void visitPrimitiveTransform(Node node) { + children.put(node.getEnclosingNode(), node.toAppliedPTransform()); + try { + components.registerPTransform( + node.toAppliedPTransform(), Collections.>emptyList()); + } catch (IOException e) { + throw new IllegalStateException(e); + } + } + }); + // TODO: Display Data + return RunnerApi.Pipeline.newBuilder() + .setComponents(components.toComponents()) + .addAllRootTransformIds(rootIds) + .build(); + } + private SdkComponents() { this.componentsBuilder = RunnerApi.Components.newBuilder(); this.transformIds = HashBiMap.create(); diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/SdkComponentsTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/SdkComponentsTest.java index 82840d670973..7424886d1009 100644 --- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/SdkComponentsTest.java +++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/SdkComponentsTest.java @@ -24,25 +24,43 @@ import static org.hamcrest.Matchers.not; import static org.junit.Assert.assertThat; +import com.google.common.base.Equivalence; import java.io.IOException; import java.util.Collections; +import java.util.HashSet; +import java.util.Set; +import org.apache.beam.sdk.Pipeline.PipelineVisitor; +import org.apache.beam.sdk.coders.BigEndianLongCoder; import org.apache.beam.sdk.coders.ByteArrayCoder; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.IterableCoder; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.coders.SetCoder; import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.coders.StructuredCoder; import org.apache.beam.sdk.coders.VarLongCoder; +import org.apache.beam.sdk.common.runner.v1.RunnerApi; import org.apache.beam.sdk.common.runner.v1.RunnerApi.Components; import org.apache.beam.sdk.io.GenerateSequence; import org.apache.beam.sdk.runners.AppliedPTransform; +import org.apache.beam.sdk.runners.TransformHierarchy.Node; import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Count; import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.GroupByKey; +import org.apache.beam.sdk.transforms.WithKeys; +import org.apache.beam.sdk.transforms.windowing.AfterPane; +import org.apache.beam.sdk.transforms.windowing.AfterWatermark; +import org.apache.beam.sdk.transforms.windowing.FixedWindows; +import org.apache.beam.sdk.transforms.windowing.Window; +import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PValue; import org.apache.beam.sdk.values.WindowingStrategy; import org.apache.beam.sdk.values.WindowingStrategy.AccumulationMode; import org.hamcrest.Matchers; +import org.joda.time.Duration; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; @@ -59,6 +77,88 @@ public class SdkComponentsTest { private SdkComponents components = SdkComponents.create(); + @Test + public void translatePipeline() { + BigEndianLongCoder customCoder = BigEndianLongCoder.of(); + PCollection elems = pipeline.apply(GenerateSequence.from(0L).to(207L)); + PCollection counted = elems.apply(Count.globally()).setCoder(customCoder); + PCollection windowed = + counted.apply( + Window.into(FixedWindows.of(Duration.standardMinutes(7))) + .triggering( + AfterWatermark.pastEndOfWindow() + .withEarlyFirings(AfterPane.elementCountAtLeast(19))) + .accumulatingFiredPanes() + .withAllowedLateness(Duration.standardMinutes(3L))); + final WindowingStrategy windowedStrategy = windowed.getWindowingStrategy(); + PCollection> keyed = windowed.apply(WithKeys.of("foo")); + PCollection>> grouped = + keyed.apply(GroupByKey.create()); + + final RunnerApi.Pipeline pipelineProto = SdkComponents.translatePipeline(pipeline); + pipeline.traverseTopologically( + new PipelineVisitor() { + Set transforms = new HashSet<>(); + Set> pcollections = new HashSet<>(); + Set>> coders = new HashSet<>(); + Set> windowingStrategies = new HashSet<>(); + + @Override + public CompositeBehavior enterCompositeTransform(Node node) { + return CompositeBehavior.ENTER_TRANSFORM; + } + + @Override + public void leaveCompositeTransform(Node node) { + if (node.isRootNode()) { + assertThat( + "Unexpected number of PTransforms", + pipelineProto.getComponents().getTransformsCount(), + equalTo(transforms.size())); + assertThat( + "Unexpected number of PCollections", + pipelineProto.getComponents().getPcollectionsCount(), + equalTo(pcollections.size())); + assertThat( + "Unexpected number of Coders", + pipelineProto.getComponents().getCodersCount(), + equalTo(coders.size())); + assertThat( + "Unexpected number of Windowing Strategies", + pipelineProto.getComponents().getWindowingStrategiesCount(), + equalTo(windowingStrategies.size())); + } else { + transforms.add(node); + } + } + + @Override + public void visitPrimitiveTransform(Node node) { + transforms.add(node); + } + + @Override + public void visitValue(PValue value, Node producer) { + if (value instanceof PCollection) { + PCollection pc = (PCollection) value; + pcollections.add(pc); + addCoders(pc.getCoder()); + windowingStrategies.add(pc.getWindowingStrategy()); + addCoders(pc.getWindowingStrategy().getWindowFn().windowCoder()); + } + } + + private void addCoders(Coder coder) { + coders.add(Equivalence.>identity().wrap(coder)); + if (coder instanceof StructuredCoder) { + for (Coder component : ((StructuredCoder ) coder).getComponents()) { + addCoders(component); + } + } + } + }); + } + @Test public void registerCoder() throws IOException { Coder coder =