From 08b12dd8f35d46a62a51e1c9573dc9c4f01ad210 Mon Sep 17 00:00:00 2001 From: Thomas Groh Date: Tue, 23 Aug 2016 09:37:07 -0700 Subject: [PATCH] Only Encode elements in EncodabilityEnforcement once Instead of checking that all input elements are encodable, ensure all elements produced by a PTransform can be encoded with the provided coder. This reduces the number of duplicate checks performed and enables EncodabilityEnforcement to be attached to Read PTransforms to ensure that provided coders can encode all elements output by a source. Enable EncodabilityEnforcement by Default This ensures that all elements, rather than only non-null elements, will have their encodability checked. --- .../beam/runners/direct/DirectOptions.java | 13 +- .../beam/runners/direct/DirectRunner.java | 16 ++- .../EncodabilityEnforcementFactory.java | 50 ++++--- .../beam/runners/direct/DirectRunnerTest.java | 66 ++++++++++ .../EncodabilityEnforcementFactoryTest.java | 122 ++++++++++++++---- 5 files changed, 215 insertions(+), 52 deletions(-) diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectOptions.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectOptions.java index 798fda4c4e9d..89e1bb805bcc 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectOptions.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectOptions.java @@ -49,10 +49,17 @@ public interface DirectOptions extends PipelineOptions, ApplicationNameOptions { @Default.Boolean(true) @Description( - "Controls whether the runner should ensure that all of the elements of every " + "Controls whether the DirectRunner should ensure that all of the elements of every " + "PCollection are not mutated. PTransforms are not permitted to mutate input elements " + "at any point, or output elements after they are output.") - boolean isTestImmutability(); + boolean isEnforceImmutability(); - void setTestImmutability(boolean test); + void setEnforceImmutability(boolean test); + + @Default.Boolean(true) + @Description( + "Controls whether the DirectRunner should ensure that all of the elements of every " + + "PCollection are encodable. All elements in a PCollection must be encodable.") + boolean isEnforceEncodability(); + void setEnforceEncodability(boolean test); } diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java index a3d20f69de5b..2ec4f08ddc75 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java @@ -39,6 +39,7 @@ import org.apache.beam.sdk.Pipeline.PipelineExecutionException; import org.apache.beam.sdk.PipelineResult; import org.apache.beam.sdk.annotations.Experimental; +import org.apache.beam.sdk.io.Read; import org.apache.beam.sdk.io.Write; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.runners.PipelineRunner; @@ -284,21 +285,32 @@ public DirectPipelineResult run(Pipeline pipeline) { Collection parDoEnforcements = createParDoEnforcements(options); enforcements.put(ParDo.Bound.class, parDoEnforcements); enforcements.put(ParDo.BoundMulti.class, parDoEnforcements); + if (options.isEnforceEncodability()) { + enforcements.put( + Read.Unbounded.class, + ImmutableSet.of(EncodabilityEnforcementFactory.create())); + enforcements.put( + Read.Bounded.class, + ImmutableSet.of(EncodabilityEnforcementFactory.create())); + } return enforcements.build(); } private Collection createParDoEnforcements( DirectOptions options) { ImmutableList.Builder enforcements = ImmutableList.builder(); - if (options.isTestImmutability()) { + if (options.isEnforceImmutability()) { enforcements.add(ImmutabilityEnforcementFactory.create()); } + if (options.isEnforceEncodability()) { + enforcements.add(EncodabilityEnforcementFactory.create()); + } return enforcements.build(); } private BundleFactory createBundleFactory(DirectOptions pipelineOptions) { BundleFactory bundleFactory = ImmutableListBundleFactory.create(); - if (pipelineOptions.isTestImmutability()) { + if (pipelineOptions.isEnforceImmutability()) { bundleFactory = ImmutabilityCheckingBundleFactory.create(bundleFactory); } return bundleFactory; diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/EncodabilityEnforcementFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/EncodabilityEnforcementFactory.java index bed61ec5465c..0a5f03fbae5b 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/EncodabilityEnforcementFactory.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/EncodabilityEnforcementFactory.java @@ -32,38 +32,48 @@ * {@link PCollection PCollection's} {@link Coder}. */ class EncodabilityEnforcementFactory implements ModelEnforcementFactory { + // The factory proper is stateless + private static final EncodabilityEnforcementFactory INSTANCE = + new EncodabilityEnforcementFactory(); + public static EncodabilityEnforcementFactory create() { - return new EncodabilityEnforcementFactory(); + return INSTANCE; } @Override public ModelEnforcement forBundle( CommittedBundle input, AppliedPTransform consumer) { - return new EncodabilityEnforcement<>(input); + return new EncodabilityEnforcement<>(); } private static class EncodabilityEnforcement extends AbstractModelEnforcement { - private Coder coder; - - public EncodabilityEnforcement(CommittedBundle input) { - coder = input.getPCollection().getCoder(); + @Override + public void afterFinish( + CommittedBundle input, + TransformResult result, + Iterable> outputs) { + for (CommittedBundle bundle : outputs) { + ensureBundleEncodable(bundle); + } } - @Override - public void beforeElement(WindowedValue element) { - try { - T clone = CoderUtils.clone(coder, element.getValue()); - if (coder.consistentWithEquals()) { - checkArgument( - coder.structuralValue(element.getValue()).equals(coder.structuralValue(clone)), - "Coder %s of class %s does not maintain structural value equality" - + " on input element %s", - coder, - coder.getClass().getSimpleName(), - element.getValue()); + private void ensureBundleEncodable(CommittedBundle bundle) { + Coder coder = bundle.getPCollection().getCoder(); + for (WindowedValue element : bundle.getElements()) { + try { + T clone = CoderUtils.clone(coder, element.getValue()); + if (coder.consistentWithEquals()) { + checkArgument( + coder.structuralValue(element.getValue()).equals(coder.structuralValue(clone)), + "Coder %s of class %s does not maintain structural value equality" + + " on input element %s", + coder, + coder.getClass().getSimpleName(), + element.getValue()); + } + } catch (Exception e) { + throw UserCodeException.wrap(e); } - } catch (Exception e) { - throw UserCodeException.wrap(e); } } } diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectRunnerTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectRunnerTest.java index c7efac388b68..4768fb030fdb 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectRunnerTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectRunnerTest.java @@ -19,11 +19,15 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.isA; import static org.junit.Assert.assertThat; import static org.junit.Assert.fail; import com.fasterxml.jackson.annotation.JsonValue; import com.google.common.collect.ImmutableMap; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; import java.io.Serializable; import java.util.Arrays; import java.util.List; @@ -31,9 +35,12 @@ import java.util.concurrent.atomic.AtomicInteger; import org.apache.beam.runners.direct.DirectRunner.DirectPipelineResult; import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.coders.AtomicCoder; import org.apache.beam.sdk.coders.CoderException; import org.apache.beam.sdk.coders.ListCoder; import org.apache.beam.sdk.coders.VarIntCoder; +import org.apache.beam.sdk.coders.VarLongCoder; +import org.apache.beam.sdk.io.CountingInput; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.runners.PipelineRunner; @@ -378,4 +385,63 @@ public void testMutatingInputCoderDoFnError() throws Exception { thrown.expectMessage("must not be mutated"); pipeline.run(); } + + @Test + public void testUnencodableOutputElement() throws Exception { + Pipeline p = getPipeline(); + PCollection pcollection = + p.apply(Create.of((Void) null)).apply(ParDo.of(new DoFn() { + @ProcessElement + public void processElement(ProcessContext c) { + c.output(null); + } + })).setCoder(VarLongCoder.of()); + pcollection + .apply( + ParDo.of( + new DoFn() { + @ProcessElement + public void unreachable(ProcessContext c) { + fail("Pipeline should fail to encode a null Long in VarLongCoder"); + } + })); + + thrown.expectCause(isA(CoderException.class)); + thrown.expectMessage("cannot encode a null Long"); + p.run(); + } + + @Test + public void testUnencodableOutputFromBoundedRead() throws Exception { + Pipeline p = getPipeline(); + PCollection pCollection = + p.apply(CountingInput.upTo(10)).setCoder(new LongNoDecodeCoder()); + + thrown.expectCause(isA(CoderException.class)); + thrown.expectMessage("Cannot decode a long"); + p.run(); + } + + @Test + public void testUnencodableOutputFromUnboundedRead() { + Pipeline p = getPipeline(); + PCollection pCollection = + p.apply(CountingInput.unbounded()).setCoder(new LongNoDecodeCoder()); + + thrown.expectCause(isA(CoderException.class)); + thrown.expectMessage("Cannot decode a long"); + p.run(); + } + + private static class LongNoDecodeCoder extends AtomicCoder { + @Override + public void encode( + Long value, OutputStream outStream, Context context) throws IOException { + } + + @Override + public Long decode(InputStream inStream, Context context) throws IOException { + throw new CoderException("Cannot decode a long"); + } + } } diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/EncodabilityEnforcementFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/EncodabilityEnforcementFactoryTest.java index 4da4aad9d6ce..e62bf015b2e8 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/EncodabilityEnforcementFactoryTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/EncodabilityEnforcementFactoryTest.java @@ -24,6 +24,8 @@ import java.io.OutputStream; import java.util.Collections; import org.apache.beam.runners.direct.DirectRunner.CommittedBundle; +import org.apache.beam.runners.direct.DirectRunner.UncommittedBundle; +import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.coders.AtomicCoder; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.CoderException; @@ -32,10 +34,15 @@ import org.apache.beam.sdk.transforms.AppliedPTransform; import org.apache.beam.sdk.transforms.Count; import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.MapElements; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.SimpleFunction; import org.apache.beam.sdk.util.UserCodeException; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.PCollection; import org.joda.time.Instant; +import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; @@ -51,16 +58,43 @@ public class EncodabilityEnforcementFactoryTest { private EncodabilityEnforcementFactory factory = EncodabilityEnforcementFactory.create(); private BundleFactory bundleFactory = ImmutableListBundleFactory.create(); + private PCollection inputPCollection; + private CommittedBundle inputBundle; + private PCollection outputPCollection; + + @Before + public void setup() { + Pipeline p = TestPipeline.create(); + inputPCollection = p.apply(Create.of(new Record()).withCoder(new RecordNoDecodeCoder())); + outputPCollection = inputPCollection.apply(ParDo.of(new IdentityDoFn())); + + inputBundle = + bundleFactory + .createRootBundle() + .add(WindowedValue.valueInGlobalWindow(new Record())) + .commit(Instant.now()); + } + @Test public void encodeFailsThrows() { WindowedValue record = WindowedValue.valueInGlobalWindow(new Record()); ModelEnforcement enforcement = createEnforcement(new RecordNoEncodeCoder(), record); + UncommittedBundle output = + bundleFactory.createBundle(outputPCollection).add(record); + + enforcement.beforeElement(record); + enforcement.afterElement(record); thrown.expect(UserCodeException.class); thrown.expectCause(isA(CoderException.class)); thrown.expectMessage("Encode not allowed"); - enforcement.beforeElement(record); + enforcement.afterFinish( + inputBundle, + StepTransformResult.withoutHold(outputPCollection.getProducingTransformInternal()) + .addOutput(output) + .build(), + Collections.>singleton(output.commit(Instant.now()))); } @Test @@ -69,10 +103,20 @@ public void decodeFailsThrows() { ModelEnforcement enforcement = createEnforcement(new RecordNoDecodeCoder(), record); + UncommittedBundle output = + bundleFactory.createBundle(outputPCollection).add(record); + + enforcement.beforeElement(record); + enforcement.afterElement(record); thrown.expect(UserCodeException.class); thrown.expectCause(isA(CoderException.class)); thrown.expectMessage("Decode not allowed"); - enforcement.beforeElement(record); + enforcement.afterFinish( + inputBundle, + StepTransformResult.withoutHold(outputPCollection.getProducingTransformInternal()) + .addOutput(output) + .build(), + Collections.>singleton(output.commit(Instant.now()))); } @Test @@ -89,46 +133,57 @@ public String toString() { ModelEnforcement enforcement = createEnforcement(new RecordStructuralValueCoder(), record); + UncommittedBundle output = + bundleFactory.createBundle(outputPCollection).add(record); + + enforcement.beforeElement(record); + enforcement.afterElement(record); + thrown.expect(UserCodeException.class); thrown.expectCause(isA(IllegalArgumentException.class)); thrown.expectMessage("does not maintain structural value equality"); thrown.expectMessage(RecordStructuralValueCoder.class.getSimpleName()); thrown.expectMessage("OriginalRecord"); - enforcement.beforeElement(record); + enforcement.afterFinish( + inputBundle, + StepTransformResult.withoutHold(outputPCollection.getProducingTransformInternal()) + .addOutput(output) + .build(), + Collections.>singleton(output.commit(Instant.now()))); } @Test public void notConsistentWithEqualsStructuralValueNotEqualSucceeds() { - TestPipeline p = TestPipeline.create(); - PCollection unencodable = - p.apply( - Create.of(new Record()) - .withCoder(new RecordNotConsistentWithEqualsStructuralValueCoder())); - AppliedPTransform consumer = - unencodable.apply(Count.globally()).getProducingTransformInternal(); - + outputPCollection.setCoder(new RecordNotConsistentWithEqualsStructuralValueCoder()); WindowedValue record = WindowedValue.valueInGlobalWindow(new Record()); - CommittedBundle input = - bundleFactory.createBundle(unencodable).add(record).commit(Instant.now()); - ModelEnforcement enforcement = factory.forBundle(input, consumer); + ModelEnforcement enforcement = + factory.forBundle(inputBundle, outputPCollection.getProducingTransformInternal()); + + UncommittedBundle output = + bundleFactory.createBundle(outputPCollection).add(record); enforcement.beforeElement(record); enforcement.afterElement(record); enforcement.afterFinish( - input, - StepTransformResult.withoutHold(consumer).build(), - Collections.>emptyList()); + inputBundle, + StepTransformResult.withoutHold(outputPCollection.getProducingTransformInternal()) + .addOutput(output) + .build(), + Collections.>singleton(output.commit(Instant.now()))); } - private ModelEnforcement createEnforcement(Coder coder, WindowedValue record) { + private ModelEnforcement createEnforcement( + Coder coder, WindowedValue record) { TestPipeline p = TestPipeline.create(); - PCollection unencodable = p.apply(Create.of().withCoder(coder)); - AppliedPTransform consumer = - unencodable.apply(Count.globally()).getProducingTransformInternal(); - CommittedBundle input = + PCollection unencodable = p.apply(Create.of().withCoder(coder)); + outputPCollection = + unencodable.apply( + MapElements.via(new SimpleIdentity())); + AppliedPTransform consumer = outputPCollection.getProducingTransformInternal(); + CommittedBundle input = bundleFactory.createBundle(unencodable).add(record).commit(Instant.now()); - ModelEnforcement enforcement = factory.forBundle(input, consumer); + ModelEnforcement enforcement = factory.forBundle(input, consumer); return enforcement; } @@ -161,14 +216,14 @@ public void encode( Record value, OutputStream outStream, org.apache.beam.sdk.coders.Coder.Context context) - throws CoderException, IOException { + throws IOException { throw new CoderException("Encode not allowed"); } @Override public Record decode( InputStream inStream, org.apache.beam.sdk.coders.Coder.Context context) - throws CoderException, IOException { + throws IOException { return null; } } @@ -179,12 +234,12 @@ public void encode( Record value, OutputStream outStream, org.apache.beam.sdk.coders.Coder.Context context) - throws CoderException, IOException {} + throws IOException {} @Override public Record decode( InputStream inStream, org.apache.beam.sdk.coders.Coder.Context context) - throws CoderException, IOException { + throws IOException { throw new CoderException("Decode not allowed"); } } @@ -252,4 +307,17 @@ public Object structuralValue(Record value) { } } + private static class IdentityDoFn extends DoFn { + @ProcessElement + public void proc(ProcessContext ctxt) { + ctxt.output(ctxt.element()); + } + } + + private static class SimpleIdentity extends SimpleFunction { + @Override + public Record apply(Record input) { + return input; + } + } }