Skip to content

Commit

Permalink
Add Parameters to finishSpecifying
Browse files Browse the repository at this point in the history
Remove the need to use getProducingTransformInternal in TypedPValue.

Ensure that all nodes are finished specifying before a call to
TransformHierarchy#visit. This ensures that all nodes are fully
specified without requiring the Pipeline or Runner to do so explicitly.
  • Loading branch information
tgroh committed Dec 12, 2016
1 parent 52d29c5 commit d6d7c1d
Show file tree
Hide file tree
Showing 18 changed files with 169 additions and 169 deletions.
Expand Up @@ -27,7 +27,6 @@
import java.util.Set;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.Pipeline.PipelineVisitor;
import org.apache.beam.sdk.runners.PipelineRunner;
import org.apache.beam.sdk.runners.TransformHierarchy;
import org.apache.beam.sdk.transforms.AppliedPTransform;
import org.apache.beam.sdk.transforms.PTransform;
Expand All @@ -50,7 +49,6 @@ class DirectGraphVisitor extends PipelineVisitor.Defaults {
private Set<PCollectionView<?>> views = new HashSet<>();
private Set<AppliedPTransform<?, ?, ?>> rootTransforms = new HashSet<>();
private Map<AppliedPTransform<?, ?, ?>, String> stepNames = new HashMap<>();
private Set<PValue> toFinalize = new HashSet<>();
private int numTransforms = 0;
private boolean finalized = false;

Expand Down Expand Up @@ -79,7 +77,6 @@ public void leaveCompositeTransform(TransformHierarchy.Node node) {

@Override
public void visitPrimitiveTransform(TransformHierarchy.Node node) {
toFinalize.removeAll(node.getInputs());
AppliedPTransform<?, ?, ?> appliedTransform = getAppliedTransform(node);
stepNames.put(appliedTransform, genStepName());
if (node.getInputs().isEmpty()) {
Expand All @@ -93,8 +90,6 @@ public void visitPrimitiveTransform(TransformHierarchy.Node node) {

@Override
public void visitValue(PValue value, TransformHierarchy.Node producer) {
toFinalize.add(value);

AppliedPTransform<?, ?, ?> appliedTransform = getAppliedTransform(producer);
if (!producers.containsKey(value)) {
producers.put(value, appliedTransform);
Expand All @@ -119,20 +114,6 @@ private String genStepName() {
return String.format("s%s", numTransforms++);
}

/**
* Returns all of the {@link PValue PValues} that have been produced but not consumed. These
* {@link PValue PValues} should be finalized by the {@link PipelineRunner} before the
* {@link Pipeline} is executed.
*/
public void finishSpecifyingRemainder() {
checkState(
finalized,
"Can't call finishSpecifyingRemainder before the Pipeline has been completely traversed");
for (PValue unfinalized : toFinalize) {
unfinalized.finishSpecifying();
}
}

/**
* Get the graph constructed by this {@link DirectGraphVisitor}, which provides
* lookups for producers and consumers of {@link PValue PValues}.
Expand Down
Expand Up @@ -300,7 +300,6 @@ public DirectPipelineResult run(Pipeline pipeline) {
MetricsEnvironment.setMetricsSupported(true);
DirectGraphVisitor graphVisitor = new DirectGraphVisitor();
pipeline.traverseTopologically(graphVisitor);
graphVisitor.finishSpecifyingRemainder();

@SuppressWarnings("rawtypes")
KeyedPValueTrackingVisitor keyedPValueVisitor =
Expand Down
Expand Up @@ -20,14 +20,14 @@
import static org.hamcrest.Matchers.emptyIterable;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.is;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;

import com.google.common.collect.Iterables;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.io.CountingInput;
import org.apache.beam.sdk.io.CountingSource;
import org.apache.beam.sdk.io.Read;
Expand Down Expand Up @@ -110,6 +110,7 @@ public void getRootTransformsContainsEmptyFlatten() {
FlattenPCollectionList<String> flatten = Flatten.pCollections();
PCollectionList<String> emptyList = PCollectionList.empty(p);
PCollection<String> empty = emptyList.apply(flatten);
empty.setCoder(StringUtf8Coder.of());
p.traverseTopologically(visitor);
DirectGraph graph = visitor.getGraph();
assertThat(
Expand Down Expand Up @@ -175,27 +176,6 @@ public void getValueToConsumersWithDuplicateInputSucceeds() {
assertThat(graph.getPrimitiveConsumers(flattened), emptyIterable());
}

@Test
public void getUnfinalizedPValuesContainsDanglingOutputs() {
PCollection<String> created = p.apply(Create.of("1", "2", "3"));
PCollection<String> transformed =
created.apply(
ParDo.of(
new DoFn<String, String>() {
@ProcessElement
public void processElement(DoFn<String, String>.ProcessContext c)
throws Exception {
c.output(Integer.toString(c.element().length()));
}
}));

assertThat(transformed.isFinishedSpecifyingInternal(), is(false));

p.traverseTopologically(visitor);
visitor.finishSpecifyingRemainder();
assertThat(transformed.isFinishedSpecifyingInternal(), is(true));
}

@Test
public void getStepNamesContainsAllTransforms() {
PCollection<String> created = p.apply(Create.of("1", "2", "3"));
Expand Down Expand Up @@ -253,12 +233,4 @@ public void getGraphWithoutVisitingThrows() {
thrown.expectMessage("get a graph");
visitor.getGraph();
}

@Test
public void finishSpecifyingRemainderWithoutVisitingThrows() {
thrown.expect(IllegalStateException.class);
thrown.expectMessage("completely traversed");
thrown.expectMessage("finishSpecifyingRemainder");
visitor.finishSpecifyingRemainder();
}
}
Expand Up @@ -26,6 +26,7 @@
import com.google.common.collect.Iterables;
import org.apache.beam.runners.direct.DirectRunner.CommittedBundle;
import org.apache.beam.runners.direct.DirectRunner.UncommittedBundle;
import org.apache.beam.sdk.coders.VarIntCoder;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.transforms.AppliedPTransform;
import org.apache.beam.sdk.transforms.Create;
Expand Down Expand Up @@ -122,6 +123,7 @@ public void testFlattenInMemoryEvaluatorWithEmptyPCollectionList() throws Except
PCollectionList<Integer> list = PCollectionList.empty(p);

PCollection<Integer> flattened = list.apply(Flatten.<Integer>pCollections());
flattened.setCoder(VarIntCoder.of());

EvaluationContext evaluationContext = mock(EvaluationContext.class);
when(evaluationContext.createBundle(flattened))
Expand Down
Expand Up @@ -171,6 +171,8 @@ public <OutputT extends POutput> OutputT apply(
* Runs the {@link Pipeline} using its {@link PipelineRunner}.
*/
public PipelineResult run() {
// Ensure all of the nodes are fully specified before a PipelineRunner gets access to the
// pipeline.
LOG.debug("Running {} via {}", this, runner);
try {
return runner.run(this);
Expand Down Expand Up @@ -281,6 +283,7 @@ public void visitValue(PValue value, TransformHierarchy.Node producer) { }
* <p>Typically invoked by {@link PipelineRunner} subclasses.
*/
public void traverseTopologically(PipelineVisitor visitor) {
// Ensure all nodes are fully specified before visiting the pipeline
Set<PValue> visitedValues =
// Visit all the transforms, which should implicitly visit all the values.
transforms.visit(visitor);
Expand Down
Expand Up @@ -26,8 +26,8 @@
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import javax.annotation.Nullable;
import org.apache.beam.sdk.Pipeline.PipelineVisitor;
Expand All @@ -45,13 +45,15 @@
public class TransformHierarchy {
private final Node root;
private final Map<POutput, Node> producers;
private final Map<PValue, PInput> producerInput;
// Maintain a stack based on the enclosing nodes
private Node current;

public TransformHierarchy() {
root = new Node(null, null, "", null);
current = root;
producers = new HashMap<>();
producerInput = new HashMap<>();
}

/**
Expand Down Expand Up @@ -83,12 +85,15 @@ public Node pushNode(String name, PInput input, PTransform<?, ?> transform) {
* specified, and have been produced by a node in this graph.
*/
public void finishSpecifyingInput() {
// Inputs must be completely specified before they are consumed by a transform.
// Inputs must be completely specified before they are consumed by a transform. All component
// inputs must be finished specifying before the overall input
for (PValue inputValue : current.getInputs()) {
inputValue.finishSpecifying();
Node producerNode = getProducer(inputValue);
PInput input = producerInput.remove(inputValue);
inputValue.finishSpecifying(input, producerNode.getTransform());
checkState(producers.get(inputValue) != null, "Producer unknown for input %s", inputValue);
inputValue.finishSpecifying();
}
current.input.finishSpecifying();
}

/**
Expand All @@ -102,12 +107,14 @@ public void finishSpecifyingInput() {
* nodes.
*/
public void setOutput(POutput output) {
output.finishSpecifyingOutput();
for (PValue value : output.expand()) {
if (!producers.containsKey(value)) {
producers.put(value, current);
}
value.finishSpecifyingOutput(current.input, current.transform);
producerInput.put(value, current.input);
}
output.finishSpecifyingOutput(current.input, current.transform);
current.setOutput(output);
// TODO: Replace with a "generateDefaultNames" method.
output.recordAsOutput(current.toAppliedPTransform());
Expand All @@ -127,27 +134,26 @@ Node getProducer(PValue produced) {
return producers.get(produced);
}

/**
* Returns all producing transforms for the {@link PValue PValues} contained
* in {@code output}.
*/
List<Node> getProducingTransforms(POutput output) {
List<Node> producingTransforms = new ArrayList<>();
for (PValue value : output.expand()) {
Node producer = getProducer(value);
if (producer != null) {
producingTransforms.add(producer);
}
}
return producingTransforms;
}

public Set<PValue> visit(PipelineVisitor visitor) {
finishSpecifying();
Set<PValue> visitedValues = new HashSet<>();
root.visit(visitor, visitedValues);
return visitedValues;
}

/**
* Finish specifying any remaining nodes within the {@link TransformHierarchy}. These are {@link
* PValue PValues} that are produced as output of some {@link PTransform} but are never consumed
* as input. These values must still be finished specifying.
*/
private void finishSpecifying() {
for (Entry<PValue, PInput> producerInputEntry : producerInput.entrySet()) {
PValue value = producerInputEntry.getKey();
value.finishSpecifying(producerInputEntry.getValue(), getProducer(value).getTransform());
}
producerInput.clear();
}

public Node getCurrent() {
return current;
}
Expand Down
Expand Up @@ -155,11 +155,25 @@ public Pipeline getPipeline() {

@Override
public void finishSpecifying() {
for (TaggedKeyedPCollection<K, ?> taggedPCollection : keyedCollections) {
taggedPCollection.pCollection.finishSpecifying();
// TODO: Make sure key coder is consistent between PCollections. All component PCollections will
// have already been finished.
}

private static <K, V> Coder<K> getKeyCoder(PCollection<KV<K, V>> pc) {
// TODO: This should already have run coder inference for output, but may not have been consumed
// as input yet (and won't be fully specified); This is fine

// Assumes that the PCollection uses a KvCoder.
Coder<?> entryCoder = pc.getCoder();
if (!(entryCoder instanceof KvCoder<?, ?>)) {
throw new IllegalArgumentException("PCollection does not use a KvCoder");
}
@SuppressWarnings("unchecked")
KvCoder<K, V> coder = (KvCoder<K, V>) entryCoder;
return coder.getKeyCoder();
}


/////////////////////////////////////////////////////////////////////////////

/**
Expand Down Expand Up @@ -198,7 +212,7 @@ public TupleTag<V> getTupleTag() {
*/
private final List<TaggedKeyedPCollection<K, ?>> keyedCollections;

private final Coder<K> keyCoder;
private Coder<K> keyCoder;

private final CoGbkResultSchema schema;

Expand All @@ -222,20 +236,6 @@ public TupleTag<V> getTupleTag() {
this.keyCoder = keyCoder;
}

private static <K, V> Coder<K> getKeyCoder(PCollection<KV<K, V>> pc) {
// Need to run coder inference on this PCollection before inspecting it.
pc.finishSpecifying();

// Assumes that the PCollection uses a KvCoder.
Coder<?> entryCoder = pc.getCoder();
if (!(entryCoder instanceof KvCoder<?, ?>)) {
throw new IllegalArgumentException("PCollection does not use a KvCoder");
}
@SuppressWarnings("unchecked")
KvCoder<K, V> coder = (KvCoder<K, V>) entryCoder;
return coder.getKeyCoder();
}

private static <K> List<TaggedKeyedPCollection<K, ?>> copyAddLast(
List<TaggedKeyedPCollection<K, ?>> keyedCollections,
TaggedKeyedPCollection<K, ?> taggedCollection) {
Expand Down
Expand Up @@ -224,15 +224,11 @@ public void recordAsOutput(AppliedPTransform<?, ?, ?> transform) {

@Override
public void finishSpecifying() {
for (PCollection<T> pc : pcollections) {
pc.finishSpecifying();
}
// All component PCollections will have already been finished.
}

@Override
public void finishSpecifyingOutput() {
for (PCollection<T> pc : pcollections) {
pc.finishSpecifyingOutput();
}
public void finishSpecifyingOutput(PInput input, PTransform<?, ?> transform) {
// All component PCollections will have already been finished.
}
}
Expand Up @@ -250,15 +250,11 @@ public void recordAsOutput(AppliedPTransform<?, ?, ?> transform) {

@Override
public void finishSpecifying() {
for (PCollection<?> pc : pcollectionMap.values()) {
pc.finishSpecifying();
}
// All component PCollections will already have been finished
}

@Override
public void finishSpecifyingOutput() {
for (PCollection<?> pc : pcollectionMap.values()) {
pc.finishSpecifyingOutput();
}
public void finishSpecifyingOutput(PInput input, PTransform<?, ?> transform) {
// All component PCollections will already have been finished
}
}
Expand Up @@ -19,6 +19,7 @@

import java.util.Collection;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.transforms.PTransform;

/**
* The interface for things that might be input to a
Expand Down Expand Up @@ -46,11 +47,12 @@ public interface PInput {
Collection<? extends PValue> expand();

/**
* After building, finalizes this {@code PInput} to make it ready for
* being used as an input to a {@link org.apache.beam.sdk.transforms.PTransform}.
* After building, finalizes this {@code PInput} to make it ready for being used as an input to a
* {@link org.apache.beam.sdk.transforms.PTransform}.
*
* <p>Automatically invoked whenever {@code apply()} is invoked on
* this {@code PInput}, so users do not normally call this explicitly.
* <p>Automatically invoked whenever {@code apply()} is invoked on this {@code PInput}, after
* {@link PValue#finishSpecifying(PInput, PTransform)} has been called on each component {@link
* PValue}, so users do not normally call this explicitly.
*/
void finishSpecifying();
}

0 comments on commit d6d7c1d

Please sign in to comment.