From ca7b9c288151d318898ab000b91d26fcf62046ca Mon Sep 17 00:00:00 2001 From: Kenneth Knowles Date: Thu, 25 May 2017 06:29:09 -0700 Subject: [PATCH 1/2] Add Runner API oriented PTransformMatchers for DirectRunner overrides --- .../core/construction/PTransformMatchers.java | 94 ++++++++++++++++++- .../construction/PTransformTranslation.java | 7 +- .../construction/PTransformMatchersTest.java | 32 +++++++ 3 files changed, 128 insertions(+), 5 deletions(-) diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformMatchers.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformMatchers.java index bfe24a02ab63..c339891d51ed 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformMatchers.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformMatchers.java @@ -18,6 +18,7 @@ package org.apache.beam.runners.core.construction; import com.google.common.base.MoreObjects; +import java.io.IOException; import java.util.HashSet; import java.util.Set; import org.apache.beam.sdk.annotations.Experimental; @@ -49,6 +50,34 @@ public class PTransformMatchers { private PTransformMatchers() {} + /** + * Returns a {@link PTransformMatcher} that matches a {@link PTransform} if the URN of the + * {@link PTransform} is equal to the URN provided ot this matcher. + */ + public static PTransformMatcher urnEqualTo(String urn) { + return new EqualUrnPTransformMatcher(urn); + } + + private static class EqualUrnPTransformMatcher implements PTransformMatcher { + private final String urn; + + private EqualUrnPTransformMatcher(String urn) { + this.urn = urn; + } + + @Override + public boolean matches(AppliedPTransform application) { + return urn.equals(PTransformTranslation.urnForTransformOrNull(application.getTransform())); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("urn", urn) + .toString(); + } + } + /** * Returns a {@link PTransformMatcher} that matches a {@link PTransform} if the class of the * {@link PTransform} is equal to the {@link Class} provided ot this matcher. @@ -150,6 +179,68 @@ public String toString() { }; } + /** + * A {@link PTransformMatcher} that matches a {@link ParDo} by URN if it has a splittable {@link + * DoFn}. + */ + public static PTransformMatcher splittableParDo() { + return new PTransformMatcher() { + @Override + public boolean matches(AppliedPTransform application) { + if (PTransformTranslation.PAR_DO_TRANSFORM_URN.equals( + PTransformTranslation.urnForTransformOrNull(application.getTransform()))) { + + try { + return ParDoTranslation.isSplittable(application); + } catch (IOException e) { + throw new RuntimeException( + String.format( + "Transform with URN %s could not be translated", + PTransformTranslation.PAR_DO_TRANSFORM_URN), + e); + } + } + return false; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper("SplittableParDoMultiMatcher").toString(); + } + }; + } + + /** + * A {@link PTransformMatcher} that matches a {@link ParDo} transform by URN + * and whether it contains state or timers as specified by {@link ParDoTranslation}. + */ + public static PTransformMatcher stateOrTimerParDo() { + return new PTransformMatcher() { + @Override + public boolean matches(AppliedPTransform application) { + if (PTransformTranslation.PAR_DO_TRANSFORM_URN.equals( + PTransformTranslation.urnForTransformOrNull(application.getTransform()))) { + + try { + return ParDoTranslation.usesStateOrTimers(application); + } catch (IOException e) { + throw new RuntimeException( + String.format( + "Transform with URN %s could not be translated", + PTransformTranslation.PAR_DO_TRANSFORM_URN), + e); + } + } + return false; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper("StateOrTimerParDoMatcher").toString(); + } + }; + } + /** * A {@link PTransformMatcher} that matches a {@link ParDo.MultiOutput} containing a {@link DoFn} * that uses state or timers, as specified by {@link DoFnSignature#usesState()} and @@ -268,7 +359,8 @@ public static PTransformMatcher writeWithRunnerDeterminedSharding() { return new PTransformMatcher() { @Override public boolean matches(AppliedPTransform application) { - if (application.getTransform() instanceof WriteFiles) { + if (PTransformTranslation.WRITE_FILES_TRANSFORM_URN.equals( + PTransformTranslation.urnForTransformOrNull(application.getTransform()))) { WriteFiles write = (WriteFiles) application.getTransform(); return write.getSharding() == null && write.getNumShards() == null; } diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformTranslation.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformTranslation.java index 32ecf430c271..bae7b0574b22 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformTranslation.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformTranslation.java @@ -179,13 +179,12 @@ public static String urnForTransformOrNull(PTransform transform) { * Returns the URN for the transform if it is known, otherwise throws. */ public static String urnForTransform(PTransform transform) { - TransformPayloadTranslator translator = KNOWN_PAYLOAD_TRANSLATORS.get(transform.getClass()); - if (translator == null) { + String urn = urnForTransformOrNull(transform); + if (urn == null) { throw new IllegalStateException( String.format("No translator known for %s", transform.getClass().getName())); } - - return translator.getUrn(transform); + return urn; } /** diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/PTransformMatchersTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/PTransformMatchersTest.java index 249759880803..6459849f24fa 100644 --- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/PTransformMatchersTest.java +++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/PTransformMatchersTest.java @@ -27,6 +27,8 @@ import com.google.common.collect.ImmutableMap; import java.io.Serializable; import java.util.Collections; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.coders.VarIntCoder; import org.apache.beam.sdk.coders.VoidCoder; import org.apache.beam.sdk.io.DefaultFilenamePolicy; @@ -95,9 +97,14 @@ public class PTransformMatchersTest implements Serializable { PCollection> input = PCollection.createPrimitiveOutputInternal( p, WindowingStrategy.globalDefault(), IsBounded.BOUNDED); + input.setName("dummy input"); + input.setCoder(KvCoder.of(StringUtf8Coder.of(), VarIntCoder.of())); + PCollection output = PCollection.createPrimitiveOutputInternal( p, WindowingStrategy.globalDefault(), IsBounded.BOUNDED); + output.setName("dummy output"); + output.setCoder(VarIntCoder.of()); return AppliedPTransform.of("pardo", input.expand(), output.expand(), pardo, p); } @@ -271,6 +278,18 @@ public void parDoMultiSplittable() { assertThat(PTransformMatchers.stateOrTimerParDoSingle().matches(parDoApplication), is(false)); } + @Test + public void parDoSplittable() { + AppliedPTransform parDoApplication = + getAppliedTransform( + ParDo.of(splittableDoFn).withOutputTags(new TupleTag(), TupleTagList.empty())); + assertThat(PTransformMatchers.splittableParDo().matches(parDoApplication), is(true)); + + assertThat(PTransformMatchers.stateOrTimerParDoMulti().matches(parDoApplication), is(false)); + assertThat(PTransformMatchers.splittableParDoSingle().matches(parDoApplication), is(false)); + assertThat(PTransformMatchers.stateOrTimerParDoSingle().matches(parDoApplication), is(false)); + } + @Test public void parDoMultiWithState() { AppliedPTransform parDoApplication = @@ -283,6 +302,19 @@ public void parDoMultiWithState() { assertThat(PTransformMatchers.stateOrTimerParDoSingle().matches(parDoApplication), is(false)); } + @Test + public void parDoWithState() { + AppliedPTransform statefulApplication = + getAppliedTransform( + ParDo.of(doFnWithState).withOutputTags(new TupleTag(), TupleTagList.empty())); + assertThat(PTransformMatchers.stateOrTimerParDo().matches(statefulApplication), is(true)); + + AppliedPTransform splittableApplication = + getAppliedTransform( + ParDo.of(splittableDoFn).withOutputTags(new TupleTag(), TupleTagList.empty())); + assertThat(PTransformMatchers.stateOrTimerParDo().matches(splittableApplication), is(false)); + } + @Test public void parDoMultiWithTimers() { AppliedPTransform parDoApplication = From d8d9087877c01f1786271726a541fb3eeda7f939 Mon Sep 17 00:00:00 2001 From: Kenneth Knowles Date: Thu, 25 May 2017 06:31:16 -0700 Subject: [PATCH 2/2] DirectRunner override matchers using Runner API --- .../beam/runners/direct/DirectRunner.java | 20 +++++++++---------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java index dbd1ec47ed54..136ccf3bd86a 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java @@ -30,6 +30,7 @@ import java.util.Set; import org.apache.beam.runners.core.SplittableParDoViaKeyedWorkItems; import org.apache.beam.runners.core.construction.PTransformMatchers; +import org.apache.beam.runners.core.construction.PTransformTranslation; import org.apache.beam.runners.core.construction.SplittableParDo; import org.apache.beam.runners.direct.DirectRunner.DirectPipelineResult; import org.apache.beam.runners.direct.TestStreamEvaluatorFactory.DirectTestStreamFactory; @@ -42,12 +43,9 @@ import org.apache.beam.sdk.metrics.MetricsEnvironment; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.runners.PTransformOverride; -import org.apache.beam.sdk.testing.TestStream; -import org.apache.beam.sdk.transforms.GroupByKey; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.ParDo.MultiOutput; -import org.apache.beam.sdk.transforms.View.CreatePCollectionView; import org.apache.beam.sdk.util.UserCodeException; import org.apache.beam.sdk.values.PCollection; import org.joda.time.Duration; @@ -230,33 +228,33 @@ List defaultTransformOverrides() { new WriteWithShardingFactory())) /* Uses a view internally. */ .add( PTransformOverride.of( - PTransformMatchers.classEqualTo(CreatePCollectionView.class), + PTransformMatchers.urnEqualTo(PTransformTranslation.CREATE_VIEW_TRANSFORM_URN), new ViewOverrideFactory())) /* Uses pardos and GBKs */ .add( PTransformOverride.of( - PTransformMatchers.classEqualTo(TestStream.class), + PTransformMatchers.urnEqualTo(PTransformTranslation.TEST_STREAM_TRANSFORM_URN), new DirectTestStreamFactory(this))) /* primitive */ // SplittableParMultiDo is implemented in terms of nonsplittable simple ParDos and extra // primitives .add( PTransformOverride.of( - PTransformMatchers.splittableParDoMulti(), new ParDoMultiOverrideFactory())) + PTransformMatchers.splittableParDo(), new ParDoMultiOverrideFactory())) // state and timer pardos are implemented in terms of simple ParDos and extra primitives .add( PTransformOverride.of( - PTransformMatchers.stateOrTimerParDoMulti(), new ParDoMultiOverrideFactory())) + PTransformMatchers.stateOrTimerParDo(), new ParDoMultiOverrideFactory())) .add( PTransformOverride.of( - PTransformMatchers.classEqualTo(SplittableParDo.ProcessKeyedElements.class), + PTransformMatchers.urnEqualTo( + SplittableParDo.SPLITTABLE_PROCESS_KEYED_ELEMENTS_URN), new SplittableParDoViaKeyedWorkItems.OverrideFactory())) .add( PTransformOverride.of( - PTransformMatchers.classEqualTo( - SplittableParDoViaKeyedWorkItems.GBKIntoKeyedWorkItems.class), + PTransformMatchers.urnEqualTo(SplittableParDo.SPLITTABLE_GBKIKWI_URN), new DirectGBKIntoKeyedWorkItemsOverrideFactory())) /* Returns a GBKO */ .add( PTransformOverride.of( - PTransformMatchers.classEqualTo(GroupByKey.class), + PTransformMatchers.urnEqualTo(PTransformTranslation.GROUP_BY_KEY_TRANSFORM_URN), new DirectGroupByKeyOverrideFactory())) /* returns two chained primitives. */ .build(); }