Skip to content

Commit

Permalink
Merge 2656abe into 1c6861f
Browse files Browse the repository at this point in the history
  • Loading branch information
tgroh committed Jun 9, 2017
2 parents 1c6861f + 2656abe commit 26528b6
Show file tree
Hide file tree
Showing 20 changed files with 594 additions and 52 deletions.
Expand Up @@ -34,6 +34,7 @@
import org.apache.beam.runners.apex.translation.utils.ApexStateInternals.ApexStateBackend;
import org.apache.beam.runners.apex.translation.utils.ApexStreamTuple;
import org.apache.beam.runners.apex.translation.utils.CoderAdapterStreamCodec;
import org.apache.beam.runners.core.construction.TransformInputs;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.runners.AppliedPTransform;
import org.apache.beam.sdk.transforms.PTransform;
Expand Down Expand Up @@ -93,7 +94,8 @@ public Map<TupleTag<?>, PValue> getInputs() {
}

public <InputT extends PValue> InputT getInput() {
return (InputT) Iterables.getOnlyElement(getCurrentTransform().getInputs().values());
return (InputT)
Iterables.getOnlyElement(TransformInputs.nonAdditionalInputs(getCurrentTransform()));
}

public Map<TupleTag<?>, PValue> getOutputs() {
Expand Down
@@ -0,0 +1,50 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.beam.runners.core.construction;

import static com.google.common.base.Preconditions.checkArgument;

import com.google.common.collect.ImmutableList;
import java.util.Collection;
import java.util.Map;
import org.apache.beam.sdk.runners.AppliedPTransform;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.values.PValue;
import org.apache.beam.sdk.values.TupleTag;

/** Utilities for extracting subsets of inputs from an {@link AppliedPTransform}. */
public class TransformInputs {
/**
* Gets all inputs of the {@link AppliedPTransform} that are not returned by {@link
* PTransform#getAdditionalInputs()}.
*/
public static Collection<PValue> nonAdditionalInputs(AppliedPTransform<?, ?, ?> application) {
ImmutableList.Builder<PValue> mainInputs = ImmutableList.builder();
PTransform<?, ?> transform = application.getTransform();
for (Map.Entry<TupleTag<?>, PValue> input : application.getInputs().entrySet()) {
if (!transform.getAdditionalInputs().containsKey(input.getKey())) {
mainInputs.add(input.getValue());
}
}
checkArgument(
!mainInputs.build().isEmpty() || application.getInputs().isEmpty(),
"Expected at least one main input if any inputs exist");
return mainInputs.build();
}
}
@@ -0,0 +1,166 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.beam.runners.core.construction;

import static org.junit.Assert.assertThat;

import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import org.apache.beam.sdk.coders.VoidCoder;
import org.apache.beam.sdk.io.GenerateSequence;
import org.apache.beam.sdk.runners.AppliedPTransform;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PDone;
import org.apache.beam.sdk.values.PInput;
import org.apache.beam.sdk.values.POutput;
import org.apache.beam.sdk.values.PValue;
import org.apache.beam.sdk.values.TupleTag;
import org.hamcrest.Matchers;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

/** Tests for {@link TransformInputs}. */
@RunWith(JUnit4.class)
public class TransformInputsTest {
@Rule public TestPipeline pipeline = TestPipeline.create().enableAbandonedNodeEnforcement(false);
@Rule public ExpectedException thrown = ExpectedException.none();

@Test
public void nonAdditionalInputsWithNoInputSucceeds() {
AppliedPTransform<PInput, POutput, TestTransform> transform =
AppliedPTransform.of(
"input-free",
Collections.<TupleTag<?>, PValue>emptyMap(),
Collections.<TupleTag<?>, PValue>emptyMap(),
new TestTransform(),
pipeline);

assertThat(TransformInputs.nonAdditionalInputs(transform), Matchers.<PValue>empty());
}

@Test
public void nonAdditionalInputsWithOneMainInputSucceeds() {
PCollection<Long> input = pipeline.apply(GenerateSequence.from(1L));
AppliedPTransform<PInput, POutput, TestTransform> transform =
AppliedPTransform.of(
"input-single",
Collections.<TupleTag<?>, PValue>singletonMap(new TupleTag<Long>() {}, input),
Collections.<TupleTag<?>, PValue>emptyMap(),
new TestTransform(),
pipeline);

assertThat(
TransformInputs.nonAdditionalInputs(transform), Matchers.<PValue>containsInAnyOrder(input));
}

@Test
public void nonAdditionalInputsWithMultipleNonAdditionalInputsSucceeds() {
Map<TupleTag<?>, PValue> allInputs = new HashMap<>();
PCollection<Integer> mainInts = pipeline.apply("MainInput", Create.of(12, 3));
allInputs.put(new TupleTag<Integer>() {}, mainInts);
PCollection<Void> voids = pipeline.apply("VoidInput", Create.empty(VoidCoder.of()));
allInputs.put(new TupleTag<Void>() {}, voids);
AppliedPTransform<PInput, POutput, TestTransform> transform =
AppliedPTransform.of(
"additional-free",
allInputs,
Collections.<TupleTag<?>, PValue>emptyMap(),
new TestTransform(),
pipeline);

assertThat(
TransformInputs.nonAdditionalInputs(transform),
Matchers.<PValue>containsInAnyOrder(voids, mainInts));
}

@Test
public void nonAdditionalInputsWithAdditionalInputsSucceeds() {
Map<TupleTag<?>, PValue> additionalInputs = new HashMap<>();
additionalInputs.put(new TupleTag<String>() {}, pipeline.apply(Create.of("1, 2", "3")));
additionalInputs.put(new TupleTag<Long>() {}, pipeline.apply(GenerateSequence.from(3L)));

Map<TupleTag<?>, PValue> allInputs = new HashMap<>();
PCollection<Integer> mainInts = pipeline.apply("MainInput", Create.of(12, 3));
allInputs.put(new TupleTag<Integer>() {}, mainInts);
PCollection<Void> voids = pipeline.apply("VoidInput", Create.empty(VoidCoder.of()));
allInputs.put(
new TupleTag<Void>() {}, voids);
allInputs.putAll(additionalInputs);

AppliedPTransform<PInput, POutput, TestTransform> transform =
AppliedPTransform.of(
"additional",
allInputs,
Collections.<TupleTag<?>, PValue>emptyMap(),
new TestTransform(additionalInputs),
pipeline);

assertThat(
TransformInputs.nonAdditionalInputs(transform),
Matchers.<PValue>containsInAnyOrder(mainInts, voids));
}

@Test
public void nonAdditionalInputsWithOnlyAdditionalInputsThrows() {
Map<TupleTag<?>, PValue> additionalInputs = new HashMap<>();
additionalInputs.put(new TupleTag<String>() {}, pipeline.apply(Create.of("1, 2", "3")));
additionalInputs.put(new TupleTag<Long>() {}, pipeline.apply(GenerateSequence.from(3L)));

AppliedPTransform<PInput, POutput, TestTransform> transform =
AppliedPTransform.of(
"additional-only",
additionalInputs,
Collections.<TupleTag<?>, PValue>emptyMap(),
new TestTransform(additionalInputs),
pipeline);

thrown.expect(IllegalArgumentException.class);
thrown.expectMessage("at least one");
TransformInputs.nonAdditionalInputs(transform);
}

private static class TestTransform extends PTransform<PInput, POutput> {
private final Map<TupleTag<?>, PValue> additionalInputs;

private TestTransform() {
this(Collections.<TupleTag<?>, PValue>emptyMap());
}

private TestTransform(Map<TupleTag<?>, PValue> additionalInputs) {
this.additionalInputs = additionalInputs;
}

@Override
public POutput expand(PInput input) {
return PDone.in(input.getPipeline());
}

@Override
public Map<TupleTag<?>, PValue> getAdditionalInputs() {
return additionalInputs;
}
}
}
Expand Up @@ -17,6 +17,8 @@
*/
package org.apache.beam.runners.direct;

import static com.google.common.base.Preconditions.checkArgument;

import com.google.common.collect.ListMultimap;
import java.util.Collection;
import java.util.List;
Expand All @@ -36,31 +38,45 @@
class DirectGraph {
private final Map<PCollection<?>, AppliedPTransform<?, ?, ?>> producers;
private final Map<PCollectionView<?>, AppliedPTransform<?, ?, ?>> viewWriters;
private final ListMultimap<PInput, AppliedPTransform<?, ?, ?>> primitiveConsumers;
private final ListMultimap<PInput, AppliedPTransform<?, ?, ?>> perElementConsumers;
private final ListMultimap<PValue, AppliedPTransform<?, ?, ?>> allConsumers;

private final Set<AppliedPTransform<?, ?, ?>> rootTransforms;
private final Map<AppliedPTransform<?, ?, ?>, String> stepNames;

public static DirectGraph create(
Map<PCollection<?>, AppliedPTransform<?, ?, ?>> producers,
Map<PCollectionView<?>, AppliedPTransform<?, ?, ?>> viewWriters,
ListMultimap<PInput, AppliedPTransform<?, ?, ?>> primitiveConsumers,
ListMultimap<PInput, AppliedPTransform<?, ?, ?>> perElementConsumers,
ListMultimap<PValue, AppliedPTransform<?, ?, ?>> allConsumers,
Set<AppliedPTransform<?, ?, ?>> rootTransforms,
Map<AppliedPTransform<?, ?, ?>, String> stepNames) {
return new DirectGraph(producers, viewWriters, primitiveConsumers, rootTransforms, stepNames);
return new DirectGraph(
producers, viewWriters, perElementConsumers, allConsumers, rootTransforms, stepNames);
}

private DirectGraph(
Map<PCollection<?>, AppliedPTransform<?, ?, ?>> producers,
Map<PCollectionView<?>, AppliedPTransform<?, ?, ?>> viewWriters,
ListMultimap<PInput, AppliedPTransform<?, ?, ?>> primitiveConsumers,
ListMultimap<PInput, AppliedPTransform<?, ?, ?>> perElementConsumers,
ListMultimap<PValue, AppliedPTransform<?, ?, ?>> allConsumers,
Set<AppliedPTransform<?, ?, ?>> rootTransforms,
Map<AppliedPTransform<?, ?, ?>, String> stepNames) {
this.producers = producers;
this.viewWriters = viewWriters;
this.primitiveConsumers = primitiveConsumers;
this.perElementConsumers = perElementConsumers;
this.allConsumers = allConsumers;
this.rootTransforms = rootTransforms;
this.stepNames = stepNames;
for (AppliedPTransform<?, ?, ?> step : stepNames.keySet()) {
for (PValue input : step.getInputs().values()) {
checkArgument(
allConsumers.get(input).contains(step),
"Step %s lists value %s as input, but it is not in the graph of consumers",
step.getFullName(),
input);
}
}
}

AppliedPTransform<?, ?, ?> getProducer(PCollection<?> produced) {
Expand All @@ -71,8 +87,12 @@ private DirectGraph(
return viewWriters.get(view);
}

List<AppliedPTransform<?, ?, ?>> getPrimitiveConsumers(PValue consumed) {
return primitiveConsumers.get(consumed);
List<AppliedPTransform<?, ?, ?>> getPerElementConsumers(PValue consumed) {
return perElementConsumers.get(consumed);
}

List<AppliedPTransform<?, ?, ?>> getAllConsumers(PValue consumed) {
return allConsumers.get(consumed);
}

Set<AppliedPTransform<?, ?, ?>> getRootTransforms() {
Expand Down
Expand Up @@ -22,10 +22,12 @@
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.ListMultimap;
import com.google.common.collect.Sets;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import org.apache.beam.runners.core.construction.TransformInputs;
import org.apache.beam.runners.direct.ViewOverrideFactory.WriteView;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.Pipeline.PipelineVisitor;
Expand All @@ -37,19 +39,24 @@
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.PInput;
import org.apache.beam.sdk.values.PValue;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* Tracks the {@link AppliedPTransform AppliedPTransforms} that consume each {@link PValue} in the
* {@link Pipeline}. This is used to schedule consuming {@link PTransform PTransforms} to consume
* input after the upstream transform has produced and committed output.
*/
class DirectGraphVisitor extends PipelineVisitor.Defaults {
private static final Logger LOG = LoggerFactory.getLogger(DirectGraphVisitor.class);

private Map<PCollection<?>, AppliedPTransform<?, ?, ?>> producers = new HashMap<>();
private Map<PCollectionView<?>, AppliedPTransform<?, ?, ?>> viewWriters = new HashMap<>();
private Set<PCollectionView<?>> consumedViews = new HashSet<>();

private ListMultimap<PInput, AppliedPTransform<?, ?, ?>> primitiveConsumers =
private ListMultimap<PInput, AppliedPTransform<?, ?, ?>> perElementConsumers =
ArrayListMultimap.create();
private ListMultimap<PValue, AppliedPTransform<?, ?, ?>> allConsumers =
ArrayListMultimap.create();

private Set<AppliedPTransform<?, ?, ?>> rootTransforms = new HashSet<>();
Expand Down Expand Up @@ -94,8 +101,19 @@ public void visitPrimitiveTransform(TransformHierarchy.Node node) {
if (node.getInputs().isEmpty()) {
rootTransforms.add(appliedTransform);
} else {
Collection<PValue> mainInputs =
TransformInputs.nonAdditionalInputs(node.toAppliedPTransform(getPipeline()));
if (!mainInputs.containsAll(node.getInputs().values())) {
LOG.debug(
"Inputs reduced to {} from {} by removing additional inputs",
mainInputs,
node.getInputs().values());
}
for (PValue value : mainInputs) {
perElementConsumers.put(value, appliedTransform);
}
for (PValue value : node.getInputs().values()) {
primitiveConsumers.put(value, appliedTransform);
allConsumers.put(value, appliedTransform);
}
}
if (node.getTransform() instanceof ParDo.MultiOutput) {
Expand All @@ -106,7 +124,7 @@ public void visitPrimitiveTransform(TransformHierarchy.Node node) {
}
}

@Override
@Override
public void visitValue(PValue value, TransformHierarchy.Node producer) {
AppliedPTransform<?, ?, ?> appliedTransform = getAppliedTransform(producer);
if (value instanceof PCollection && !producers.containsKey(value)) {
Expand All @@ -131,6 +149,6 @@ private String genStepName() {
public DirectGraph getGraph() {
checkState(finalized, "Can't get a graph before the Pipeline has been completely traversed");
return DirectGraph.create(
producers, viewWriters, primitiveConsumers, rootTransforms, stepNames);
producers, viewWriters, perElementConsumers, allConsumers, rootTransforms, stepNames);
}
}
Expand Up @@ -355,7 +355,7 @@ public final CommittedResult handleResult(
for (CommittedBundle<?> outputBundle : committedResult.getOutputs()) {
allUpdates.offer(
ExecutorUpdate.fromBundle(
outputBundle, graph.getPrimitiveConsumers(outputBundle.getPCollection())));
outputBundle, graph.getPerElementConsumers(outputBundle.getPCollection())));
}
CommittedBundle<?> unprocessedInputs = committedResult.getUnprocessedInputs();
if (unprocessedInputs != null && !Iterables.isEmpty(unprocessedInputs.getElements())) {
Expand Down

0 comments on commit 26528b6

Please sign in to comment.