Skip to content

Commit

Permalink
Port direct runner StatefulParDo to KeyedWorkItem
Browse files Browse the repository at this point in the history
  • Loading branch information
kennknowles committed Dec 14, 2016
1 parent 2780090 commit 5d912cf
Show file tree
Hide file tree
Showing 5 changed files with 198 additions and 58 deletions.
Expand Up @@ -31,6 +31,7 @@
import org.apache.beam.sdk.runners.TransformHierarchy;
import org.apache.beam.sdk.transforms.GroupByKey;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.values.PValue;

/**
Expand Down Expand Up @@ -105,7 +106,15 @@ public Set<PValue> getKeyedPValues() {
}

private static boolean isKeyPreserving(PTransform<?, ?> transform) {
// There are currently no key-preserving transforms; this lays the infrastructure for them
return false;
// This is a hacky check for what is considered key-preserving to the direct runner.
// The most obvious alternative would be a package-private marker interface, but
// better to make this obviously hacky so it is less likely to proliferate. Meanwhile
// we intend to allow explicit expression of key-preserving DoFn in the model.
if (transform instanceof ParDo.BoundMulti) {
ParDo.BoundMulti<?, ?> parDo = (ParDo.BoundMulti<?, ?>) transform;
return parDo.getNewFn() instanceof ParDoMultiOverrideFactory.ToKeyedWorkItem;
} else {
return false;
}
}
}
Expand Up @@ -17,15 +17,21 @@
*/
package org.apache.beam.runners.direct;

import org.apache.beam.runners.core.KeyedWorkItem;
import org.apache.beam.runners.core.KeyedWorkItemCoder;
import org.apache.beam.runners.core.KeyedWorkItems;
import org.apache.beam.runners.core.SplittableParDo;
import org.apache.beam.sdk.coders.CannotProvideCoderException;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.GroupByKey;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature;
import org.apache.beam.sdk.transforms.reflect.DoFnSignatures;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionTuple;
Expand Down Expand Up @@ -83,16 +89,36 @@ public GbkThenStatefulParDo(ParDo.BoundMulti<KV<K, InputT>, OutputT> underlyingP
@Override
public PCollectionTuple expand(PCollection<KV<K, InputT>> input) {

PCollectionTuple outputs = input
.apply("Group by key", GroupByKey.<K, InputT>create())
.apply("Stateful ParDo", new StatefulParDo<>(underlyingParDo, input));
// A KvCoder is required since this goes through GBK. Further, WindowedValueCoder
// is not registered by default, so we explicitly set the relevant coders.
KvCoder<K, InputT> kvCoder = (KvCoder<K, InputT>) input.getCoder();
Coder<K> keyCoder = kvCoder.getKeyCoder();
Coder<? extends BoundedWindow> windowCoder =
input.getWindowingStrategy().getWindowFn().windowCoder();

PCollectionTuple outputs =
input
// Stash the original timestamps, etc, for when it is fed to the user's DoFn
.apply("Reify timestamps", ParDo.of(new ReifyWindowedValueFn<K, InputT>()))
.setCoder(KvCoder.of(keyCoder, WindowedValue.getFullCoder(kvCoder, windowCoder)))

// A full GBK to group by key _and_ window
.apply("Group by key", GroupByKey.<K, WindowedValue<KV<K, InputT>>>create())

// Adapt to KeyedWorkItem; that is how this runner delivers timers
.apply("To KeyedWorkItem", ParDo.of(new ToKeyedWorkItem<K, InputT>()))
.setCoder(KeyedWorkItemCoder.of(keyCoder, kvCoder, windowCoder))

// Explode the resulting iterable into elements that are exactly the ones from
// the input
.apply("Stateful ParDo", new StatefulParDo<>(underlyingParDo, input));

return outputs;
}
}

static class StatefulParDo<K, InputT, OutputT>
extends PTransform<PCollection<? extends KV<K, Iterable<InputT>>>, PCollectionTuple> {
extends PTransform<PCollection<? extends KeyedWorkItem<K, KV<K, InputT>>>, PCollectionTuple> {
private final transient ParDo.BoundMulti<KV<K, InputT>, OutputT> underlyingParDo;
private final transient PCollection<KV<K, InputT>> originalInput;

Expand All @@ -109,21 +135,58 @@ public ParDo.BoundMulti<KV<K, InputT>, OutputT> getUnderlyingParDo() {

@Override
public <T> Coder<T> getDefaultOutputCoder(
PCollection<? extends KV<K, Iterable<InputT>>> input, TypedPValue<T> output)
PCollection<? extends KeyedWorkItem<K, KV<K, InputT>>> input, TypedPValue<T> output)
throws CannotProvideCoderException {
return underlyingParDo.getDefaultOutputCoder(originalInput, output);
}

public PCollectionTuple expand(PCollection<? extends KV<K, Iterable<InputT>>> input) {
@Override
public PCollectionTuple expand(PCollection<? extends KeyedWorkItem<K, KV<K, InputT>>> input) {

PCollectionTuple outputs = PCollectionTuple.ofPrimitiveOutputsInternal(
input.getPipeline(),
TupleTagList.of(underlyingParDo.getMainOutputTag())
.and(underlyingParDo.getSideOutputTags().getAll()),
input.getWindowingStrategy(),
input.isBounded());
PCollectionTuple outputs =
PCollectionTuple.ofPrimitiveOutputsInternal(
input.getPipeline(),
TupleTagList.of(underlyingParDo.getMainOutputTag())
.and(underlyingParDo.getSideOutputTags().getAll()),
input.getWindowingStrategy(),
input.isBounded());

return outputs;
}
}

/**
* A distinguished key-preserving {@link DoFn}.
*
* <p>This wraps the {@link GroupByKey} output in a {@link KeyedWorkItem} to be able to deliver
* timers. It also explodes them into single {@link KV KVs} since this is what the user's {@link
* DoFn} needs to process anyhow.
*/
static class ReifyWindowedValueFn<K, V> extends DoFn<KV<K, V>, KV<K, WindowedValue<KV<K, V>>>> {
@ProcessElement
public void processElement(final ProcessContext c, final BoundedWindow window) {
c.output(
KV.of(
c.element().getKey(),
WindowedValue.of(c.element(), c.timestamp(), window, c.pane())));
}
}

/**
* A runner-specific primitive that is just a key-preserving {@link ParDo}, but we do not have the
* machinery to detect or enforce that yet.
*
* <p>This wraps the {@link GroupByKey} output in a {@link KeyedWorkItem} to be able to deliver
* timers. It also explodes them into single {@link KV KVs} since this is what the user's {@link
* DoFn} needs to process anyhow.
*/
static class ToKeyedWorkItem<K, V>
extends DoFn<KV<K, Iterable<WindowedValue<KV<K, V>>>>, KeyedWorkItem<K, KV<K, V>>> {

@ProcessElement
public void processElement(final ProcessContext c, final BoundedWindow window) {
final K key = c.element().getKey();
c.output(KeyedWorkItems.elementsWorkItem(key, c.element().getValue()));
}
}
}
Expand Up @@ -23,6 +23,8 @@
import com.google.common.cache.LoadingCache;
import com.google.common.collect.Lists;
import java.util.Collections;
import org.apache.beam.runners.core.KeyedWorkItem;
import org.apache.beam.runners.core.KeyedWorkItems;
import org.apache.beam.runners.direct.DirectExecutionContext.DirectStepContext;
import org.apache.beam.runners.direct.DirectRunner.CommittedBundle;
import org.apache.beam.runners.direct.ParDoMultiOverrideFactory.StatefulParDo;
Expand Down Expand Up @@ -77,12 +79,12 @@ public void cleanup() throws Exception {
}

@SuppressWarnings({"unchecked", "rawtypes"})
private TransformEvaluator<KV<K, Iterable<InputT>>> createEvaluator(
private TransformEvaluator<KeyedWorkItem<K, KV<K, InputT>>> createEvaluator(
AppliedPTransform<
PCollection<? extends KV<K, Iterable<InputT>>>, PCollectionTuple,
PCollection<? extends KeyedWorkItem<K, KV<K, InputT>>>, PCollectionTuple,
StatefulParDo<K, InputT, OutputT>>
application,
CommittedBundle<KV<K, Iterable<InputT>>> inputBundle)
CommittedBundle<KeyedWorkItem<K, KV<K, InputT>>> inputBundle)
throws Exception {

final DoFn<KV<K, InputT>, OutputT> doFn =
Expand Down Expand Up @@ -185,7 +187,7 @@ public void run() {
@AutoValue
abstract static class AppliedPTransformOutputKeyAndWindow<K, InputT, OutputT> {
abstract AppliedPTransform<
PCollection<? extends KV<K, Iterable<InputT>>>, PCollectionTuple,
PCollection<? extends KeyedWorkItem<K, KV<K, InputT>>>, PCollectionTuple,
StatefulParDo<K, InputT, OutputT>>
getTransform();

Expand All @@ -195,7 +197,7 @@ abstract static class AppliedPTransformOutputKeyAndWindow<K, InputT, OutputT> {

static <K, InputT, OutputT> AppliedPTransformOutputKeyAndWindow<K, InputT, OutputT> create(
AppliedPTransform<
PCollection<? extends KV<K, Iterable<InputT>>>, PCollectionTuple,
PCollection<? extends KeyedWorkItem<K, KV<K, InputT>>>, PCollectionTuple,
StatefulParDo<K, InputT, OutputT>>
transform,
StructuralKey<K> key,
Expand All @@ -206,7 +208,7 @@ static <K, InputT, OutputT> AppliedPTransformOutputKeyAndWindow<K, InputT, Outpu
}

private static class StatefulParDoEvaluator<K, InputT>
implements TransformEvaluator<KV<K, Iterable<InputT>>> {
implements TransformEvaluator<KeyedWorkItem<K, KV<K, InputT>>> {

private final TransformEvaluator<KV<K, InputT>> delegateEvaluator;

Expand All @@ -215,20 +217,20 @@ public StatefulParDoEvaluator(TransformEvaluator<KV<K, InputT>> delegateEvaluato
}

@Override
public void processElement(WindowedValue<KV<K, Iterable<InputT>>> gbkResult) throws Exception {
public void processElement(WindowedValue<KeyedWorkItem<K, KV<K, InputT>>> gbkResult)
throws Exception {

for (InputT value : gbkResult.getValue().getValue()) {
delegateEvaluator.processElement(
gbkResult.withValue(KV.of(gbkResult.getValue().getKey(), value)));
for (WindowedValue<KV<K, InputT>> windowedValue : gbkResult.getValue().elementsIterable()) {
delegateEvaluator.processElement(windowedValue);
}
}

@Override
public TransformResult<KV<K, Iterable<InputT>>> finishBundle() throws Exception {
public TransformResult<KeyedWorkItem<K, KV<K, InputT>>> finishBundle() throws Exception {
TransformResult<KV<K, InputT>> delegateResult = delegateEvaluator.finishBundle();

StepTransformResult.Builder<KV<K, Iterable<InputT>>> regroupedResult =
StepTransformResult.<KV<K, Iterable<InputT>>>withHold(
StepTransformResult.Builder<KeyedWorkItem<K, KV<K, InputT>>> regroupedResult =
StepTransformResult.<KeyedWorkItem<K, KV<K, InputT>>>withHold(
delegateResult.getTransform(), delegateResult.getWatermarkHold())
.withTimerUpdate(delegateResult.getTimerUpdate())
.withAggregatorChanges(delegateResult.getAggregatorChanges())
Expand All @@ -240,12 +242,10 @@ public TransformResult<KV<K, Iterable<InputT>>> finishBundle() throws Exception
// outputs, but just make a bunch of singletons
for (WindowedValue<?> untypedUnprocessed : delegateResult.getUnprocessedElements()) {
WindowedValue<KV<K, InputT>> windowedKv = (WindowedValue<KV<K, InputT>>) untypedUnprocessed;
WindowedValue<KV<K, Iterable<InputT>>> pushedBack =
WindowedValue<KeyedWorkItem<K, KV<K, InputT>>> pushedBack =
windowedKv.withValue(
KV.of(
windowedKv.getValue().getKey(),
(Iterable<InputT>)
Collections.singletonList(windowedKv.getValue().getValue())));
KeyedWorkItems.elementsWorkItem(
windowedKv.getValue().getKey(), Collections.singleton(windowedKv)));

regroupedResult.addUnprocessedElements(pushedBack);
}
Expand Down
Expand Up @@ -22,9 +22,11 @@
import static org.junit.Assert.assertThat;

import java.util.Collections;
import org.apache.beam.runners.core.KeyedWorkItem;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.coders.IterableCoder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.coders.VarIntCoder;
import org.apache.beam.sdk.coders.VoidCoder;
import org.apache.beam.sdk.testing.TestPipeline;
Expand All @@ -33,18 +35,20 @@
import org.apache.beam.sdk.transforms.GroupByKey;
import org.apache.beam.sdk.transforms.Keys;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.windowing.IntervalWindow;
import org.apache.beam.sdk.transforms.windowing.PaneInfo;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.joda.time.Instant;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

/**
* Tests for {@link KeyedPValueTrackingVisitor}.
*/
/** Tests for {@link KeyedPValueTrackingVisitor}. */
@RunWith(JUnit4.class)
public class KeyedPValueTrackingVisitorTest {
@Rule public ExpectedException thrown = ExpectedException.none();
Expand All @@ -61,8 +65,7 @@ public void setup() {
@Test
public void groupByKeyProducesKeyedOutput() {
PCollection<KV<String, Iterable<Integer>>> keyed =
p.apply(Create.of(KV.of("foo", 3)))
.apply(GroupByKey.<String, Integer>create());
p.apply(Create.of(KV.of("foo", 3))).apply(GroupByKey.<String, Integer>create());

p.traverseTopologically(visitor);
assertThat(visitor.getKeyedPValues(), hasItem(keyed));
Expand Down Expand Up @@ -90,17 +93,67 @@ public void keyedInputWithoutKeyPreserving() {
assertThat(visitor.getKeyedPValues(), not(hasItem(onceKeyed)));
}

@Test
public void unkeyedInputWithKeyPreserving() {

PCollection<KV<String, Iterable<WindowedValue<KV<String, Integer>>>>> input =
p.apply(
Create.of(
KV.of(
"hello",
(Iterable<WindowedValue<KV<String, Integer>>>)
Collections.<WindowedValue<KV<String, Integer>>>emptyList()))
.withCoder(
KvCoder.of(
StringUtf8Coder.of(),
IterableCoder.of(
WindowedValue.getValueOnlyCoder(
KvCoder.of(StringUtf8Coder.of(), VarIntCoder.of()))))));

PCollection<KeyedWorkItem<String, KV<String, Integer>>> unkeyed =
input.apply(ParDo.of(new ParDoMultiOverrideFactory.ToKeyedWorkItem<String, Integer>()));

p.traverseTopologically(visitor);
assertThat(visitor.getKeyedPValues(), not(hasItem(unkeyed)));
}

@Test
public void keyedInputWithKeyPreserving() {

PCollection<KV<String, WindowedValue<KV<String, Integer>>>> input =
p.apply(
Create.of(
KV.of(
"hello",
WindowedValue.of(
KV.of("hello", 3),
new Instant(0),
new IntervalWindow(new Instant(0), new Instant(9)),
PaneInfo.NO_FIRING)))
.withCoder(
KvCoder.of(
StringUtf8Coder.of(),
WindowedValue.getValueOnlyCoder(
KvCoder.of(StringUtf8Coder.of(), VarIntCoder.of())))));

PCollection<KeyedWorkItem<String, KV<String, Integer>>> keyed =
input
.apply(GroupByKey.<String, WindowedValue<KV<String, Integer>>>create())
.apply(ParDo.of(new ParDoMultiOverrideFactory.ToKeyedWorkItem<String, Integer>()));

p.traverseTopologically(visitor);
assertThat(visitor.getKeyedPValues(), hasItem(keyed));
}

@Test
public void traverseMultipleTimesThrows() {
p.apply(
Create.<KV<Integer, Void>>of(
KV.of(1, (Void) null), KV.of(2, (Void) null), KV.of(3, (Void) null))
Create.of(KV.of(1, (Void) null), KV.of(2, (Void) null), KV.of(3, (Void) null))
.withCoder(KvCoder.of(VarIntCoder.of(), VoidCoder.of())))
.apply(GroupByKey.<Integer, Void>create())
.apply(Keys.<Integer>create());

p.traverseTopologically(visitor);

thrown.expect(IllegalStateException.class);
thrown.expectMessage("already been finalized");
thrown.expectMessage(KeyedPValueTrackingVisitor.class.getSimpleName());
Expand Down

0 comments on commit 5d912cf

Please sign in to comment.