From bcf92b8d1c2218dd17cfee687497d26cfdb77b2f Mon Sep 17 00:00:00 2001 From: Thomas Groh Date: Tue, 4 Apr 2017 17:43:48 -0700 Subject: [PATCH 1/5] Include Additional PTransform inputs in Transform Nodes Add the value of PTransform.getAdditionalInputs in the inputs of a TransformHierarchy node. Fork the Node constructor to reduce nullability This slightly simplifies the constructor implementation(s). Update the DirectRunner to track main inputs instead of all inputs. --- .../runners/direct/DirectGraphVisitor.java | 9 +++- .../runners/direct/ParDoEvaluatorFactory.java | 9 ++-- ...ttableProcessElementsEvaluatorFactory.java | 2 + .../direct/StatefulParDoEvaluatorFactory.java | 1 + .../beam/runners/direct/TransformInputs.java | 50 +++++++++++++++++++ .../beam/runners/direct/WatermarkManager.java | 17 ++++--- .../runners/direct/ParDoEvaluatorTest.java | 6 +-- .../beam/sdk/runners/TransformHierarchy.java | 28 +++++++---- 8 files changed, 98 insertions(+), 24 deletions(-) create mode 100644 runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformInputs.java diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGraphVisitor.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGraphVisitor.java index 01204e3049dd..decabefbeac1 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGraphVisitor.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGraphVisitor.java @@ -21,6 +21,7 @@ import com.google.common.collect.ArrayListMultimap; import com.google.common.collect.ListMultimap; +import java.util.Collection; import java.util.HashMap; import java.util.HashSet; import java.util.Map; @@ -83,7 +84,13 @@ public void visitPrimitiveTransform(TransformHierarchy.Node node) { if (node.getInputs().isEmpty()) { rootTransforms.add(appliedTransform); } else { - for (PValue value : node.getInputs().values()) { + Collection mainInputs = + TransformInputs.nonAdditionalInputs(node.toAppliedPTransform()); + if (!mainInputs.containsAll(node.getInputs().values())) { + System.out.printf( + "Main inputs reduced to %s from %s%n", mainInputs, node.getInputs().values()); + } + for (PValue value : mainInputs) { primitiveConsumers.put(value, appliedTransform); } } diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoEvaluatorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoEvaluatorFactory.java index 74470bfb8b8d..c52091e27cfe 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoEvaluatorFactory.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoEvaluatorFactory.java @@ -20,7 +20,6 @@ import com.google.common.cache.CacheBuilder; import com.google.common.cache.CacheLoader; import com.google.common.cache.LoadingCache; -import com.google.common.collect.Iterables; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -79,6 +78,7 @@ public TransformEvaluator forApplication( (TransformEvaluator) createEvaluator( (AppliedPTransform) application, + (PCollection) inputBundle.getPCollection(), inputBundle.getKey(), doFn, transform.getSideInputs(), @@ -102,6 +102,7 @@ public void cleanup() throws Exception { @SuppressWarnings({"unchecked", "rawtypes"}) DoFnLifecycleManagerRemovingTransformEvaluator createEvaluator( AppliedPTransform, PCollectionTuple, ?> application, + PCollection mainInput, StructuralKey inputBundleKey, DoFn doFn, List> sideInputs, @@ -120,6 +121,7 @@ DoFnLifecycleManagerRemovingTransformEvaluator createEvaluator( createParDoEvaluator( application, inputBundleKey, + mainInput, sideInputs, mainOutputTag, additionalOutputTags, @@ -132,6 +134,7 @@ DoFnLifecycleManagerRemovingTransformEvaluator createEvaluator( ParDoEvaluator createParDoEvaluator( AppliedPTransform, PCollectionTuple, ?> application, StructuralKey key, + PCollection mainInput, List> sideInputs, TupleTag mainOutputTag, List> additionalOutputTags, @@ -144,8 +147,7 @@ ParDoEvaluator createParDoEvaluator( evaluationContext, stepContext, application, - ((PCollection) Iterables.getOnlyElement(application.getInputs().values())) - .getWindowingStrategy(), + mainInput.getWindowingStrategy(), fn, key, sideInputs, @@ -173,5 +175,4 @@ static Map, PCollection> pcollections(Map, PValue> ou } return pcs; } - } diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/SplittableProcessElementsEvaluatorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/SplittableProcessElementsEvaluatorFactory.java index dc85d87bc93f..4e7f4db65478 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/SplittableProcessElementsEvaluatorFactory.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/SplittableProcessElementsEvaluatorFactory.java @@ -116,6 +116,8 @@ public void cleanup() throws Exception { delegateFactory.createParDoEvaluator( application, inputBundle.getKey(), + (PCollection>>) + inputBundle.getPCollection(), transform.getSideInputs(), transform.getMainOutputTag(), transform.getAdditionalOutputTags().getAll(), diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactory.java index 985c3be4e9e9..e22edd187c53 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactory.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactory.java @@ -117,6 +117,7 @@ private TransformEvaluator>> createEvaluator( DoFnLifecycleManagerRemovingTransformEvaluator> delegateEvaluator = delegateFactory.createEvaluator( (AppliedPTransform) application, + (PCollection) inputBundle.getPCollection(), inputBundle.getKey(), doFn, application.getTransform().getUnderlyingParDo().getSideInputs(), diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformInputs.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformInputs.java new file mode 100644 index 000000000000..4b06d313d2a4 --- /dev/null +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformInputs.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.direct; + +import static com.google.common.base.Preconditions.checkArgument; + +import com.google.common.collect.ImmutableList; +import java.util.Collection; +import java.util.Map; +import org.apache.beam.sdk.runners.AppliedPTransform; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.values.PValue; +import org.apache.beam.sdk.values.TupleTag; + +/** Utilities for extracting subsets of inputs from an {@link AppliedPTransform}. */ +class TransformInputs { + /** + * Gets all inputs of the {@link AppliedPTransform} that are not returned by {@link + * PTransform#getAdditionalInputs()}. + */ + public static Collection nonAdditionalInputs(AppliedPTransform application) { + ImmutableList.Builder mainInputs = ImmutableList.builder(); + PTransform transform = application.getTransform(); + for (Map.Entry, PValue> input : application.getInputs().entrySet()) { + if (!transform.getAdditionalInputs().containsKey(input.getKey())) { + mainInputs.add(input.getValue()); + } + } + checkArgument( + !mainInputs.build().isEmpty() || application.getInputs().isEmpty(), + "Expected at least one main input if any inputs exist"); + return mainInputs.build(); + } +} diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WatermarkManager.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WatermarkManager.java index 4f1b8319dc2d..b15b52e314de 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WatermarkManager.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WatermarkManager.java @@ -823,10 +823,11 @@ private Collection getInputProcessingWatermarks(AppliedPTransform getInputWatermarks(AppliedPTransform transform) inputWatermarksBuilder.add(THE_END_OF_TIME); } for (PValue pvalue : inputs.values()) { - Watermark producerOutputWatermark = - getTransformWatermark(graph.getProducer(pvalue)).outputWatermark; - inputWatermarksBuilder.add(producerOutputWatermark); + if (graph.getPrimitiveConsumers(pvalue).contains(transform)) { + Watermark producerOutputWatermark = + getTransformWatermark(graph.getProducer(pvalue)).outputWatermark; + inputWatermarksBuilder.add(producerOutputWatermark); + } } List inputCollectionWatermarks = inputWatermarksBuilder.build(); return inputCollectionWatermarks; diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ParDoEvaluatorTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ParDoEvaluatorTest.java index 286e44d1be04..3b2a22ee26f9 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ParDoEvaluatorTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ParDoEvaluatorTest.java @@ -98,7 +98,7 @@ public void sideInputsNotReadyResultHasUnprocessedElements() { when(evaluationContext.createBundle(output)).thenReturn(outputBundle); ParDoEvaluator evaluator = - createEvaluator(singletonView, fn, output); + createEvaluator(singletonView, fn, inputPc, output); IntervalWindow nonGlobalWindow = new IntervalWindow(new Instant(0), new Instant(10_000L)); WindowedValue first = WindowedValue.valueInGlobalWindow(3); @@ -132,6 +132,7 @@ public void sideInputsNotReadyResultHasUnprocessedElements() { private ParDoEvaluator createEvaluator( PCollectionView singletonView, RecorderFn fn, + PCollection input, PCollection output) { when( evaluationContext.createSideInputReader( @@ -156,8 +157,7 @@ private ParDoEvaluator createEvaluator( evaluationContext, stepContext, transform, - ((PCollection) Iterables.getOnlyElement(transform.getInputs().values())) - .getWindowingStrategy(), + input.getWindowingStrategy(), fn, null /* key */, ImmutableList.>of(singletonView), diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformHierarchy.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformHierarchy.java index 2f0e8efd7de8..9d73b4576788 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformHierarchy.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformHierarchy.java @@ -32,7 +32,6 @@ import java.util.Map; import java.util.Map.Entry; import java.util.Set; -import javax.annotation.Nullable; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.Pipeline.PipelineVisitor; import org.apache.beam.sdk.Pipeline.PipelineVisitor.CompositeBehavior; @@ -68,7 +67,7 @@ public TransformHierarchy() { producers = new HashMap<>(); producerInput = new HashMap<>(); unexpandedInputs = new HashMap<>(); - root = new Node(null, null, "", null); + root = new Node(); current = root; } @@ -252,26 +251,37 @@ public class Node { @VisibleForTesting boolean finishedSpecifying = false; + /** + * Creates the root-level node. The root level node has a null enclosing node, a null transform, + * an empty map of inputs, and a name equal to the empty string. + */ + private Node() { + this.enclosingNode = null; + this.transform = null; + this.fullName = ""; + this.inputs = Collections.emptyMap(); + } + /** * Creates a new Node with the given parent and transform. * - *

EnclosingNode and transform may both be null for a root-level node, which holds all other - * nodes. - * * @param enclosingNode the composite node containing this node * @param transform the PTransform tracked by this node * @param fullName the fully qualified name of the transform * @param input the unexpanded input to the transform */ private Node( - @Nullable Node enclosingNode, - @Nullable PTransform transform, + Node enclosingNode, + PTransform transform, String fullName, - @Nullable PInput input) { + PInput input) { this.enclosingNode = enclosingNode; this.transform = transform; this.fullName = fullName; - this.inputs = input == null ? Collections., PValue>emptyMap() : input.expand(); + ImmutableMap.Builder, PValue> inputs = ImmutableMap.builder(); + inputs.putAll(input.expand()); + inputs.putAll(transform.getAdditionalInputs()); + this.inputs = inputs.build(); } /** From 2ce3d6e783d94706be71a96e2ff3dc63084c3176 Mon Sep 17 00:00:00 2001 From: Thomas Groh Date: Fri, 19 May 2017 15:02:29 -0700 Subject: [PATCH 2/5] fixup! Include Additional PTransform inputs in Transform Nodes --- .../apache/beam/runners/direct/DirectGraphVisitor.java | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGraphVisitor.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGraphVisitor.java index decabefbeac1..ce387fd5f855 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGraphVisitor.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGraphVisitor.java @@ -35,6 +35,8 @@ import org.apache.beam.sdk.values.PInput; import org.apache.beam.sdk.values.POutput; import org.apache.beam.sdk.values.PValue; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * Tracks the {@link AppliedPTransform AppliedPTransforms} that consume each {@link PValue} in the @@ -42,6 +44,7 @@ * input after the upstream transform has produced and committed output. */ class DirectGraphVisitor extends PipelineVisitor.Defaults { + private static final Logger LOG = LoggerFactory.getLogger(DirectGraphVisitor.class); private Map> producers = new HashMap<>(); @@ -87,8 +90,10 @@ public void visitPrimitiveTransform(TransformHierarchy.Node node) { Collection mainInputs = TransformInputs.nonAdditionalInputs(node.toAppliedPTransform()); if (!mainInputs.containsAll(node.getInputs().values())) { - System.out.printf( - "Main inputs reduced to %s from %s%n", mainInputs, node.getInputs().values()); + LOG.debug( + "Inputs reduced to {} from {} by removing additional inputs", + mainInputs, + node.getInputs().values()); } for (PValue value : mainInputs) { primitiveConsumers.put(value, appliedTransform); From 0022c12cb60ce01d11fcbb999564aae5a7ffb7be Mon Sep 17 00:00:00 2001 From: Thomas Groh Date: Mon, 22 May 2017 10:18:43 -0700 Subject: [PATCH 3/5] fixup! Include Additional PTransform inputs in Transform Nodes --- .../apex/translation/TranslationContext.java | 4 +- .../core/construction}/TransformInputs.java | 4 +- .../construction/TransformInputsTest.java | 161 ++++++++++++++++++ .../runners/direct/DirectGraphVisitor.java | 1 + 4 files changed, 167 insertions(+), 3 deletions(-) rename runners/{direct-java/src/main/java/org/apache/beam/runners/direct => core-construction-java/src/main/java/org/apache/beam/runners/core/construction}/TransformInputs.java (96%) create mode 100644 runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/TransformInputsTest.java diff --git a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/TranslationContext.java b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/TranslationContext.java index aff3863624c4..94d13e177dec 100644 --- a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/TranslationContext.java +++ b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/TranslationContext.java @@ -34,6 +34,7 @@ import org.apache.beam.runners.apex.translation.utils.ApexStateInternals.ApexStateBackend; import org.apache.beam.runners.apex.translation.utils.ApexStreamTuple; import org.apache.beam.runners.apex.translation.utils.CoderAdapterStreamCodec; +import org.apache.beam.runners.core.construction.TransformInputs; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.runners.AppliedPTransform; import org.apache.beam.sdk.transforms.PTransform; @@ -93,7 +94,8 @@ public Map, PValue> getInputs() { } public InputT getInput() { - return (InputT) Iterables.getOnlyElement(getCurrentTransform().getInputs().values()); + return (InputT) + Iterables.getOnlyElement(TransformInputs.nonAdditionalInputs(getCurrentTransform())); } public Map, PValue> getOutputs() { diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformInputs.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/TransformInputs.java similarity index 96% rename from runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformInputs.java rename to runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/TransformInputs.java index 4b06d313d2a4..2baf93a3c128 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformInputs.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/TransformInputs.java @@ -16,7 +16,7 @@ * limitations under the License. */ -package org.apache.beam.runners.direct; +package org.apache.beam.runners.core.construction; import static com.google.common.base.Preconditions.checkArgument; @@ -29,7 +29,7 @@ import org.apache.beam.sdk.values.TupleTag; /** Utilities for extracting subsets of inputs from an {@link AppliedPTransform}. */ -class TransformInputs { +public class TransformInputs { /** * Gets all inputs of the {@link AppliedPTransform} that are not returned by {@link * PTransform#getAdditionalInputs()}. diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/TransformInputsTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/TransformInputsTest.java new file mode 100644 index 000000000000..33a4b31fde1e --- /dev/null +++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/TransformInputsTest.java @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.core.construction; + +import static org.junit.Assert.assertThat; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import org.apache.beam.sdk.coders.VoidCoder; +import org.apache.beam.sdk.io.GenerateSequence; +import org.apache.beam.sdk.runners.AppliedPTransform; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PDone; +import org.apache.beam.sdk.values.PInput; +import org.apache.beam.sdk.values.POutput; +import org.apache.beam.sdk.values.PValue; +import org.apache.beam.sdk.values.TupleTag; +import org.hamcrest.Matchers; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link TransformInputs}. */ +@RunWith(JUnit4.class) +public class TransformInputsTest { + @Rule public TestPipeline pipeline = TestPipeline.create().enableAbandonedNodeEnforcement(false); + @Rule public ExpectedException thrown = ExpectedException.none(); + + @Test + public void nonAdditionalInputsWithNoInputSucceeds() { + AppliedPTransform transform = + AppliedPTransform.of( + "input-free", + Collections., PValue>emptyMap(), + Collections., PValue>emptyMap(), + new TestTransform(), + pipeline); + + assertThat(TransformInputs.nonAdditionalInputs(transform), Matchers.empty()); + } + + @Test + public void nonAdditionalInputsWithOneMainInputSucceeds() { + PCollection input = pipeline.apply(GenerateSequence.from(1L)); + AppliedPTransform transform = + AppliedPTransform.of( + "input-single", + Collections., PValue>singletonMap(new TupleTag() {}, input), + Collections., PValue>emptyMap(), + new TestTransform(), + pipeline); + + assertThat( + TransformInputs.nonAdditionalInputs(transform), Matchers.containsInAnyOrder(input)); + } + + @Test + public void nonAdditionalInputsWithMultipleNonAdditionalInputsSucceeds() { + PCollection input = pipeline.apply(GenerateSequence.from(1L)); + AppliedPTransform transform = + AppliedPTransform.of( + "additional-free", + Collections., PValue>singletonMap(new TupleTag() {}, input), + Collections., PValue>emptyMap(), + new TestTransform(), + pipeline); + + assertThat( + TransformInputs.nonAdditionalInputs(transform), Matchers.containsInAnyOrder(input)); + } + + @Test + public void nonAdditionalInputsWithAdditionalInputsSucceeds() { + Map, PValue> additionalInputs = new HashMap<>(); + additionalInputs.put(new TupleTag() {}, pipeline.apply(Create.of("1, 2", "3"))); + additionalInputs.put(new TupleTag() {}, pipeline.apply(GenerateSequence.from(3L))); + + Map, PValue> allInputs = new HashMap<>(); + PCollection mainInts = pipeline.apply("MainInput", Create.of(12, 3)); + allInputs.put(new TupleTag() {}, mainInts); + PCollection voids = pipeline.apply("VoidInput", Create.empty(VoidCoder.of())); + allInputs.put( + new TupleTag() {}, voids); + allInputs.putAll(additionalInputs); + + AppliedPTransform transform = + AppliedPTransform.of( + "additional", + allInputs, + Collections., PValue>emptyMap(), + new TestTransform(additionalInputs), + pipeline); + + assertThat( + TransformInputs.nonAdditionalInputs(transform), + Matchers.containsInAnyOrder(mainInts, voids)); + } + + @Test + public void nonAdditionalInputsWithOnlyAdditionalInputsThrows() { + Map, PValue> additionalInputs = new HashMap<>(); + additionalInputs.put(new TupleTag() {}, pipeline.apply(Create.of("1, 2", "3"))); + additionalInputs.put(new TupleTag() {}, pipeline.apply(GenerateSequence.from(3L))); + + AppliedPTransform transform = + AppliedPTransform.of( + "additional-only", + additionalInputs, + Collections., PValue>emptyMap(), + new TestTransform(additionalInputs), + pipeline); + + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("at least one"); + TransformInputs.nonAdditionalInputs(transform); + } + + private static class TestTransform extends PTransform { + private final Map, PValue> additionalInputs; + + private TestTransform() { + this(Collections., PValue>emptyMap()); + } + + private TestTransform(Map, PValue> additionalInputs) { + this.additionalInputs = additionalInputs; + } + + @Override + public POutput expand(PInput input) { + return PDone.in(input.getPipeline()); + } + + @Override + public Map, PValue> getAdditionalInputs() { + return additionalInputs; + } + } +} diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGraphVisitor.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGraphVisitor.java index ce387fd5f855..438e47b559d5 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGraphVisitor.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGraphVisitor.java @@ -26,6 +26,7 @@ import java.util.HashSet; import java.util.Map; import java.util.Set; +import org.apache.beam.runners.core.construction.TransformInputs; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.Pipeline.PipelineVisitor; import org.apache.beam.sdk.runners.AppliedPTransform; From e5dc6c16426a8b784020ebdc1707bec8b62e863e Mon Sep 17 00:00:00 2001 From: Thomas Groh Date: Mon, 22 May 2017 12:40:27 -0700 Subject: [PATCH 4/5] fixup! Include Additional PTransform inputs in Transform Nodes --- .../runners/core/construction/TransformInputsTest.java | 10 +++++++--- .../runners/flink/FlinkBatchTranslationContext.java | 3 ++- .../flink/FlinkStreamingTranslationContext.java | 3 ++- .../runners/spark/translation/EvaluationContext.java | 4 +++- 4 files changed, 14 insertions(+), 6 deletions(-) diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/TransformInputsTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/TransformInputsTest.java index 33a4b31fde1e..a03fde895af8 100644 --- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/TransformInputsTest.java +++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/TransformInputsTest.java @@ -78,17 +78,21 @@ public void nonAdditionalInputsWithOneMainInputSucceeds() { @Test public void nonAdditionalInputsWithMultipleNonAdditionalInputsSucceeds() { - PCollection input = pipeline.apply(GenerateSequence.from(1L)); + Map, PValue> allInputs = new HashMap<>(); + PCollection mainInts = pipeline.apply("MainInput", Create.of(12, 3)); + allInputs.put(new TupleTag() {}, mainInts); + PCollection voids = pipeline.apply("VoidInput", Create.empty(VoidCoder.of())); AppliedPTransform transform = AppliedPTransform.of( "additional-free", - Collections., PValue>singletonMap(new TupleTag() {}, input), + allInputs, Collections., PValue>emptyMap(), new TestTransform(), pipeline); assertThat( - TransformInputs.nonAdditionalInputs(transform), Matchers.containsInAnyOrder(input)); + TransformInputs.nonAdditionalInputs(transform), + Matchers.containsInAnyOrder(voids, mainInts)); } @Test diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchTranslationContext.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchTranslationContext.java index 0439119dfc40..6e7019848b19 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchTranslationContext.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchTranslationContext.java @@ -20,6 +20,7 @@ import com.google.common.collect.Iterables; import java.util.HashMap; import java.util.Map; +import org.apache.beam.runners.core.construction.TransformInputs; import org.apache.beam.runners.flink.translation.types.CoderTypeInformation; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.options.PipelineOptions; @@ -143,7 +144,7 @@ Map, PValue> getInputs(PTransform transform) { @SuppressWarnings("unchecked") T getInput(PTransform transform) { - return (T) Iterables.getOnlyElement(currentTransform.getInputs().values()); + return (T) Iterables.getOnlyElement(TransformInputs.nonAdditionalInputs(currentTransform)); } Map, PValue> getOutputs(PTransform transform) { diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTranslationContext.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTranslationContext.java index ea5f6b3162af..18525ce36b5d 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTranslationContext.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTranslationContext.java @@ -22,6 +22,7 @@ import com.google.common.collect.Iterables; import java.util.HashMap; import java.util.Map; +import org.apache.beam.runners.core.construction.TransformInputs; import org.apache.beam.runners.flink.translation.types.CoderTypeInformation; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.options.PipelineOptions; @@ -113,7 +114,7 @@ public TypeInformation> getTypeInfo(PCollection collecti @SuppressWarnings("unchecked") public T getInput(PTransform transform) { - return (T) Iterables.getOnlyElement(currentTransform.getInputs().values()); + return (T) Iterables.getOnlyElement(TransformInputs.nonAdditionalInputs(currentTransform); } public Map, PValue> getInputs(PTransform transform) { diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java index 8102926f6daa..0c6c4d1cb660 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java @@ -26,6 +26,7 @@ import java.util.LinkedHashSet; import java.util.Map; import java.util.Set; +import org.apache.beam.runners.core.construction.TransformInputs; import org.apache.beam.runners.spark.SparkPipelineOptions; import org.apache.beam.runners.spark.coders.CoderHelpers; import org.apache.beam.sdk.Pipeline; @@ -103,7 +104,8 @@ public void setCurrentTransform(AppliedPTransform transform) { public T getInput(PTransform transform) { @SuppressWarnings("unchecked") - T input = (T) Iterables.getOnlyElement(getInputs(transform).values()); + T input = + (T) Iterables.getOnlyElement(TransformInputs.nonAdditionalInputs(getCurrentTransform())); return input; } From 94d837fe4edf0e5e4879afdc839c029f1d275323 Mon Sep 17 00:00:00 2001 From: Thomas Groh Date: Mon, 22 May 2017 14:56:17 -0700 Subject: [PATCH 5/5] fixup! Include Additional PTransform inputs in Transform Nodes --- .../beam/runners/core/construction/TransformInputsTest.java | 1 + .../java/org/apache/beam/runners/direct/DirectGraphVisitor.java | 2 +- .../beam/runners/flink/FlinkStreamingTranslationContext.java | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/TransformInputsTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/TransformInputsTest.java index a03fde895af8..f5b2c11e7923 100644 --- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/TransformInputsTest.java +++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/TransformInputsTest.java @@ -82,6 +82,7 @@ public void nonAdditionalInputsWithMultipleNonAdditionalInputsSucceeds() { PCollection mainInts = pipeline.apply("MainInput", Create.of(12, 3)); allInputs.put(new TupleTag() {}, mainInts); PCollection voids = pipeline.apply("VoidInput", Create.empty(VoidCoder.of())); + allInputs.put(new TupleTag() {}, voids); AppliedPTransform transform = AppliedPTransform.of( "additional-free", diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGraphVisitor.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGraphVisitor.java index 438e47b559d5..ed4282bbdb8b 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGraphVisitor.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGraphVisitor.java @@ -89,7 +89,7 @@ public void visitPrimitiveTransform(TransformHierarchy.Node node) { rootTransforms.add(appliedTransform); } else { Collection mainInputs = - TransformInputs.nonAdditionalInputs(node.toAppliedPTransform()); + TransformInputs.nonAdditionalInputs(node.toAppliedPTransform(getPipeline())); if (!mainInputs.containsAll(node.getInputs().values())) { LOG.debug( "Inputs reduced to {} from {} by removing additional inputs", diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTranslationContext.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTranslationContext.java index 18525ce36b5d..74a5fb971144 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTranslationContext.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTranslationContext.java @@ -114,7 +114,7 @@ public TypeInformation> getTypeInfo(PCollection collecti @SuppressWarnings("unchecked") public T getInput(PTransform transform) { - return (T) Iterables.getOnlyElement(TransformInputs.nonAdditionalInputs(currentTransform); + return (T) Iterables.getOnlyElement(TransformInputs.nonAdditionalInputs(currentTransform)); } public Map, PValue> getInputs(PTransform transform) {