From ca1b3a87fbe7218c0af912ba0f0deae8b903b1ac Mon Sep 17 00:00:00 2001 From: Kenneth Knowles Date: Thu, 28 Apr 2016 15:51:40 -0700 Subject: [PATCH 01/21] Add accessors for sub-coders of KeyedWorkItemCoder --- .../java/org/apache/beam/sdk/util/KeyedWorkItemCoder.java | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/KeyedWorkItemCoder.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/KeyedWorkItemCoder.java index 763f68b302e3..ec5d82138269 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/KeyedWorkItemCoder.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/KeyedWorkItemCoder.java @@ -79,6 +79,14 @@ private KeyedWorkItemCoder( this.elemsCoder = IterableCoder.of(FullWindowedValueCoder.of(elemCoder, windowCoder)); } + public Coder getKeyCoder() { + return keyCoder; + } + + public Coder getElementCoder() { + return elemCoder; + } + @Override public void encode(KeyedWorkItem value, OutputStream outStream, Coder.Context context) throws CoderException, IOException { From aad284a513e829552e9ae7fa10ea89cfd89bdb5f Mon Sep 17 00:00:00 2001 From: Kenneth Knowles Date: Thu, 28 Apr 2016 16:12:21 -0700 Subject: [PATCH 02/21] Make in-process GroupByKey respect future Beam model This introduces or clarifies the following transforms: - InProcessGroupByKey, which expands like GroupByKeyViaGroupByKeyOnly but with different intermediate PCollection types. - InProcessGroupByKeyOnly, which outputs KeyedWorkItem. This existed already under a different name. - InProcessGroupAlsoByWindow, which is evaluated directly and accepts input elements of type KeyedWorkItem. --- .../beam/runners/direct/BundleFactory.java | 2 +- .../direct/InProcessEvaluationContext.java | 2 +- ...cessGroupAlsoByWindowEvaluatorFactory.java | 127 ++++++++++++ .../runners/direct/InProcessGroupByKey.java | 132 +++++++++++++ ...rocessGroupByKeyOnlyEvaluatorFactory.java} | 153 +++------------ .../InProcessGroupByKeyOverrideFactory.java | 41 ++++ .../direct/InProcessPipelineRunner.java | 3 +- .../direct/TransformEvaluatorRegistry.java | 8 +- .../GroupByKeyEvaluatorFactoryTest.java | 4 +- ...essGroupByKeyOnlyEvaluatorFactoryTest.java | 183 ++++++++++++++++++ 10 files changed, 524 insertions(+), 131 deletions(-) create mode 100644 runners/direct-java/src/main/java/org/apache/beam/runners/direct/InProcessGroupAlsoByWindowEvaluatorFactory.java create mode 100644 runners/direct-java/src/main/java/org/apache/beam/runners/direct/InProcessGroupByKey.java rename runners/direct-java/src/main/java/org/apache/beam/runners/direct/{GroupByKeyEvaluatorFactory.java => InProcessGroupByKeyOnlyEvaluatorFactory.java} (52%) create mode 100644 runners/direct-java/src/main/java/org/apache/beam/runners/direct/InProcessGroupByKeyOverrideFactory.java create mode 100644 runners/direct-java/src/test/java/org/apache/beam/runners/direct/InProcessGroupByKeyOnlyEvaluatorFactoryTest.java diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/BundleFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/BundleFactory.java index 34529e7803c1..fea48416252d 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/BundleFactory.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/BundleFactory.java @@ -17,7 +17,7 @@ */ package org.apache.beam.runners.direct; -import org.apache.beam.runners.direct.GroupByKeyEvaluatorFactory.InProcessGroupByKeyOnly; +import org.apache.beam.runners.direct.InProcessGroupByKey.InProcessGroupByKeyOnly; import org.apache.beam.runners.direct.InProcessPipelineRunner.CommittedBundle; import org.apache.beam.runners.direct.InProcessPipelineRunner.UncommittedBundle; import org.apache.beam.sdk.transforms.PTransform; diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/InProcessEvaluationContext.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/InProcessEvaluationContext.java index 9eeafbb5afe7..f348d9346a60 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/InProcessEvaluationContext.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/InProcessEvaluationContext.java @@ -19,9 +19,9 @@ import static com.google.common.base.Preconditions.checkNotNull; -import org.apache.beam.runners.direct.GroupByKeyEvaluatorFactory.InProcessGroupByKeyOnly; import org.apache.beam.runners.direct.InMemoryWatermarkManager.FiredTimers; import org.apache.beam.runners.direct.InMemoryWatermarkManager.TransformWatermarks; +import org.apache.beam.runners.direct.InProcessGroupByKey.InProcessGroupByKeyOnly; import org.apache.beam.runners.direct.InProcessPipelineRunner.CommittedBundle; import org.apache.beam.runners.direct.InProcessPipelineRunner.PCollectionViewWriter; import org.apache.beam.runners.direct.InProcessPipelineRunner.UncommittedBundle; diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/InProcessGroupAlsoByWindowEvaluatorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/InProcessGroupAlsoByWindowEvaluatorFactory.java new file mode 100644 index 000000000000..5ded8b68f8d3 --- /dev/null +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/InProcessGroupAlsoByWindowEvaluatorFactory.java @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.direct; + +import org.apache.beam.runners.core.GroupAlsoByWindowViaWindowSetDoFn; +import org.apache.beam.runners.direct.InProcessGroupByKey.InProcessGroupAlsoByWindow; +import org.apache.beam.runners.direct.InProcessPipelineRunner.CommittedBundle; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.transforms.AppliedPTransform; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.util.GroupByKeyViaGroupByKeyOnly; +import org.apache.beam.sdk.util.GroupByKeyViaGroupByKeyOnly.GroupByKeyOnly; +import org.apache.beam.sdk.util.KeyedWorkItem; +import org.apache.beam.sdk.util.SystemReduceFn; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.util.WindowingStrategy; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.TupleTag; + +import com.google.common.collect.ImmutableMap; + +import java.util.Collections; + +/** + * The {@link InProcessPipelineRunner} {@link TransformEvaluatorFactory} for the + * {@link GroupByKeyOnly} {@link PTransform}. + */ +class InProcessGroupAlsoByWindowEvaluatorFactory implements TransformEvaluatorFactory { + @Override + public TransformEvaluator forApplication( + AppliedPTransform application, + CommittedBundle inputBundle, + InProcessEvaluationContext evaluationContext) { + @SuppressWarnings({"cast", "unchecked", "rawtypes"}) + TransformEvaluator evaluator = + createEvaluator( + (AppliedPTransform) application, (CommittedBundle) inputBundle, evaluationContext); + return evaluator; + } + + private TransformEvaluator> createEvaluator( + AppliedPTransform< + PCollection>, PCollection>>, + InProcessGroupAlsoByWindow> + application, + CommittedBundle> inputBundle, + InProcessEvaluationContext evaluationContext) { + return new InProcessGroupAlsoByWindowEvaluator( + evaluationContext, inputBundle, application); + } + + /** + * A transform evaluator for the pseudo-primitive {@link GroupAlsoByWindow}. Windowing is ignored; + * all input should be in the global window since all output will be as well. + * + * @see GroupByKeyViaGroupByKeyOnly + */ + private static class InProcessGroupAlsoByWindowEvaluator + implements TransformEvaluator> { + + private final TransformEvaluator> gabwParDoEvaluator; + + public InProcessGroupAlsoByWindowEvaluator( + final InProcessEvaluationContext evaluationContext, + CommittedBundle> inputBundle, + final AppliedPTransform< + PCollection>, PCollection>>, + InProcessGroupAlsoByWindow> + application) { + + Coder valueCoder = + application.getTransform().getValueCoder(inputBundle.getPCollection().getCoder()); + + @SuppressWarnings("unchecked") + WindowingStrategy windowingStrategy = + (WindowingStrategy) application.getTransform().getWindowingStrategy(); + + DoFn, KV>> gabwDoFn = + GroupAlsoByWindowViaWindowSetDoFn.create( + windowingStrategy, + SystemReduceFn.buffering(valueCoder)); + + TupleTag>> mainOutputTag = new TupleTag>>() {}; + + // Not technically legit, as the application is not a ParDo + this.gabwParDoEvaluator = + ParDoInProcessEvaluator.create( + evaluationContext, + inputBundle, + application, + gabwDoFn, + Collections.>emptyList(), + mainOutputTag, + Collections.>emptyList(), + ImmutableMap., PCollection>of(mainOutputTag, application.getOutput())); + } + + @Override + public void processElement(WindowedValue> element) throws Exception { + gabwParDoEvaluator.processElement(element); + } + + @Override + public InProcessTransformResult finishBundle() throws Exception { + return gabwParDoEvaluator.finishBundle(); + } + } +} diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/InProcessGroupByKey.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/InProcessGroupByKey.java new file mode 100644 index 000000000000..026b4d5636f3 --- /dev/null +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/InProcessGroupByKey.java @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.direct; + +import static com.google.common.base.Preconditions.checkArgument; + +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.IterableCoder; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.transforms.GroupByKey; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.util.GroupByKeyViaGroupByKeyOnly.ReifyTimestampsAndWindows; +import org.apache.beam.sdk.util.KeyedWorkItem; +import org.apache.beam.sdk.util.KeyedWorkItemCoder; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.util.WindowingStrategy; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollection; + +class InProcessGroupByKey + extends ForwardingPTransform>, PCollection>>> { + private final GroupByKey original; + + InProcessGroupByKey(GroupByKey from) { + this.original = from; + } + + @Override + public PTransform>, PCollection>>> delegate() { + return original; + } + + @Override + public PCollection>> apply(PCollection> input) { + @SuppressWarnings("unchecked") + KvCoder inputCoder = (KvCoder) input.getCoder(); + + // This operation groups by the combination of key and window, + // merging windows as needed, using the windows assigned to the + // key/value input elements and the window merge operation of the + // window function associated with the input PCollection. + WindowingStrategy windowingStrategy = input.getWindowingStrategy(); + + // By default, implement GroupByKey via a series of lower-level operations. + return input + // Make each input element's timestamp and assigned windows + // explicit, in the value part. + .apply(new ReifyTimestampsAndWindows()) + .apply(new InProcessGroupByKeyOnly()) + .setCoder( + KeyedWorkItemCoder.of( + inputCoder.getKeyCoder(), + inputCoder.getValueCoder(), + input.getWindowingStrategy().getWindowFn().windowCoder())) + + // Group each key's values by window, merging windows as needed. + .apply("GroupAlsoByWindow", new InProcessGroupAlsoByWindow(windowingStrategy)) + + // And update the windowing strategy as appropriate. + .setWindowingStrategyInternal(original.updateWindowingStrategy(windowingStrategy)) + .setCoder( + KvCoder.of(inputCoder.getKeyCoder(), IterableCoder.of(inputCoder.getValueCoder()))); + } + + static final class InProcessGroupByKeyOnly + extends PTransform>>, PCollection>> { + @Override + public PCollection> apply(PCollection>> input) { + return PCollection.>createPrimitiveOutputInternal( + input.getPipeline(), input.getWindowingStrategy(), input.isBounded()); + } + + InProcessGroupByKeyOnly() {} + } + + static final class InProcessGroupAlsoByWindow + extends PTransform>, PCollection>>> { + + private final WindowingStrategy windowingStrategy; + + public InProcessGroupAlsoByWindow(WindowingStrategy windowingStrategy) { + this.windowingStrategy = windowingStrategy; + } + + public WindowingStrategy getWindowingStrategy() { + return windowingStrategy; + } + + private KeyedWorkItemCoder getKeyedWorkItemCoder(Coder> inputCoder) { + // Coder> --> KvCoder<...> + checkArgument( + inputCoder instanceof KeyedWorkItemCoder, + "%s requires a %s<...> but got %s", + getClass().getSimpleName(), + KvCoder.class.getSimpleName(), + inputCoder); + @SuppressWarnings("unchecked") + KeyedWorkItemCoder kvCoder = (KeyedWorkItemCoder) inputCoder; + return kvCoder; + } + + public Coder getKeyCoder(Coder> inputCoder) { + return getKeyedWorkItemCoder(inputCoder).getKeyCoder(); + } + + public Coder getValueCoder(Coder> inputCoder) { + return getKeyedWorkItemCoder(inputCoder).getElementCoder(); + } + + @Override + public PCollection>> apply(PCollection> input) { + return PCollection.>>createPrimitiveOutputInternal( + input.getPipeline(), input.getWindowingStrategy(), input.isBounded()); + } + } +} diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/GroupByKeyEvaluatorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/InProcessGroupByKeyOnlyEvaluatorFactory.java similarity index 52% rename from runners/direct-java/src/main/java/org/apache/beam/runners/direct/GroupByKeyEvaluatorFactory.java rename to runners/direct-java/src/main/java/org/apache/beam/runners/direct/InProcessGroupByKeyOnlyEvaluatorFactory.java index 9a08996be215..79db5b696659 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/GroupByKeyEvaluatorFactory.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/InProcessGroupByKeyOnlyEvaluatorFactory.java @@ -19,33 +19,24 @@ import static org.apache.beam.sdk.util.CoderUtils.encodeToByteArray; -import org.apache.beam.runners.core.GroupAlsoByWindowViaWindowSetDoFn; +import static com.google.common.base.Preconditions.checkState; + +import org.apache.beam.runners.direct.InProcessGroupByKey.InProcessGroupByKeyOnly; import org.apache.beam.runners.direct.InProcessPipelineRunner.CommittedBundle; import org.apache.beam.runners.direct.InProcessPipelineRunner.UncommittedBundle; import org.apache.beam.runners.direct.StepTransformResult.Builder; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.CoderException; -import org.apache.beam.sdk.coders.IterableCoder; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.transforms.AppliedPTransform; -import org.apache.beam.sdk.transforms.DoFn; -import org.apache.beam.sdk.transforms.GroupByKey; import org.apache.beam.sdk.transforms.PTransform; -import org.apache.beam.sdk.transforms.ParDo; -import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.util.GroupByKeyViaGroupByKeyOnly.ReifyTimestampsAndWindows; +import org.apache.beam.sdk.util.GroupByKeyViaGroupByKeyOnly; +import org.apache.beam.sdk.util.GroupByKeyViaGroupByKeyOnly.GroupByKeyOnly; import org.apache.beam.sdk.util.KeyedWorkItem; -import org.apache.beam.sdk.util.KeyedWorkItemCoder; import org.apache.beam.sdk.util.KeyedWorkItems; -import org.apache.beam.sdk.util.SystemReduceFn; import org.apache.beam.sdk.util.WindowedValue; -import org.apache.beam.sdk.util.WindowingStrategy; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; -import org.apache.beam.sdk.values.PInput; -import org.apache.beam.sdk.values.POutput; - -import com.google.common.annotations.VisibleForTesting; import java.util.ArrayList; import java.util.Arrays; @@ -54,17 +45,18 @@ import java.util.Map; /** - * The {@link InProcessPipelineRunner} {@link TransformEvaluatorFactory} for the {@link GroupByKey} - * {@link PTransform}. + * The {@link InProcessPipelineRunner} {@link TransformEvaluatorFactory} for the + * {@link GroupByKeyOnly} {@link PTransform}. */ -class GroupByKeyEvaluatorFactory implements TransformEvaluatorFactory { +class InProcessGroupByKeyOnlyEvaluatorFactory implements TransformEvaluatorFactory { @Override public TransformEvaluator forApplication( AppliedPTransform application, CommittedBundle inputBundle, InProcessEvaluationContext evaluationContext) { @SuppressWarnings({"cast", "unchecked", "rawtypes"}) - TransformEvaluator evaluator = createEvaluator( + TransformEvaluator evaluator = + createEvaluator( (AppliedPTransform) application, (CommittedBundle) inputBundle, evaluationContext); return evaluator; } @@ -74,16 +66,22 @@ private TransformEvaluator>> createEvaluator( PCollection>>, PCollection>, InProcessGroupByKeyOnly> application, - final CommittedBundle> inputBundle, + final CommittedBundle>> inputBundle, final InProcessEvaluationContext evaluationContext) { - return new GroupByKeyEvaluator(evaluationContext, inputBundle, application); + return new InProcessGroupByKeyOnlyEvaluator(evaluationContext, inputBundle, application); } - private static class GroupByKeyEvaluator + /** + * A transform evaluator for the pseudo-primitive {@link GroupByKeyOnly}. Windowing is ignored; + * all input should be in the global window since all output will be as well. + * + * @see GroupByKeyViaGroupByKeyOnly + */ + private static class InProcessGroupByKeyOnlyEvaluator implements TransformEvaluator>> { private final InProcessEvaluationContext evaluationContext; - private final CommittedBundle> inputBundle; + private final CommittedBundle>> inputBundle; private final AppliedPTransform< PCollection>>, PCollection>, InProcessGroupByKeyOnly> @@ -91,9 +89,9 @@ private static class GroupByKeyEvaluator private final Coder keyCoder; private Map, List>> groupingMap; - public GroupByKeyEvaluator( + public InProcessGroupByKeyOnlyEvaluator( InProcessEvaluationContext evaluationContext, - CommittedBundle> inputBundle, + CommittedBundle>> inputBundle, AppliedPTransform< PCollection>>, PCollection>, InProcessGroupByKeyOnly> @@ -101,16 +99,18 @@ public GroupByKeyEvaluator( this.evaluationContext = evaluationContext; this.inputBundle = inputBundle; this.application = application; - - PCollection>> input = application.getInput(); - keyCoder = getKeyCoder(input.getCoder()); - groupingMap = new HashMap<>(); + this.keyCoder = getKeyCoder(application.getInput().getCoder()); + this.groupingMap = new HashMap<>(); } private Coder getKeyCoder(Coder>> coder) { - if (!(coder instanceof KvCoder)) { - throw new IllegalStateException(); - } + checkState( + coder instanceof KvCoder, + "%s requires a coder of class %s." + + " This is an internal error; this is checked during pipeline construction" + + " but became corrupted.", + getClass().getSimpleName(), + KvCoder.class.getSimpleName()); @SuppressWarnings("unchecked") Coder keyCoder = ((KvCoder>) coder).getKeyCoder(); return keyCoder; @@ -180,95 +180,4 @@ public int hashCode() { } } } - - /** - * A {@link PTransformOverrideFactory} for {@link GroupByKey} PTransforms. - */ - public static final class InProcessGroupByKeyOverrideFactory - implements PTransformOverrideFactory { - @Override - public PTransform override( - PTransform transform) { - if (transform instanceof GroupByKey) { - @SuppressWarnings({"rawtypes", "unchecked"}) - PTransform override = new InProcessGroupByKey((GroupByKey) transform); - return override; - } - return transform; - } - } - - /** - * An in-memory implementation of the {@link GroupByKey} primitive as a composite - * {@link PTransform}. - */ - private static final class InProcessGroupByKey - extends ForwardingPTransform>, PCollection>>> { - private final GroupByKey original; - - private InProcessGroupByKey(GroupByKey from) { - this.original = from; - } - - @Override - public PTransform>, PCollection>>> delegate() { - return original; - } - - @Override - public PCollection>> apply(PCollection> input) { - KvCoder inputCoder = (KvCoder) input.getCoder(); - - // This operation groups by the combination of key and window, - // merging windows as needed, using the windows assigned to the - // key/value input elements and the window merge operation of the - // window function associated with the input PCollection. - WindowingStrategy windowingStrategy = input.getWindowingStrategy(); - - // Use the default GroupAlsoByWindow implementation - DoFn, KV>> groupAlsoByWindow = - groupAlsoByWindow(windowingStrategy, inputCoder.getValueCoder()); - - // By default, implement GroupByKey via a series of lower-level operations. - return input - // Make each input element's timestamp and assigned windows - // explicit, in the value part. - .apply(new ReifyTimestampsAndWindows()) - - .apply(new InProcessGroupByKeyOnly()) - .setCoder(KeyedWorkItemCoder.of(inputCoder.getKeyCoder(), - inputCoder.getValueCoder(), input.getWindowingStrategy().getWindowFn().windowCoder())) - - // Group each key's values by window, merging windows as needed. - .apply("GroupAlsoByWindow", ParDo.of(groupAlsoByWindow)) - - // And update the windowing strategy as appropriate. - .setWindowingStrategyInternal(original.updateWindowingStrategy(windowingStrategy)) - .setCoder( - KvCoder.of(inputCoder.getKeyCoder(), IterableCoder.of(inputCoder.getValueCoder()))); - } - - private - DoFn, KV>> groupAlsoByWindow( - final WindowingStrategy windowingStrategy, final Coder inputCoder) { - return GroupAlsoByWindowViaWindowSetDoFn.create( - windowingStrategy, SystemReduceFn.buffering(inputCoder)); - } - } - - /** - * An implementation primitive to use in the evaluation of a {@link GroupByKey} - * {@link PTransform}. - */ - public static final class InProcessGroupByKeyOnly - extends PTransform>>, PCollection>> { - @Override - public PCollection> apply(PCollection>> input) { - return PCollection.>createPrimitiveOutputInternal( - input.getPipeline(), input.getWindowingStrategy(), input.isBounded()); - } - - @VisibleForTesting - InProcessGroupByKeyOnly() {} - } } diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/InProcessGroupByKeyOverrideFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/InProcessGroupByKeyOverrideFactory.java new file mode 100644 index 000000000000..1d84bc905fee --- /dev/null +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/InProcessGroupByKeyOverrideFactory.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.direct; + +import org.apache.beam.sdk.transforms.GroupByKey; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.values.PInput; +import org.apache.beam.sdk.values.POutput; + +/** + * A {@link PTransformOverrideFactory} for {@link GroupByKey} PTransforms. + */ +final class InProcessGroupByKeyOverrideFactory + implements PTransformOverrideFactory { + @Override + public PTransform override( + PTransform transform) { + if (transform instanceof GroupByKey) { + @SuppressWarnings({"rawtypes", "unchecked"}) + PTransform override = + (PTransform) new InProcessGroupByKey((GroupByKey) transform); + return override; + } + return transform; + } +} diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/InProcessPipelineRunner.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/InProcessPipelineRunner.java index 19e9f47de568..a7f6941a738c 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/InProcessPipelineRunner.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/InProcessPipelineRunner.java @@ -17,8 +17,7 @@ */ package org.apache.beam.runners.direct; -import org.apache.beam.runners.direct.GroupByKeyEvaluatorFactory.InProcessGroupByKeyOnly; -import org.apache.beam.runners.direct.GroupByKeyEvaluatorFactory.InProcessGroupByKeyOverrideFactory; +import org.apache.beam.runners.direct.InProcessGroupByKey.InProcessGroupByKeyOnly; import org.apache.beam.runners.direct.ViewEvaluatorFactory.InProcessViewOverrideFactory; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.Pipeline.PipelineExecutionException; diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformEvaluatorRegistry.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformEvaluatorRegistry.java index f449731a5918..81d252087c8c 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformEvaluatorRegistry.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformEvaluatorRegistry.java @@ -17,6 +17,8 @@ */ package org.apache.beam.runners.direct; +import org.apache.beam.runners.direct.InProcessGroupByKey.InProcessGroupAlsoByWindow; +import org.apache.beam.runners.direct.InProcessGroupByKey.InProcessGroupByKeyOnly; import org.apache.beam.runners.direct.InProcessPipelineRunner.CommittedBundle; import org.apache.beam.sdk.io.Read; import org.apache.beam.sdk.transforms.AppliedPTransform; @@ -44,12 +46,12 @@ public static TransformEvaluatorRegistry defaultRegistry() { .put(Read.Unbounded.class, new UnboundedReadEvaluatorFactory()) .put(ParDo.Bound.class, new ParDoSingleEvaluatorFactory()) .put(ParDo.BoundMulti.class, new ParDoMultiEvaluatorFactory()) - .put( - GroupByKeyEvaluatorFactory.InProcessGroupByKeyOnly.class, - new GroupByKeyEvaluatorFactory()) .put(FlattenPCollectionList.class, new FlattenEvaluatorFactory()) .put(ViewEvaluatorFactory.WriteView.class, new ViewEvaluatorFactory()) .put(Window.Bound.class, new WindowEvaluatorFactory()) + // Runner-specific primitives used in expansion of GroupByKey + .put(InProcessGroupByKeyOnly.class, new InProcessGroupByKeyOnlyEvaluatorFactory()) + .put(InProcessGroupAlsoByWindow.class, new InProcessGroupAlsoByWindowEvaluatorFactory()) .build(); return new TransformEvaluatorRegistry(primitives); } diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/GroupByKeyEvaluatorFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/GroupByKeyEvaluatorFactoryTest.java index 267266d3b891..92f845c8bbef 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/GroupByKeyEvaluatorFactoryTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/GroupByKeyEvaluatorFactoryTest.java @@ -67,7 +67,7 @@ public void testInMemoryEvaluator() throws Exception { PCollection>> kvs = values.apply(new ReifyTimestampsAndWindows()); PCollection> groupedKvs = - kvs.apply(new GroupByKeyEvaluatorFactory.InProcessGroupByKeyOnly()); + kvs.apply(new InProcessGroupByKey.InProcessGroupByKeyOnly()); CommittedBundle>> inputBundle = bundleFactory.createRootBundle(kvs).commit(Instant.now()); @@ -89,7 +89,7 @@ public void testInMemoryEvaluator() throws Exception { Coder keyCoder = ((KvCoder>) kvs.getCoder()).getKeyCoder(); TransformEvaluator>> evaluator = - new GroupByKeyEvaluatorFactory() + new InProcessGroupByKeyOnlyEvaluatorFactory() .forApplication( groupedKvs.getProducingTransformInternal(), inputBundle, evaluationContext); diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/InProcessGroupByKeyOnlyEvaluatorFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/InProcessGroupByKeyOnlyEvaluatorFactoryTest.java new file mode 100644 index 000000000000..1172a4d08dfd --- /dev/null +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/InProcessGroupByKeyOnlyEvaluatorFactoryTest.java @@ -0,0 +1,183 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.direct; + +import static org.hamcrest.Matchers.contains; +import static org.junit.Assert.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import org.apache.beam.runners.direct.InProcessPipelineRunner.CommittedBundle; +import org.apache.beam.runners.direct.InProcessPipelineRunner.UncommittedBundle; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.util.GroupByKeyViaGroupByKeyOnly.ReifyTimestampsAndWindows; +import org.apache.beam.sdk.util.KeyedWorkItem; +import org.apache.beam.sdk.util.KeyedWorkItems; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollection; + +import com.google.common.collect.HashMultiset; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Multiset; + +import org.hamcrest.BaseMatcher; +import org.hamcrest.Description; +import org.joda.time.Instant; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Tests for {@link InProcessGroupByKeyOnlyEvaluatorFactory}. + */ +@RunWith(JUnit4.class) +public class InProcessGroupByKeyOnlyEvaluatorFactoryTest { + private BundleFactory bundleFactory = InProcessBundleFactory.create(); + + @Test + public void testInMemoryEvaluator() throws Exception { + TestPipeline p = TestPipeline.create(); + KV firstFoo = KV.of("foo", -1); + KV secondFoo = KV.of("foo", 1); + KV thirdFoo = KV.of("foo", 3); + KV firstBar = KV.of("bar", 22); + KV secondBar = KV.of("bar", 12); + KV firstBaz = KV.of("baz", Integer.MAX_VALUE); + PCollection> values = + p.apply(Create.of(firstFoo, firstBar, secondFoo, firstBaz, secondBar, thirdFoo)); + PCollection>> kvs = + values.apply(new ReifyTimestampsAndWindows()); + PCollection> groupedKvs = + kvs.apply(new InProcessGroupByKey.InProcessGroupByKeyOnly()); + + CommittedBundle>> inputBundle = + bundleFactory.createRootBundle(kvs).commit(Instant.now()); + InProcessEvaluationContext evaluationContext = mock(InProcessEvaluationContext.class); + + UncommittedBundle> fooBundle = + bundleFactory.createKeyedBundle(null, "foo", groupedKvs); + UncommittedBundle> barBundle = + bundleFactory.createKeyedBundle(null, "bar", groupedKvs); + UncommittedBundle> bazBundle = + bundleFactory.createKeyedBundle(null, "baz", groupedKvs); + + when(evaluationContext.createKeyedBundle(inputBundle, "foo", groupedKvs)).thenReturn(fooBundle); + when(evaluationContext.createKeyedBundle(inputBundle, "bar", groupedKvs)).thenReturn(barBundle); + when(evaluationContext.createKeyedBundle(inputBundle, "baz", groupedKvs)).thenReturn(bazBundle); + + // The input to a GroupByKey is assumed to be a KvCoder + @SuppressWarnings("unchecked") + Coder keyCoder = + ((KvCoder>) kvs.getCoder()).getKeyCoder(); + TransformEvaluator>> evaluator = + new InProcessGroupByKeyOnlyEvaluatorFactory() + .forApplication( + groupedKvs.getProducingTransformInternal(), inputBundle, evaluationContext); + + evaluator.processElement(WindowedValue.valueInEmptyWindows(gwValue(firstFoo))); + evaluator.processElement(WindowedValue.valueInEmptyWindows(gwValue(secondFoo))); + evaluator.processElement(WindowedValue.valueInEmptyWindows(gwValue(thirdFoo))); + evaluator.processElement(WindowedValue.valueInEmptyWindows(gwValue(firstBar))); + evaluator.processElement(WindowedValue.valueInEmptyWindows(gwValue(secondBar))); + evaluator.processElement(WindowedValue.valueInEmptyWindows(gwValue(firstBaz))); + + evaluator.finishBundle(); + + assertThat( + fooBundle.commit(Instant.now()).getElements(), + contains( + new KeyedWorkItemMatcher( + KeyedWorkItems.elementsWorkItem( + "foo", + ImmutableSet.of( + WindowedValue.valueInGlobalWindow(-1), + WindowedValue.valueInGlobalWindow(1), + WindowedValue.valueInGlobalWindow(3))), + keyCoder))); + assertThat( + barBundle.commit(Instant.now()).getElements(), + contains( + new KeyedWorkItemMatcher( + KeyedWorkItems.elementsWorkItem( + "bar", + ImmutableSet.of( + WindowedValue.valueInGlobalWindow(12), + WindowedValue.valueInGlobalWindow(22))), + keyCoder))); + assertThat( + bazBundle.commit(Instant.now()).getElements(), + contains( + new KeyedWorkItemMatcher( + KeyedWorkItems.elementsWorkItem( + "baz", + ImmutableSet.of(WindowedValue.valueInGlobalWindow(Integer.MAX_VALUE))), + keyCoder))); + } + + private KV> gwValue(KV kv) { + return KV.of(kv.getKey(), WindowedValue.valueInGlobalWindow(kv.getValue())); + } + + private static class KeyedWorkItemMatcher + extends BaseMatcher>> { + private final KeyedWorkItem myWorkItem; + private final Coder keyCoder; + + public KeyedWorkItemMatcher(KeyedWorkItem myWorkItem, Coder keyCoder) { + this.myWorkItem = myWorkItem; + this.keyCoder = keyCoder; + } + + @Override + public boolean matches(Object item) { + if (item == null || !(item instanceof WindowedValue)) { + return false; + } + WindowedValue> that = (WindowedValue>) item; + Multiset> myValues = HashMultiset.create(); + Multiset> thatValues = HashMultiset.create(); + for (WindowedValue value : myWorkItem.elementsIterable()) { + myValues.add(value); + } + for (WindowedValue value : that.getValue().elementsIterable()) { + thatValues.add(value); + } + try { + return myValues.equals(thatValues) + && keyCoder + .structuralValue(myWorkItem.key()) + .equals(keyCoder.structuralValue(that.getValue().key())); + } catch (Exception e) { + return false; + } + } + + @Override + public void describeTo(Description description) { + description + .appendText("KeyedWorkItem containing key ") + .appendValue(myWorkItem.key()) + .appendText(" and values ") + .appendValueList("[", ", ", "]", myWorkItem.elementsIterable()); + } + } +} From ebe69cf284e8b6a3c24d833d1eed2abec83500bf Mon Sep 17 00:00:00 2001 From: Thomas Groh Date: Wed, 18 May 2016 16:53:49 -0700 Subject: [PATCH 03/21] Add DefaultPipelineOptionsRegistrar This registers all of the PipelineOptions classes in the SDK. This ensures that SDK options are registered regardless of the class hierarchy of registered runner options. --- .../DefaultPipelineOptionsRegistrar.java | 42 +++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 sdks/java/core/src/main/java/org/apache/beam/sdk/options/DefaultPipelineOptionsRegistrar.java diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/options/DefaultPipelineOptionsRegistrar.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/options/DefaultPipelineOptionsRegistrar.java new file mode 100644 index 000000000000..069c10983399 --- /dev/null +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/options/DefaultPipelineOptionsRegistrar.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.sdk.options; + +import com.google.auto.service.AutoService; +import com.google.common.collect.ImmutableList; + +/** + * A registrar containing the default SDK options. + */ +@AutoService(PipelineOptionsRegistrar.class) +public class DefaultPipelineOptionsRegistrar implements PipelineOptionsRegistrar { + @Override + public Iterable> getPipelineOptions() { + return ImmutableList.>builder() + .add(PipelineOptions.class) + .add(ApplicationNameOptions.class) + .add(StreamingOptions.class) + .add(BigQueryOptions.class) + .add(GcpOptions.class) + .add(GcsOptions.class) + .add(GoogleApiDebugOptions.class) + .add(PubsubOptions.class) + .build(); + } +} From acad55e7ea8b51a77707c5db41ce88b0f4f8808e Mon Sep 17 00:00:00 2001 From: Jianfeng Qian Date: Thu, 19 May 2016 19:01:02 +0800 Subject: [PATCH 04/21] update of flink README.me version of org.apache.beam should be 0.1.0-incubating-SNAPSHOT. change line 145 0.4-SNAPSHOT to 0.1.0-incubating-SNAPSHOT --- runners/flink/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/runners/flink/README.md b/runners/flink/README.md index 7418f16fe6b4..297ef7b15304 100644 --- a/runners/flink/README.md +++ b/runners/flink/README.md @@ -142,7 +142,7 @@ The contents of the root `pom.xml` should be slightly changed aftewards (explana org.apache.beam flink-runner_2.10 - 0.4-SNAPSHOT + 0.1.0-incubating-SNAPSHOT @@ -200,4 +200,4 @@ folder to the Flink cluster using the command-line utility like so: # More For more information, please visit the [Apache Flink Website](http://flink.apache.org) or contact -the [Mailinglists](http://flink.apache.org/community.html#mailing-lists). \ No newline at end of file +the [Mailinglists](http://flink.apache.org/community.html#mailing-lists). From 26941f152cb5bed422ff14ccb10403604a611130 Mon Sep 17 00:00:00 2001 From: Mark Shields Date: Mon, 11 Apr 2016 17:36:27 -0700 Subject: [PATCH 05/21] PubsubIO: integrate the new PubsubUnboundedSource and Sink --- .../dataflow/DataflowPipelineRunner.java | 249 +++++++++++++++--- .../dataflow/DataflowPipelineTranslator.java | 8 - .../dataflow/internal/PubsubIOTranslator.java | 108 -------- .../dataflow/io/DataflowPubsubIOTest.java | 13 +- .../java/org/apache/beam/sdk/io/PubsubIO.java | 102 ++++--- .../beam/sdk/io/PubsubUnboundedSink.java | 67 ++++- .../beam/sdk/io/PubsubUnboundedSource.java | 131 +++++++-- .../beam/sdk/util/PubsubApiaryClient.java | 20 +- .../apache/beam/sdk/util/PubsubClient.java | 82 ++++-- .../beam/sdk/util/PubsubGrpcClient.java | 19 +- .../beam/sdk/util/PubsubTestClient.java | 21 +- .../beam/sdk/io/PubsubUnboundedSinkTest.java | 48 ++-- .../sdk/io/PubsubUnboundedSourceTest.java | 8 +- .../beam/sdk/util/PubsubApiaryClientTest.java | 12 +- .../beam/sdk/util/PubsubGrpcClientTest.java | 12 +- .../beam/sdk/util/PubsubTestClientTest.java | 12 +- 16 files changed, 604 insertions(+), 308 deletions(-) delete mode 100755 runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/internal/PubsubIOTranslator.java diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineRunner.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineRunner.java index 88018965bff7..0c77191a7bd1 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineRunner.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineRunner.java @@ -33,7 +33,6 @@ import org.apache.beam.runners.dataflow.internal.IsmFormat.IsmRecord; import org.apache.beam.runners.dataflow.internal.IsmFormat.IsmRecordCoder; import org.apache.beam.runners.dataflow.internal.IsmFormat.MetadataKeyCoder; -import org.apache.beam.runners.dataflow.internal.PubsubIOTranslator; import org.apache.beam.runners.dataflow.internal.ReadTranslator; import org.apache.beam.runners.dataflow.options.DataflowPipelineDebugOptions; import org.apache.beam.runners.dataflow.options.DataflowPipelineOptions; @@ -63,6 +62,8 @@ import org.apache.beam.sdk.io.BigQueryIO; import org.apache.beam.sdk.io.FileBasedSink; import org.apache.beam.sdk.io.PubsubIO; +import org.apache.beam.sdk.io.PubsubUnboundedSink; +import org.apache.beam.sdk.io.PubsubUnboundedSource; import org.apache.beam.sdk.io.Read; import org.apache.beam.sdk.io.ShardNameTemplate; import org.apache.beam.sdk.io.TextIO; @@ -107,6 +108,7 @@ import org.apache.beam.sdk.util.WindowedValue.FullWindowedValueCoder; import org.apache.beam.sdk.util.WindowingStrategy; import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollection.IsBounded; import org.apache.beam.sdk.values.PCollectionList; @@ -177,6 +179,7 @@ import java.util.Set; import java.util.SortedSet; import java.util.TreeSet; +import javax.annotation.Nullable; /** * A {@link PipelineRunner} that executes the operations in the @@ -338,33 +341,46 @@ public static DataflowPipelineRunner fromOptions(PipelineOptions options) { this.pcollectionsRequiringIndexedFormat = new HashSet<>(); this.ptransformViewsWithNonDeterministicKeyCoders = new HashSet<>(); + ImmutableMap.Builder, Class> builder = ImmutableMap., Class>builder(); if (options.isStreaming()) { - overrides = ImmutableMap., Class>builder() - .put(Combine.GloballyAsSingletonView.class, StreamingCombineGloballyAsSingletonView.class) - .put(Create.Values.class, StreamingCreate.class) - .put(View.AsMap.class, StreamingViewAsMap.class) - .put(View.AsMultimap.class, StreamingViewAsMultimap.class) - .put(View.AsSingleton.class, StreamingViewAsSingleton.class) - .put(View.AsList.class, StreamingViewAsList.class) - .put(View.AsIterable.class, StreamingViewAsIterable.class) - .put(Write.Bound.class, StreamingWrite.class) - .put(PubsubIO.Write.Bound.class, StreamingPubsubIOWrite.class) - .put(Read.Unbounded.class, StreamingUnboundedRead.class) - .put(Read.Bounded.class, UnsupportedIO.class) - .put(AvroIO.Read.Bound.class, UnsupportedIO.class) - .put(AvroIO.Write.Bound.class, UnsupportedIO.class) - .put(BigQueryIO.Read.Bound.class, UnsupportedIO.class) - .put(TextIO.Read.Bound.class, UnsupportedIO.class) - .put(TextIO.Write.Bound.class, UnsupportedIO.class) - .put(Window.Bound.class, AssignWindows.class) - .build(); + builder.put(Combine.GloballyAsSingletonView.class, + StreamingCombineGloballyAsSingletonView.class); + builder.put(Create.Values.class, StreamingCreate.class); + builder.put(View.AsMap.class, StreamingViewAsMap.class); + builder.put(View.AsMultimap.class, StreamingViewAsMultimap.class); + builder.put(View.AsSingleton.class, StreamingViewAsSingleton.class); + builder.put(View.AsList.class, StreamingViewAsList.class); + builder.put(View.AsIterable.class, StreamingViewAsIterable.class); + builder.put(Write.Bound.class, StreamingWrite.class); + builder.put(Read.Unbounded.class, StreamingUnboundedRead.class); + builder.put(Read.Bounded.class, UnsupportedIO.class); + builder.put(AvroIO.Read.Bound.class, UnsupportedIO.class); + builder.put(AvroIO.Write.Bound.class, UnsupportedIO.class); + builder.put(BigQueryIO.Read.Bound.class, UnsupportedIO.class); + builder.put(TextIO.Read.Bound.class, UnsupportedIO.class); + builder.put(TextIO.Write.Bound.class, UnsupportedIO.class); + builder.put(Window.Bound.class, AssignWindows.class); + // In streaming mode must use either the custom Pubsub unbounded source/sink or + // defer to Windmill's built-in implementation. + builder.put(PubsubIO.Read.Bound.PubsubBoundedReader.class, UnsupportedIO.class); + builder.put(PubsubIO.Write.Bound.PubsubBoundedWriter.class, UnsupportedIO.class); + if (options.getExperiments() == null + || !options.getExperiments().contains("enable_custom_pubsub_source")) { + builder.put(PubsubUnboundedSource.class, StreamingPubsubIORead.class); + } + if (options.getExperiments() == null + || !options.getExperiments().contains("enable_custom_pubsub_sink")) { + builder.put(PubsubUnboundedSink.class, StreamingPubsubIOWrite.class); + } } else { - ImmutableMap.Builder, Class> builder = ImmutableMap., Class>builder(); builder.put(Read.Unbounded.class, UnsupportedIO.class); builder.put(Window.Bound.class, AssignWindows.class); builder.put(Write.Bound.class, BatchWrite.class); builder.put(AvroIO.Write.Bound.class, BatchAvroIOWrite.class); builder.put(TextIO.Write.Bound.class, BatchTextIOWrite.class); + // In batch mode must use the custom Pubsub bounded source/sink. + builder.put(PubsubUnboundedSource.class, UnsupportedIO.class); + builder.put(PubsubUnboundedSink.class, UnsupportedIO.class); if (options.getExperiments() == null || !options.getExperiments().contains("disable_ism_side_input")) { builder.put(View.AsMap.class, BatchViewAsMap.class); @@ -373,8 +389,8 @@ public static DataflowPipelineRunner fromOptions(PipelineOptions options) { builder.put(View.AsList.class, BatchViewAsList.class); builder.put(View.AsIterable.class, BatchViewAsIterable.class); } - overrides = builder.build(); } + overrides = builder.build(); } /** @@ -2336,27 +2352,104 @@ protected String getKindString() { } } + // ================================================================================ + // PubsubIO translations + // ================================================================================ + /** - * Specialized implementation for - * {@link org.apache.beam.sdk.io.PubsubIO.Write PubsubIO.Write} for the - * Dataflow runner in streaming mode. - * - *

For internal use only. Subject to change at any time. - * - *

Public so the {@link PubsubIOTranslator} can access. + * Suppress application of {@link PubsubUnboundedSource#apply} in streaming mode so that we + * can instead defer to Windmill's implementation. */ - public static class StreamingPubsubIOWrite extends PTransform, PDone> { - private final PubsubIO.Write.Bound transform; + private static class StreamingPubsubIORead extends PTransform> { + private final PubsubUnboundedSource transform; + + /** + * Builds an instance of this class from the overridden transform. + */ + public StreamingPubsubIORead( + DataflowPipelineRunner runner, PubsubUnboundedSource transform) { + this.transform = transform; + } + + PubsubUnboundedSource getOverriddenTransform() { + return transform; + } + + @Override + public PCollection apply(PBegin input) { + return PCollection.createPrimitiveOutputInternal( + input.getPipeline(), WindowingStrategy.globalDefault(), IsBounded.UNBOUNDED) + .setCoder(transform.getElementCoder()); + } + + @Override + protected String getKindString() { + return "StreamingPubsubIORead"; + } + + static { + DataflowPipelineTranslator.registerTransformTranslator( + StreamingPubsubIORead.class, new StreamingPubsubIOReadTranslator()); + } + } + + /** + * Rewrite {@link StreamingPubsubIORead} to the appropriate internal node. + */ + private static class StreamingPubsubIOReadTranslator implements + TransformTranslator { + @Override + @SuppressWarnings({"rawtypes", "unchecked"}) + public void translate( + StreamingPubsubIORead transform, + TranslationContext context) { + translateTyped(transform, context); + } + + private void translateTyped( + StreamingPubsubIORead transform, + TranslationContext context) { + checkArgument(context.getPipelineOptions().isStreaming(), + "StreamingPubsubIORead is only for streaming pipelines."); + PubsubUnboundedSource overriddenTransform = transform.getOverriddenTransform(); + context.addStep(transform, "ParallelRead"); + context.addInput(PropertyNames.FORMAT, "pubsub"); + if (overriddenTransform.getTopic() != null) { + context.addInput(PropertyNames.PUBSUB_TOPIC, + overriddenTransform.getTopic().getV1Beta1Path()); + } + if (overriddenTransform.getSubscription() != null) { + context.addInput( + PropertyNames.PUBSUB_SUBSCRIPTION, + overriddenTransform.getSubscription().getV1Beta1Path()); + } + if (overriddenTransform.getTimestampLabel() != null) { + context.addInput(PropertyNames.PUBSUB_TIMESTAMP_LABEL, + overriddenTransform.getTimestampLabel()); + } + if (overriddenTransform.getIdLabel() != null) { + context.addInput(PropertyNames.PUBSUB_ID_LABEL, overriddenTransform.getIdLabel()); + } + context.addValueOnlyOutput(PropertyNames.OUTPUT, context.getOutput(transform)); + } + } + + /** + * Suppress application of {@link PubsubUnboundedSink#apply} in streaming mode so that we + * can instead defer to Windmill's implementation. + */ + private static class StreamingPubsubIOWrite extends PTransform, PDone> { + private final PubsubUnboundedSink transform; /** * Builds an instance of this class from the overridden transform. */ public StreamingPubsubIOWrite( - DataflowPipelineRunner runner, PubsubIO.Write.Bound transform) { + DataflowPipelineRunner runner, PubsubUnboundedSink transform) { this.transform = transform; } - public PubsubIO.Write.Bound getOverriddenTransform() { + PubsubUnboundedSink getOverriddenTransform() { return transform; } @@ -2369,8 +2462,51 @@ public PDone apply(PCollection input) { protected String getKindString() { return "StreamingPubsubIOWrite"; } + + static { + DataflowPipelineTranslator.registerTransformTranslator( + StreamingPubsubIOWrite.class, new StreamingPubsubIOWriteTranslator()); + } + } + + /** + * Rewrite {@link StreamingPubsubIOWrite} to the appropriate internal node. + */ + private static class StreamingPubsubIOWriteTranslator implements + TransformTranslator { + + @Override + @SuppressWarnings({"rawtypes", "unchecked"}) + public void translate( + StreamingPubsubIOWrite transform, + TranslationContext context) { + translateTyped(transform, context); + } + + private void translateTyped( + StreamingPubsubIOWrite transform, + TranslationContext context) { + checkArgument(context.getPipelineOptions().isStreaming(), + "StreamingPubsubIOWrite is only for streaming pipelines."); + PubsubUnboundedSink overriddenTransform = transform.getOverriddenTransform(); + context.addStep(transform, "ParallelWrite"); + context.addInput(PropertyNames.FORMAT, "pubsub"); + context.addInput(PropertyNames.PUBSUB_TOPIC, overriddenTransform.getTopic().getV1Beta1Path()); + if (overriddenTransform.getTimestampLabel() != null) { + context.addInput(PropertyNames.PUBSUB_TIMESTAMP_LABEL, + overriddenTransform.getTimestampLabel()); + } + if (overriddenTransform.getIdLabel() != null) { + context.addInput(PropertyNames.PUBSUB_ID_LABEL, overriddenTransform.getIdLabel()); + } + context.addEncodingInput( + WindowedValue.getValueOnlyCoder(overriddenTransform.getElementCoder())); + context.addInput(PropertyNames.PARALLEL_INPUT, context.getInput(transform)); + } } + // ================================================================================ + /** * Specialized implementation for * {@link org.apache.beam.sdk.io.Read.Unbounded Read.Unbounded} for the @@ -2912,11 +3048,14 @@ public Coder> getDefaultOutputCoder(CoderRegistry registry, Coder inp } /** - * Specialized expansion for unsupported IO transforms that throws an error. + * Specialized expansion for unsupported IO transforms and DoFns that throws an error. */ private static class UnsupportedIO extends PTransform { + @Nullable private PTransform transform; + @Nullable + private DoFn doFn; /** * Builds an instance of this class from the overridden transform. @@ -2974,13 +3113,51 @@ public UnsupportedIO(DataflowPipelineRunner runner, TextIO.Write.Bound transf this.transform = transform; } + /** + * Builds an instance of this class from the overridden doFn. + */ + @SuppressWarnings("unused") // used via reflection in DataflowPipelineRunner#apply() + public UnsupportedIO(DataflowPipelineRunner runner, + PubsubIO.Read.Bound.PubsubBoundedReader doFn) { + this.doFn = doFn; + } + + /** + * Builds an instance of this class from the overridden doFn. + */ + @SuppressWarnings("unused") // used via reflection in DataflowPipelineRunner#apply() + public UnsupportedIO(DataflowPipelineRunner runner, + PubsubIO.Write.Bound.PubsubBoundedWriter doFn) { + this.doFn = doFn; + } + + /** + * Builds an instance of this class from the overridden transform. + */ + @SuppressWarnings("unused") // used via reflection in DataflowPipelineRunner#apply() + public UnsupportedIO(DataflowPipelineRunner runner, PubsubUnboundedSource transform) { + this.transform = transform; + } + + /** + * Builds an instance of this class from the overridden transform. + */ + @SuppressWarnings("unused") // used via reflection in DataflowPipelineRunner#apply() + public UnsupportedIO(DataflowPipelineRunner runner, PubsubUnboundedSink transform) { + this.transform = transform; + } + + @Override public OutputT apply(InputT input) { String mode = input.getPipeline().getOptions().as(StreamingOptions.class).isStreaming() ? "streaming" : "batch"; + String name = + transform == null + ? approximateSimpleName(doFn.getClass()) + : approximatePTransformName(transform.getClass()); throw new UnsupportedOperationException( - String.format("The DataflowPipelineRunner in %s mode does not support %s.", - mode, approximatePTransformName(transform.getClass()))); + String.format("The DataflowPipelineRunner in %s mode does not support %s.", mode, name)); } } diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java index d82280393910..7f673932f33d 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java @@ -32,7 +32,6 @@ import static com.google.common.base.Preconditions.checkNotNull; import org.apache.beam.runners.dataflow.DataflowPipelineRunner.GroupByKeyAndSortValuesOnly; -import org.apache.beam.runners.dataflow.internal.PubsubIOTranslator; import org.apache.beam.runners.dataflow.internal.ReadTranslator; import org.apache.beam.runners.dataflow.options.DataflowPipelineOptions; import org.apache.beam.runners.dataflow.util.DoFnInfo; @@ -41,7 +40,6 @@ import org.apache.beam.sdk.Pipeline.PipelineVisitor; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.IterableCoder; -import org.apache.beam.sdk.io.PubsubIO; import org.apache.beam.sdk.io.Read; import org.apache.beam.sdk.options.StreamingOptions; import org.apache.beam.sdk.runners.TransformTreeNode; @@ -1009,12 +1007,6 @@ private void translateHelper( /////////////////////////////////////////////////////////////////////////// // IO Translation. - registerTransformTranslator( - PubsubIO.Read.Bound.class, new PubsubIOTranslator.ReadTranslator()); - registerTransformTranslator( - DataflowPipelineRunner.StreamingPubsubIOWrite.class, - new PubsubIOTranslator.WriteTranslator()); - registerTransformTranslator(Read.Bounded.class, new ReadTranslator()); } diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/internal/PubsubIOTranslator.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/internal/PubsubIOTranslator.java deleted file mode 100755 index 976f948dd1c2..000000000000 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/internal/PubsubIOTranslator.java +++ /dev/null @@ -1,108 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.dataflow.internal; - -import org.apache.beam.runners.dataflow.DataflowPipelineRunner; -import org.apache.beam.runners.dataflow.DataflowPipelineTranslator.TransformTranslator; -import org.apache.beam.runners.dataflow.DataflowPipelineTranslator.TranslationContext; -import org.apache.beam.sdk.io.PubsubIO; -import org.apache.beam.sdk.util.PropertyNames; -import org.apache.beam.sdk.util.WindowedValue; - -/** - * Pubsub transform support code for the Dataflow backend. - */ -public class PubsubIOTranslator { - - /** - * Implements PubsubIO Read translation for the Dataflow backend. - */ - public static class ReadTranslator implements TransformTranslator> { - @Override - @SuppressWarnings({"rawtypes", "unchecked"}) - public void translate( - PubsubIO.Read.Bound transform, - TranslationContext context) { - translateReadHelper(transform, context); - } - - private void translateReadHelper( - PubsubIO.Read.Bound transform, - TranslationContext context) { - if (!context.getPipelineOptions().isStreaming()) { - throw new IllegalArgumentException( - "PubsubIO.Read can only be used with the Dataflow streaming runner."); - } - - context.addStep(transform, "ParallelRead"); - context.addInput(PropertyNames.FORMAT, "pubsub"); - if (transform.getTopic() != null) { - context.addInput(PropertyNames.PUBSUB_TOPIC, transform.getTopic().asV1Beta1Path()); - } - if (transform.getSubscription() != null) { - context.addInput( - PropertyNames.PUBSUB_SUBSCRIPTION, transform.getSubscription().asV1Beta1Path()); - } - if (transform.getTimestampLabel() != null) { - context.addInput(PropertyNames.PUBSUB_TIMESTAMP_LABEL, transform.getTimestampLabel()); - } - if (transform.getIdLabel() != null) { - context.addInput(PropertyNames.PUBSUB_ID_LABEL, transform.getIdLabel()); - } - context.addValueOnlyOutput(PropertyNames.OUTPUT, context.getOutput(transform)); - } - } - - /** - * Implements PubsubIO Write translation for the Dataflow backend. - */ - public static class WriteTranslator - implements TransformTranslator> { - - @Override - @SuppressWarnings({"rawtypes", "unchecked"}) - public void translate( - DataflowPipelineRunner.StreamingPubsubIOWrite transform, - TranslationContext context) { - translateWriteHelper(transform, context); - } - - private void translateWriteHelper( - DataflowPipelineRunner.StreamingPubsubIOWrite customTransform, - TranslationContext context) { - if (!context.getPipelineOptions().isStreaming()) { - throw new IllegalArgumentException( - "PubsubIO.Write is non-primitive for the Dataflow batch runner."); - } - - PubsubIO.Write.Bound transform = customTransform.getOverriddenTransform(); - - context.addStep(customTransform, "ParallelWrite"); - context.addInput(PropertyNames.FORMAT, "pubsub"); - context.addInput(PropertyNames.PUBSUB_TOPIC, transform.getTopic().asV1Beta1Path()); - if (transform.getTimestampLabel() != null) { - context.addInput(PropertyNames.PUBSUB_TIMESTAMP_LABEL, transform.getTimestampLabel()); - } - if (transform.getIdLabel() != null) { - context.addInput(PropertyNames.PUBSUB_ID_LABEL, transform.getIdLabel()); - } - context.addEncodingInput(WindowedValue.getValueOnlyCoder(transform.getCoder())); - context.addInput(PropertyNames.PARALLEL_INPUT, context.getInput(customTransform)); - } - } -} diff --git a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/io/DataflowPubsubIOTest.java b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/io/DataflowPubsubIOTest.java index 4874877d73fd..3df9cdb9d4e7 100644 --- a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/io/DataflowPubsubIOTest.java +++ b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/io/DataflowPubsubIOTest.java @@ -42,21 +42,22 @@ public class DataflowPubsubIOTest { @Test public void testPrimitiveWriteDisplayData() { DisplayDataEvaluator evaluator = DataflowDisplayDataEvaluator.create(); - PubsubIO.Write.Bound write = PubsubIO.Write - .topic("projects/project/topics/topic"); + PubsubIO.Write.Bound write = PubsubIO.Write.topic("projects/project/topics/topic"); Set displayData = evaluator.displayDataForPrimitiveTransforms(write); assertThat("PubsubIO.Write should include the topic in its primitive display data", - displayData, hasItem(hasDisplayItem("topic"))); + displayData, hasItem(hasDisplayItem("topic"))); } @Test public void testPrimitiveReadDisplayData() { DisplayDataEvaluator evaluator = DataflowDisplayDataEvaluator.create(); - PubsubIO.Read.Bound read = PubsubIO.Read.topic("projects/project/topics/topic"); + PubsubIO.Read.Bound read = + PubsubIO.Read.subscription("projects/project/subscriptions/subscription") + .maxNumRecords(1); Set displayData = evaluator.displayDataForPrimitiveTransforms(read); - assertThat("PubsubIO.Read should include the topic in its primitive display data", - displayData, hasItem(hasDisplayItem("topic"))); + assertThat("PubsubIO.Read should include the subscription in its primitive display data", + displayData, hasItem(hasDisplayItem("subscription"))); } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/PubsubIO.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/PubsubIO.java index 78fec852c666..23a11401dedb 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/PubsubIO.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/PubsubIO.java @@ -36,10 +36,10 @@ import org.apache.beam.sdk.util.PubsubClient; import org.apache.beam.sdk.util.PubsubClient.IncomingMessage; import org.apache.beam.sdk.util.PubsubClient.OutgoingMessage; +import org.apache.beam.sdk.util.PubsubClient.ProjectPath; +import org.apache.beam.sdk.util.PubsubClient.SubscriptionPath; import org.apache.beam.sdk.util.PubsubClient.TopicPath; -import org.apache.beam.sdk.util.WindowingStrategy; import org.apache.beam.sdk.values.PCollection; -import org.apache.beam.sdk.values.PCollection.IsBounded; import org.apache.beam.sdk.values.PDone; import org.apache.beam.sdk.values.PInput; @@ -54,7 +54,6 @@ import java.io.Serializable; import java.util.ArrayList; import java.util.List; -import java.util.Random; import java.util.regex.Matcher; import java.util.regex.Pattern; import javax.annotation.Nullable; @@ -634,12 +633,12 @@ public Bound maxReadTime(Duration maxReadTime) { @Override public PCollection apply(PInput input) { if (topic == null && subscription == null) { - throw new IllegalStateException("need to set either the topic or the subscription for " + throw new IllegalStateException("Need to set either the topic or the subscription for " + "a PubsubIO.Read transform"); } if (topic != null && subscription != null) { - throw new IllegalStateException("Can't set both the topic and the subscription for a " - + "PubsubIO.Read transform"); + throw new IllegalStateException("Can't set both the topic and the subscription for " + + "a PubsubIO.Read transform"); } boolean boundedOutput = getMaxNumRecords() > 0 || getMaxReadTime() != null; @@ -649,9 +648,19 @@ public PCollection apply(PInput input) { .apply(Create.of((Void) null)).setCoder(VoidCoder.of()) .apply(ParDo.of(new PubsubBoundedReader())).setCoder(coder); } else { - return PCollection.createPrimitiveOutputInternal( - input.getPipeline(), WindowingStrategy.globalDefault(), IsBounded.UNBOUNDED) - .setCoder(coder); + @Nullable ProjectPath projectPath = + topic == null ? null : PubsubClient.projectPathFromId(topic.project); + @Nullable TopicPath topicPath = + topic == null ? null : PubsubClient.topicPathFromName(topic.project, topic.topic); + @Nullable SubscriptionPath subscriptionPath = + subscription == null + ? null + : PubsubClient + .subscriptionPathFromName(subscription.project, subscription.subscription); + return input.getPipeline().begin() + .apply(new PubsubUnboundedSource( + FACTORY, projectPath, topicPath, subscriptionPath, + coder, timestampLabel, idLabel)); } } @@ -707,12 +716,16 @@ public Duration getMaxReadTime() { /** * Default reader when Pubsub subscription has some form of upper bound. - *

TODO: Consider replacing with BoundedReadFromUnboundedSource on top of upcoming - * PubsubUnboundedSource. - *

NOTE: This is not the implementation used when running on the Google Dataflow hosted - * service. + * + *

TODO: Consider replacing with BoundedReadFromUnboundedSource on top + * of PubsubUnboundedSource. + * + *

NOTE: This is not the implementation used when running on the Google Cloud Dataflow + * service in streaming mode. + * + *

Public so can be suppressed by runners. */ - private class PubsubBoundedReader extends DoFn { + public class PubsubBoundedReader extends DoFn { private static final int DEFAULT_PULL_SIZE = 100; private static final int ACK_TIMEOUT_SEC = 60; @@ -724,20 +737,20 @@ public void processElement(ProcessContext c) throws IOException { PubsubClient.SubscriptionPath subscriptionPath; if (getSubscription() == null) { - // Create a randomized subscription derived from the topic name. - String subscription = getTopic().topic + "_dataflow_" + new Random().nextLong(); + TopicPath topicPath = + PubsubClient.topicPathFromName(getTopic().project, getTopic().topic); // The subscription will be registered under this pipeline's project if we know it. // Otherwise we'll fall back to the topic's project. // Note that they don't need to be the same. - String project = c.getPipelineOptions().as(PubsubOptions.class).getProject(); - if (Strings.isNullOrEmpty(project)) { - project = getTopic().project; + String projectId = + c.getPipelineOptions().as(PubsubOptions.class).getProject(); + if (Strings.isNullOrEmpty(projectId)) { + projectId = getTopic().project; } - subscriptionPath = PubsubClient.subscriptionPathFromName(project, subscription); - TopicPath topicPath = - PubsubClient.topicPathFromName(getTopic().project, getTopic().topic); + ProjectPath projectPath = PubsubClient.projectPathFromId(projectId); try { - pubsubClient.createSubscription(topicPath, subscriptionPath, ACK_TIMEOUT_SEC); + subscriptionPath = + pubsubClient.createRandomSubscription(projectPath, topicPath, ACK_TIMEOUT_SEC); } catch (Exception e) { throw new RuntimeException("Failed to create subscription: ", e); } @@ -795,6 +808,12 @@ public void processElement(ProcessContext c) throws IOException { } } } + + @Override + public void populateDisplayData(DisplayData.Builder builder) { + super.populateDisplayData(builder); + Bound.this.populateDisplayData(builder); + } } } @@ -961,8 +980,20 @@ public PDone apply(PCollection input) { if (topic == null) { throw new IllegalStateException("need to set the topic of a PubsubIO.Write transform"); } - input.apply(ParDo.of(new PubsubWriter())); - return PDone.in(input.getPipeline()); + switch (input.isBounded()) { + case BOUNDED: + input.apply(ParDo.of(new PubsubBoundedWriter())); + return PDone.in(input.getPipeline()); + case UNBOUNDED: + return input.apply(new PubsubUnboundedSink( + FACTORY, + PubsubClient.topicPathFromName(topic.project, topic.topic), + coder, + timestampLabel, + idLabel, + 100 /* numShards */)); + } + throw new RuntimeException(); // cases are exhaustive. } @Override @@ -993,11 +1024,14 @@ public Coder getCoder() { } /** - * Writer to Pubsub which batches messages. - *

NOTE: This is not the implementation used when running on the Google Dataflow hosted - * service. + * Writer to Pubsub which batches messages from bounded collections. + * + *

NOTE: This is not the implementation used when running on the Google Cloud Dataflow + * service in streaming mode. + * + *

Public so can be suppressed by runners. */ - private class PubsubWriter extends DoFn { + public class PubsubBoundedWriter extends DoFn { private static final int MAX_PUBLISH_BATCH_SIZE = 100; private transient List output; private transient PubsubClient pubsubClient; @@ -1005,15 +1039,18 @@ private class PubsubWriter extends DoFn { @Override public void startBundle(Context c) throws IOException { this.output = new ArrayList<>(); - this.pubsubClient = FACTORY.newClient(timestampLabel, idLabel, - c.getPipelineOptions().as(PubsubOptions.class)); + // NOTE: idLabel is ignored. + this.pubsubClient = + FACTORY.newClient(timestampLabel, null, + c.getPipelineOptions().as(PubsubOptions.class)); } @Override public void processElement(ProcessContext c) throws IOException { + // NOTE: The record id is always null. OutgoingMessage message = new OutgoingMessage(CoderUtils.encodeToByteArray(getCoder(), c.element()), - c.timestamp().getMillis()); + c.timestamp().getMillis(), null); output.add(message); if (output.size() >= MAX_PUBLISH_BATCH_SIZE) { @@ -1041,6 +1078,7 @@ private void publish() throws IOException { @Override public void populateDisplayData(DisplayData.Builder builder) { + super.populateDisplayData(builder); Bound.this.populateDisplayData(builder); } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/PubsubUnboundedSink.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/PubsubUnboundedSink.java index 7ca2b57bc4cf..6ff9b40d39d0 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/PubsubUnboundedSink.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/PubsubUnboundedSink.java @@ -26,6 +26,8 @@ import org.apache.beam.sdk.coders.CoderException; import org.apache.beam.sdk.coders.CustomCoder; import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.coders.NullableCoder; +import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.coders.VarIntCoder; import org.apache.beam.sdk.options.PubsubOptions; import org.apache.beam.sdk.transforms.Aggregator; @@ -52,6 +54,7 @@ import org.apache.beam.sdk.values.PDone; import com.google.common.annotations.VisibleForTesting; +import com.google.common.hash.Hashing; import org.joda.time.Duration; import org.slf4j.Logger; @@ -62,6 +65,7 @@ import java.io.OutputStream; import java.util.ArrayList; import java.util.List; +import java.util.UUID; import java.util.concurrent.ThreadLocalRandom; import javax.annotation.Nullable; @@ -81,6 +85,8 @@ *

  • A failed bundle will cause messages to be resent. Thus we rely on the Pubsub consumer * to dedup messages. * + * + *

    NOTE: This is not the implementation used when running on the Google Cloud Dataflow service. */ public class PubsubUnboundedSink extends PTransform, PDone> { private static final Logger LOG = LoggerFactory.getLogger(PubsubUnboundedSink.class); @@ -104,12 +110,16 @@ public class PubsubUnboundedSink extends PTransform, PDone> { * Coder for conveying outgoing messages between internal stages. */ private static class OutgoingMessageCoder extends CustomCoder { + private static final NullableCoder RECORD_ID_CODER = + NullableCoder.of(StringUtf8Coder.of()); + @Override public void encode( OutgoingMessage value, OutputStream outStream, Context context) throws CoderException, IOException { ByteArrayCoder.of().encode(value.elementBytes, outStream, Context.NESTED); BigEndianLongCoder.of().encode(value.timestampMsSinceEpoch, outStream, Context.NESTED); + RECORD_ID_CODER.encode(value.recordId, outStream, Context.NESTED); } @Override @@ -117,13 +127,31 @@ public OutgoingMessage decode( InputStream inStream, Context context) throws CoderException, IOException { byte[] elementBytes = ByteArrayCoder.of().decode(inStream, Context.NESTED); long timestampMsSinceEpoch = BigEndianLongCoder.of().decode(inStream, Context.NESTED); - return new OutgoingMessage(elementBytes, timestampMsSinceEpoch); + @Nullable String recordId = RECORD_ID_CODER.decode(inStream, Context.NESTED); + return new OutgoingMessage(elementBytes, timestampMsSinceEpoch, recordId); } } @VisibleForTesting static final Coder CODER = new OutgoingMessageCoder(); + // ================================================================================ + // RecordIdMethod + // ================================================================================ + + /** + * Specify how record ids are to be generated. + */ + @VisibleForTesting + enum RecordIdMethod { + /** Leave null. */ + NONE, + /** Generate randomly. */ + RANDOM, + /** Generate deterministically. For testing only. */ + DETERMINISTIC + } + // ================================================================================ // ShardFn // ================================================================================ @@ -136,10 +164,12 @@ private static class ShardFn extends DoFn> { createAggregator("elements", new Sum.SumLongFn()); private final Coder elementCoder; private final int numShards; + private final RecordIdMethod recordIdMethod; - ShardFn(Coder elementCoder, int numShards) { + ShardFn(Coder elementCoder, int numShards, RecordIdMethod recordIdMethod) { this.elementCoder = elementCoder; this.numShards = numShards; + this.recordIdMethod = recordIdMethod; } @Override @@ -147,9 +177,23 @@ public void processElement(ProcessContext c) throws Exception { elementCounter.addValue(1L); byte[] elementBytes = CoderUtils.encodeToByteArray(elementCoder, c.element()); long timestampMsSinceEpoch = c.timestamp().getMillis(); - // TODO: A random record id should be assigned here. + @Nullable String recordId = null; + switch (recordIdMethod) { + case NONE: + break; + case DETERMINISTIC: + recordId = Hashing.murmur3_128().hashBytes(elementBytes).toString(); + break; + case RANDOM: + // Since these elements go through a GroupByKey, any failures while sending to + // Pubsub will be retried without falling back and generating a new record id. + // Thus even though we may send the same message to Pubsub twice, it is guaranteed + // to have the same record id. + recordId = UUID.randomUUID().toString(); + break; + } c.output(KV.of(ThreadLocalRandom.current().nextInt(numShards), - new OutgoingMessage(elementBytes, timestampMsSinceEpoch))); + new OutgoingMessage(elementBytes, timestampMsSinceEpoch, recordId))); } @Override @@ -319,6 +363,12 @@ public void populateDisplayData(Builder builder) { */ private final Duration maxLatency; + /** + * How record ids should be generated for each record (if {@link #idLabel} is non-{@literal + * null}). + */ + private final RecordIdMethod recordIdMethod; + @VisibleForTesting PubsubUnboundedSink( PubsubClientFactory pubsubFactory, @@ -329,7 +379,8 @@ public void populateDisplayData(Builder builder) { int numShards, int publishBatchSize, int publishBatchBytes, - Duration maxLatency) { + Duration maxLatency, + RecordIdMethod recordIdMethod) { this.pubsubFactory = pubsubFactory; this.topic = topic; this.elementCoder = elementCoder; @@ -339,6 +390,7 @@ public void populateDisplayData(Builder builder) { this.publishBatchSize = publishBatchSize; this.publishBatchBytes = publishBatchBytes; this.maxLatency = maxLatency; + this.recordIdMethod = idLabel == null ? RecordIdMethod.NONE : recordIdMethod; } public PubsubUnboundedSink( @@ -349,7 +401,8 @@ public PubsubUnboundedSink( String idLabel, int numShards) { this(pubsubFactory, topic, elementCoder, timestampLabel, idLabel, numShards, - DEFAULT_PUBLISH_BATCH_SIZE, DEFAULT_PUBLISH_BATCH_BYTES, DEFAULT_MAX_LATENCY); + DEFAULT_PUBLISH_BATCH_SIZE, DEFAULT_PUBLISH_BATCH_BYTES, DEFAULT_MAX_LATENCY, + RecordIdMethod.RANDOM); } public TopicPath getTopic() { @@ -382,7 +435,7 @@ public PDone apply(PCollection input) { .plusDelayOf(maxLatency)))) .discardingFiredPanes()) .apply(ParDo.named("PubsubUnboundedSink.Shard") - .of(new ShardFn(elementCoder, numShards))) + .of(new ShardFn(elementCoder, numShards, recordIdMethod))) .setCoder(KvCoder.of(VarIntCoder.of(), CODER)) .apply(GroupByKey.create()) .apply(ParDo.named("PubsubUnboundedSink.Writer") diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/PubsubUnboundedSource.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/PubsubUnboundedSource.java index d635a8a3860b..0492c7623677 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/PubsubUnboundedSource.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/PubsubUnboundedSource.java @@ -18,6 +18,7 @@ package org.apache.beam.sdk.io; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; @@ -42,13 +43,16 @@ import org.apache.beam.sdk.util.CoderUtils; import org.apache.beam.sdk.util.MovingFunction; import org.apache.beam.sdk.util.PubsubClient; +import org.apache.beam.sdk.util.PubsubClient.ProjectPath; import org.apache.beam.sdk.util.PubsubClient.PubsubClientFactory; import org.apache.beam.sdk.util.PubsubClient.SubscriptionPath; +import org.apache.beam.sdk.util.PubsubClient.TopicPath; import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; import com.google.api.client.util.Clock; import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Charsets; import com.google.common.collect.Iterables; import com.google.common.collect.Lists; @@ -102,10 +106,17 @@ * are blocking. We rely on the underlying runner to allow multiple * {@link UnboundedSource.UnboundedReader} instances to execute concurrently and thus hide latency. * + * + *

    NOTE: This is not the implementation used when running on the Google Cloud Dataflow service. */ public class PubsubUnboundedSource extends PTransform> { private static final Logger LOG = LoggerFactory.getLogger(PubsubUnboundedSource.class); + /** + * Default ACK timeout for created subscriptions. + */ + private static final int DEAULT_ACK_TIMEOUT_SEC = 60; + /** * Coder for checkpoints. */ @@ -291,6 +302,17 @@ public void finalizeCheckpoint() throws IOException { } } + /** + * Return current time according to {@code reader}. + */ + private static long now(PubsubReader reader) { + if (reader.outer.outer.clock == null) { + return System.currentTimeMillis(); + } else { + return reader.outer.outer.clock.currentTimeMillis(); + } + } + /** * BLOCKING * NACK all messages which have been read from Pubsub but not passed downstream. @@ -303,13 +325,13 @@ public void nackAll(PubsubReader reader) throws IOException { for (String ackId : notYetReadIds) { batchYetToAckIds.add(ackId); if (batchYetToAckIds.size() >= ACK_BATCH_SIZE) { - long nowMsSinceEpoch = reader.outer.outer.clock.currentTimeMillis(); + long nowMsSinceEpoch = now(reader); reader.nackBatch(nowMsSinceEpoch, batchYetToAckIds); batchYetToAckIds.clear(); } } if (!batchYetToAckIds.isEmpty()) { - long nowMsSinceEpoch = reader.outer.outer.clock.currentTimeMillis(); + long nowMsSinceEpoch = now(reader); reader.nackBatch(nowMsSinceEpoch, batchYetToAckIds); } } @@ -614,7 +636,11 @@ private void extendBatch(long nowMsSinceEpoch, List ackIds) throws IOExc * Return the current time, in ms since epoch. */ private long now() { - return outer.outer.clock.currentTimeMillis(); + if (outer.outer.clock == null) { + return System.currentTimeMillis(); + } else { + return outer.outer.clock.currentTimeMillis(); + } } /** @@ -928,7 +954,7 @@ public byte[] getCurrentRecordId() throws NoSuchElementException { if (current == null) { throw new NoSuchElementException(); } - return current.recordId; + return current.recordId.getBytes(Charsets.UTF_8); } @Override @@ -1124,8 +1150,9 @@ public void populateDisplayData(Builder builder) { // ================================================================================ /** - * Clock to use for all timekeeping. + * For testing only: Clock to use for all timekeeping. If {@literal null} use system clock. */ + @Nullable private Clock clock; /** @@ -1134,9 +1161,28 @@ public void populateDisplayData(Builder builder) { private final PubsubClientFactory pubsubFactory; /** - * Subscription to read from. + * Project under which to create a subscription if only the {@link #topic} was given. + */ + @Nullable + private final ProjectPath project; + + /** + * Topic to read from. If {@literal null}, then {@link #subscription} must be given. + * Otherwise {@link #subscription} must be null. */ - private final SubscriptionPath subscription; + @Nullable + private final TopicPath topic; + + /** + * Subscription to read from. If {@literal null} then {@link #topic} must be given. + * Otherwise {@link #topic} must be null. + * + *

    If no subscription is given a random one will be created when the transorm is + * applied. This field will be update with that subscription's path. The created + * subscription is never deleted. + */ + @Nullable + private SubscriptionPath subscription; /** * Coder for elements. Elements are effectively double-encoded: first to a byte array @@ -1159,25 +1205,60 @@ public void populateDisplayData(Builder builder) { @Nullable private final String idLabel; - /** - * Construct an unbounded source to consume from the Pubsub {@code subscription}. - */ - public PubsubUnboundedSource( + @VisibleForTesting + PubsubUnboundedSource( Clock clock, PubsubClientFactory pubsubFactory, - SubscriptionPath subscription, + @Nullable ProjectPath project, + @Nullable TopicPath topic, + @Nullable SubscriptionPath subscription, Coder elementCoder, @Nullable String timestampLabel, @Nullable String idLabel) { + checkArgument((topic == null) != (subscription == null), + "Exactly one of topic and subscription must be given"); + checkArgument((topic == null) == (project == null), + "Project must be given if topic is given"); this.clock = clock; this.pubsubFactory = checkNotNull(pubsubFactory); - this.subscription = checkNotNull(subscription); + this.project = project; + this.topic = topic; + this.subscription = subscription; this.elementCoder = checkNotNull(elementCoder); this.timestampLabel = timestampLabel; this.idLabel = idLabel; } - public PubsubClient.SubscriptionPath getSubscription() { + /** + * Construct an unbounded source to consume from the Pubsub {@code subscription}. + */ + public PubsubUnboundedSource( + PubsubClientFactory pubsubFactory, + @Nullable ProjectPath project, + @Nullable TopicPath topic, + @Nullable SubscriptionPath subscription, + Coder elementCoder, + @Nullable String timestampLabel, + @Nullable String idLabel) { + this(null, pubsubFactory, project, topic, subscription, elementCoder, timestampLabel, idLabel); + } + + public Coder getElementCoder() { + return elementCoder; + } + + @Nullable + public ProjectPath getProject() { + return project; + } + + @Nullable + public TopicPath getTopic() { + return topic; + } + + @Nullable + public SubscriptionPath getSubscription() { return subscription; } @@ -1191,12 +1272,26 @@ public String getIdLabel() { return idLabel; } - public Coder getElementCoder() { - return elementCoder; - } - @Override public PCollection apply(PBegin input) { + if (subscription == null) { + try { + try (PubsubClient pubsubClient = + pubsubFactory.newClient(timestampLabel, idLabel, + input.getPipeline() + .getOptions() + .as(PubsubOptions.class))) { + subscription = + pubsubClient.createRandomSubscription(project, topic, DEAULT_ACK_TIMEOUT_SEC); + LOG.warn("Created subscription {} to topic {}." + + " Note this subscription WILL NOT be deleted when the pipeline terminates", + subscription, topic); + } + } catch (Exception e) { + throw new RuntimeException("Failed to create subscription: ", e); + } + } + return input.getPipeline().begin() .apply(Read.from(new PubsubSource(this))) .apply(ParDo.named("PubsubUnboundedSource.Stats") diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/PubsubApiaryClient.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/PubsubApiaryClient.java index aa73d421b695..08981d01212e 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/PubsubApiaryClient.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/PubsubApiaryClient.java @@ -40,7 +40,6 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Strings; import com.google.common.collect.ImmutableList; -import com.google.common.hash.Hashing; import java.io.IOException; import java.util.ArrayList; @@ -135,11 +134,8 @@ public int publish(TopicPath topic, List outgoingMessages) attributes.put(timestampLabel, String.valueOf(outgoingMessage.timestampMsSinceEpoch)); } - if (idLabel != null) { - // TODO: The id should be associated with the OutgoingMessage so that it is stable - // across retried bundles - attributes.put(idLabel, - Hashing.murmur3_128().hashBytes(outgoingMessage.elementBytes).toString()); + if (idLabel != null && !Strings.isNullOrEmpty(outgoingMessage.recordId)) { + attributes.put(idLabel, outgoingMessage.recordId); } pubsubMessages.add(pubsubMessage); @@ -185,15 +181,13 @@ public List pull( checkState(!Strings.isNullOrEmpty(ackId)); // Record id, if any. - @Nullable byte[] recordId = null; + @Nullable String recordId = null; if (idLabel != null && attributes != null) { - String recordIdString = attributes.get(idLabel); - if (!Strings.isNullOrEmpty(recordIdString)) { - recordId = recordIdString.getBytes(); - } + recordId = attributes.get(idLabel); } - if (recordId == null) { - recordId = pubsubMessage.getMessageId().getBytes(); + if (Strings.isNullOrEmpty(recordId)) { + // Fall back to the Pubsub provided message id. + recordId = pubsubMessage.getMessageId(); } incomingMessages.add(new IncomingMessage(elementBytes, timestampMsSinceEpoch, diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/PubsubClient.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/PubsubClient.java index dc4858e2014f..07ce97df13bd 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/PubsubClient.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/PubsubClient.java @@ -33,6 +33,7 @@ import java.util.Arrays; import java.util.List; import java.util.Map; +import java.util.concurrent.ThreadLocalRandom; import javax.annotation.Nullable; /** @@ -132,6 +133,12 @@ public String getPath() { return path; } + public String getId() { + String[] splits = path.split("/"); + checkState(splits.length == 1, "Malformed project path %s", path); + return splits[1]; + } + @Override public boolean equals(Object o) { if (this == o) { @@ -180,6 +187,12 @@ public String getPath() { return path; } + public String getName() { + String[] splits = path.split("/"); + checkState(splits.length == 4, "Malformed subscription path %s", path); + return splits[3]; + } + public String getV1Beta1Path() { String[] splits = path.split("/"); checkState(splits.length == 4, "Malformed subscription path %s", path); @@ -233,6 +246,12 @@ public String getPath() { return path; } + public String getName() { + String[] splits = path.split("/"); + checkState(splits.length == 4, "Malformed topic path %s", path); + return splits[3]; + } + public String getV1Beta1Path() { String[] splits = path.split("/"); checkState(splits.length == 4, "Malformed topic path %s", path); @@ -286,11 +305,18 @@ public static class OutgoingMessage implements Serializable { */ public final long timestampMsSinceEpoch; - // TODO: Support a record id. + /** + * If using an id label, the record id to associate with this record's metadata so the receiver + * can reject duplicates. Otherwise {@literal null}. + */ + @Nullable + public final String recordId; - public OutgoingMessage(byte[] elementBytes, long timestampMsSinceEpoch) { + public OutgoingMessage( + byte[] elementBytes, long timestampMsSinceEpoch, @Nullable String recordId) { this.elementBytes = elementBytes; this.timestampMsSinceEpoch = timestampMsSinceEpoch; + this.recordId = recordId; } @Override @@ -310,16 +336,14 @@ public boolean equals(Object o) { OutgoingMessage that = (OutgoingMessage) o; - if (timestampMsSinceEpoch != that.timestampMsSinceEpoch) { - return false; - } - return Arrays.equals(elementBytes, that.elementBytes); - + return timestampMsSinceEpoch == that.timestampMsSinceEpoch + && Arrays.equals(elementBytes, that.elementBytes) + && Objects.equal(recordId, that.recordId); } @Override public int hashCode() { - return Objects.hashCode(Arrays.hashCode(elementBytes), timestampMsSinceEpoch); + return Objects.hashCode(Arrays.hashCode(elementBytes), timestampMsSinceEpoch, recordId); } } @@ -353,14 +377,14 @@ public static class IncomingMessage implements Serializable { /** * Id to pass to the runner to distinguish this message from all others. */ - public final byte[] recordId; + public final String recordId; public IncomingMessage( byte[] elementBytes, long timestampMsSinceEpoch, long requestTimeMsSinceEpoch, String ackId, - byte[] recordId) { + String recordId) { this.elementBytes = elementBytes; this.timestampMsSinceEpoch = timestampMsSinceEpoch; this.requestTimeMsSinceEpoch = requestTimeMsSinceEpoch; @@ -390,26 +414,18 @@ public boolean equals(Object o) { IncomingMessage that = (IncomingMessage) o; - if (timestampMsSinceEpoch != that.timestampMsSinceEpoch) { - return false; - } - if (requestTimeMsSinceEpoch != that.requestTimeMsSinceEpoch) { - return false; - } - if (!Arrays.equals(elementBytes, that.elementBytes)) { - return false; - } - if (!ackId.equals(that.ackId)) { - return false; - } - return Arrays.equals(recordId, that.recordId); + return timestampMsSinceEpoch == that.timestampMsSinceEpoch + && requestTimeMsSinceEpoch == that.requestTimeMsSinceEpoch + && ackId.equals(that.ackId) + && recordId.equals(that.recordId) + && Arrays.equals(elementBytes, that.elementBytes); } @Override public int hashCode() { return Objects.hashCode(Arrays.hashCode(elementBytes), timestampMsSinceEpoch, requestTimeMsSinceEpoch, - ackId, Arrays.hashCode(recordId)); + ackId, recordId); } } @@ -484,6 +500,22 @@ public abstract void modifyAckDeadline( public abstract void createSubscription( TopicPath topic, SubscriptionPath subscription, int ackDeadlineSeconds) throws IOException; + /** + * Create a random subscription for {@code topic}. Return the {@link SubscriptionPath}. It + * is the responsibility of the caller to later delete the subscription. + * + * @throws IOException + */ + public SubscriptionPath createRandomSubscription( + ProjectPath project, TopicPath topic, int ackDeadlineSeconds) throws IOException { + // Create a randomized subscription derived from the topic name. + String subscriptionName = topic.getName() + "_beam_" + ThreadLocalRandom.current().nextLong(); + SubscriptionPath subscription = + PubsubClient.subscriptionPathFromName(project.getId(), subscriptionName); + createSubscription(topic, subscription, ackDeadlineSeconds); + return subscription; + } + /** * Delete {@code subscription}. * @@ -507,7 +539,7 @@ public abstract List listSubscriptions(ProjectPath project, To public abstract int ackDeadlineSeconds(SubscriptionPath subscription) throws IOException; /** - * Return {@literal true} if {@link pull} will always return empty list. Actual clients + * Return {@literal true} if {@link #pull} will always return empty list. Actual clients * will return {@literal false}. Test clients may return {@literal true} to signal that all * expected messages have been pulled and the test may complete. */ diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/PubsubGrpcClient.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/PubsubGrpcClient.java index e759513efb25..ac157fb80309 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/PubsubGrpcClient.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/PubsubGrpcClient.java @@ -27,7 +27,6 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Strings; import com.google.common.collect.ImmutableList; -import com.google.common.hash.Hashing; import com.google.protobuf.ByteString; import com.google.protobuf.Timestamp; import com.google.pubsub.v1.AcknowledgeRequest; @@ -257,10 +256,8 @@ public int publish(TopicPath topic, List outgoingMessages) .put(timestampLabel, String.valueOf(outgoingMessage.timestampMsSinceEpoch)); } - if (idLabel != null) { - message.getMutableAttributes() - .put(idLabel, - Hashing.murmur3_128().hashBytes(outgoingMessage.elementBytes).toString()); + if (idLabel != null && !Strings.isNullOrEmpty(outgoingMessage.recordId)) { + message.getMutableAttributes().put(idLabel, outgoingMessage.recordId); } request.addMessages(message); @@ -308,15 +305,13 @@ public List pull( checkState(!Strings.isNullOrEmpty(ackId)); // Record id, if any. - @Nullable byte[] recordId = null; + @Nullable String recordId = null; if (idLabel != null && attributes != null) { - String recordIdString = attributes.get(idLabel); - if (recordIdString != null && !recordIdString.isEmpty()) { - recordId = recordIdString.getBytes(); - } + recordId = attributes.get(idLabel); } - if (recordId == null) { - recordId = pubsubMessage.getMessageId().getBytes(); + if (Strings.isNullOrEmpty(recordId)) { + // Fall back to the Pubsub provided message id. + recordId = pubsubMessage.getMessageId(); } incomingMessages.add(new IncomingMessage(elementBytes, timestampMsSinceEpoch, diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/PubsubTestClient.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/PubsubTestClient.java index c1dfa060cc05..9fa03803836b 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/PubsubTestClient.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/PubsubTestClient.java @@ -46,10 +46,9 @@ public class PubsubTestClient extends PubsubClient { * Mimic the state of the simulated Pubsub 'service'. * * Note that the {@link PubsubTestClientFactory} is serialized/deserialized even when running - * test - * pipelines. Meanwhile it is valid for multiple {@link PubsubTestClient}s to be created from - * the same client factory and run in parallel. Thus we can't enforce aliasing of the following - * data structures over all clients and must resort to a static. + * test pipelines. Meanwhile it is valid for multiple {@link PubsubTestClient}s to be created + * from the same client factory and run in parallel. Thus we can't enforce aliasing of the + * following data structures over all clients and must resort to a static. */ private static class State { /** @@ -69,6 +68,13 @@ private static class State { @Nullable Set remainingExpectedOutgoingMessages; + /** + * Publish mode only: Messages which should throw when first sent to simulate transient publish + * failure. + */ + @Nullable + Set remainingFailingOutgoingMessages; + /** * Pull mode only: Clock from which to get current time. */ @@ -119,11 +125,13 @@ public interface PubsubTestClientFactory extends PubsubClientFactory, Closeable */ public static PubsubTestClientFactory createFactoryForPublish( final TopicPath expectedTopic, - final Iterable expectedOutgoingMessages) { + final Iterable expectedOutgoingMessages, + final Iterable failingOutgoingMessages) { synchronized (STATE) { checkState(!STATE.isActive, "Test still in flight"); STATE.expectedTopic = expectedTopic; STATE.remainingExpectedOutgoingMessages = Sets.newHashSet(expectedOutgoingMessages); + STATE.remainingFailingOutgoingMessages = Sets.newHashSet(failingOutgoingMessages); STATE.isActive = true; } return new PubsubTestClientFactory() { @@ -257,6 +265,9 @@ public int publish( checkState(topic.equals(STATE.expectedTopic), "Topic %s does not match expected %s", topic, STATE.expectedTopic); for (OutgoingMessage outgoingMessage : outgoingMessages) { + if (STATE.remainingFailingOutgoingMessages.remove(outgoingMessage)) { + throw new RuntimeException("Simulating failure for " + outgoingMessage); + } checkState(STATE.remainingExpectedOutgoingMessages.remove(outgoingMessage), "Unexpected outgoing message %s", outgoingMessage); } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/PubsubUnboundedSinkTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/PubsubUnboundedSinkTest.java index b4ef785a5b04..bf70e474d918 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/PubsubUnboundedSinkTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/PubsubUnboundedSinkTest.java @@ -19,6 +19,7 @@ package org.apache.beam.sdk.io; import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.io.PubsubUnboundedSink.RecordIdMethod; import org.apache.beam.sdk.testing.CoderProperties; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.Create; @@ -31,7 +32,7 @@ import org.apache.beam.sdk.util.PubsubTestClient.PubsubTestClientFactory; import com.google.common.collect.ImmutableList; -import com.google.common.collect.Sets; +import com.google.common.hash.Hashing; import org.joda.time.Duration; import org.joda.time.Instant; @@ -41,9 +42,7 @@ import java.io.IOException; import java.util.ArrayList; -import java.util.HashSet; import java.util.List; -import java.util.Set; /** * Test PubsubUnboundedSink. @@ -55,6 +54,7 @@ public class PubsubUnboundedSinkTest { private static final long TIMESTAMP = 1234L; private static final String TIMESTAMP_LABEL = "timestamp"; private static final String ID_LABEL = "id"; + private static final int NUM_SHARDS = 10; private static class Stamp extends DoFn { @Override @@ -63,22 +63,30 @@ public void processElement(ProcessContext c) { } } + private String getRecordId(String data) { + return Hashing.murmur3_128().hashBytes(data.getBytes()).toString(); + } + @Test public void saneCoder() throws Exception { - OutgoingMessage message = new OutgoingMessage(DATA.getBytes(), TIMESTAMP); + OutgoingMessage message = new OutgoingMessage(DATA.getBytes(), TIMESTAMP, getRecordId(DATA)); CoderProperties.coderDecodeEncodeEqual(PubsubUnboundedSink.CODER, message); CoderProperties.coderSerializable(PubsubUnboundedSink.CODER); } @Test public void sendOneMessage() throws IOException { - Set outgoing = - Sets.newHashSet(new OutgoingMessage(DATA.getBytes(), TIMESTAMP)); + List outgoing = + ImmutableList.of(new OutgoingMessage(DATA.getBytes(), TIMESTAMP, getRecordId(DATA))); + int batchSize = 1; + int batchBytes = 1; try (PubsubTestClientFactory factory = - PubsubTestClient.createFactoryForPublish(TOPIC, outgoing)) { + PubsubTestClient.createFactoryForPublish(TOPIC, outgoing, + ImmutableList.of())) { PubsubUnboundedSink sink = new PubsubUnboundedSink<>(factory, TOPIC, StringUtf8Coder.of(), TIMESTAMP_LABEL, ID_LABEL, - 10); + NUM_SHARDS, batchSize, batchBytes, Duration.standardSeconds(2), + RecordIdMethod.DETERMINISTIC); TestPipeline p = TestPipeline.create(); p.apply(Create.of(ImmutableList.of(DATA))) .apply(ParDo.of(new Stamp())) @@ -91,20 +99,22 @@ public void sendOneMessage() throws IOException { @Test public void sendMoreThanOneBatchByNumMessages() throws IOException { - Set outgoing = new HashSet<>(); + List outgoing = new ArrayList<>(); List data = new ArrayList<>(); int batchSize = 2; int batchBytes = 1000; for (int i = 0; i < batchSize * 10; i++) { String str = String.valueOf(i); - outgoing.add(new OutgoingMessage(str.getBytes(), TIMESTAMP)); + outgoing.add(new OutgoingMessage(str.getBytes(), TIMESTAMP, getRecordId(str))); data.add(str); } try (PubsubTestClientFactory factory = - PubsubTestClient.createFactoryForPublish(TOPIC, outgoing)) { + PubsubTestClient.createFactoryForPublish(TOPIC, outgoing, + ImmutableList.of())) { PubsubUnboundedSink sink = new PubsubUnboundedSink<>(factory, TOPIC, StringUtf8Coder.of(), TIMESTAMP_LABEL, ID_LABEL, - 10, batchSize, batchBytes, Duration.standardSeconds(2)); + NUM_SHARDS, batchSize, batchBytes, Duration.standardSeconds(2), + RecordIdMethod.DETERMINISTIC); TestPipeline p = TestPipeline.create(); p.apply(Create.of(data)) .apply(ParDo.of(new Stamp())) @@ -117,7 +127,7 @@ public void sendMoreThanOneBatchByNumMessages() throws IOException { @Test public void sendMoreThanOneBatchByByteSize() throws IOException { - Set outgoing = new HashSet<>(); + List outgoing = new ArrayList<>(); List data = new ArrayList<>(); int batchSize = 100; int batchBytes = 10; @@ -128,15 +138,17 @@ public void sendMoreThanOneBatchByByteSize() throws IOException { sb.append(String.valueOf(n)); } String str = sb.toString(); - outgoing.add(new OutgoingMessage(str.getBytes(), TIMESTAMP)); + outgoing.add(new OutgoingMessage(str.getBytes(), TIMESTAMP, getRecordId(str))); data.add(str); n += str.length(); } try (PubsubTestClientFactory factory = - PubsubTestClient.createFactoryForPublish(TOPIC, outgoing)) { + PubsubTestClient.createFactoryForPublish(TOPIC, outgoing, + ImmutableList.of())) { PubsubUnboundedSink sink = new PubsubUnboundedSink<>(factory, TOPIC, StringUtf8Coder.of(), TIMESTAMP_LABEL, ID_LABEL, - 10, batchSize, batchBytes, Duration.standardSeconds(2)); + NUM_SHARDS, batchSize, batchBytes, Duration.standardSeconds(2), + RecordIdMethod.DETERMINISTIC); TestPipeline p = TestPipeline.create(); p.apply(Create.of(data)) .apply(ParDo.of(new Stamp())) @@ -146,4 +158,8 @@ public void sendMoreThanOneBatchByByteSize() throws IOException { // The PubsubTestClientFactory will assert fail on close if the actual published // message does not match the expected publish message. } + + // TODO: We would like to test that failed Pubsub publish calls cause the already assigned + // (and random) record ids to be reused. However that can't be done without the test runnner + // supporting retrying bundles. } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/PubsubUnboundedSourceTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/PubsubUnboundedSourceTest.java index b265d18dee70..3b0a1c8c00f2 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/PubsubUnboundedSourceTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/PubsubUnboundedSourceTest.java @@ -86,14 +86,14 @@ public long currentTimeMillis() { }; factory = PubsubTestClient.createFactoryForPull(clock, SUBSCRIPTION, ACK_TIMEOUT_S, incoming); PubsubUnboundedSource source = - new PubsubUnboundedSource<>(clock, factory, SUBSCRIPTION, StringUtf8Coder.of(), + new PubsubUnboundedSource<>(clock, factory, null, null, SUBSCRIPTION, StringUtf8Coder.of(), TIMESTAMP_LABEL, ID_LABEL); primSource = new PubsubSource<>(source); } private void setupOneMessage() { setupOneMessage(ImmutableList.of( - new IncomingMessage(DATA.getBytes(), TIMESTAMP, 0, ACK_ID, RECORD_ID.getBytes()))); + new IncomingMessage(DATA.getBytes(), TIMESTAMP, 0, ACK_ID, RECORD_ID))); } @After @@ -211,7 +211,7 @@ public void multipleReaders() throws IOException { for (int i = 0; i < 2; i++) { String data = String.format("data_%d", i); String ackid = String.format("ackid_%d", i); - incoming.add(new IncomingMessage(data.getBytes(), TIMESTAMP, 0, ackid, RECORD_ID.getBytes())); + incoming.add(new IncomingMessage(data.getBytes(), TIMESTAMP, 0, ackid, RECORD_ID)); } setupOneMessage(incoming); TestPipeline p = TestPipeline.create(); @@ -272,7 +272,7 @@ public void readManyMessages() throws IOException { String recid = String.format("recordid_%d", messageNum); String ackId = String.format("ackid_%d", messageNum); incoming.add(new IncomingMessage(data.getBytes(), messageNumToTimestamp(messageNum), 0, - ackId, recid.getBytes())); + ackId, recid)); } setupOneMessage(incoming); diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/util/PubsubApiaryClientTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/util/PubsubApiaryClientTest.java index 40c31fb5ac03..0f3a7bb506dc 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/util/PubsubApiaryClientTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/util/PubsubApiaryClientTest.java @@ -34,7 +34,6 @@ import com.google.api.services.pubsub.model.ReceivedMessage; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import com.google.common.hash.Hashing; import org.junit.After; import org.junit.Before; @@ -61,8 +60,7 @@ public class PubsubApiaryClientTest { private static final String ID_LABEL = "id"; private static final String MESSAGE_ID = "testMessageId"; private static final String DATA = "testData"; - private static final String CUSTOM_ID = - Hashing.murmur3_128().hashBytes(DATA.getBytes()).toString(); + private static final String RECORD_ID = "testRecordId"; private static final String ACK_ID = "testAckId"; @Before @@ -89,7 +87,7 @@ public void pullOneMessage() throws IOException { .setPublishTime(String.valueOf(PUB_TIME)) .setAttributes( ImmutableMap.of(TIMESTAMP_LABEL, String.valueOf(MESSAGE_TIME), - ID_LABEL, CUSTOM_ID)); + ID_LABEL, RECORD_ID)); ReceivedMessage expectedReceivedMessage = new ReceivedMessage().setMessage(expectedPubsubMessage) .setAckId(ACK_ID); @@ -105,7 +103,7 @@ public void pullOneMessage() throws IOException { IncomingMessage actualMessage = acutalMessages.get(0); assertEquals(ACK_ID, actualMessage.ackId); assertEquals(DATA, new String(actualMessage.elementBytes)); - assertEquals(CUSTOM_ID, new String(actualMessage.recordId)); + assertEquals(RECORD_ID, actualMessage.recordId); assertEquals(REQ_TIME, actualMessage.requestTimeMsSinceEpoch); assertEquals(MESSAGE_TIME, actualMessage.timestampMsSinceEpoch); } @@ -117,7 +115,7 @@ public void publishOneMessage() throws IOException { .encodeData(DATA.getBytes()) .setAttributes( ImmutableMap.of(TIMESTAMP_LABEL, String.valueOf(MESSAGE_TIME), - ID_LABEL, CUSTOM_ID)); + ID_LABEL, RECORD_ID)); PublishRequest expectedRequest = new PublishRequest() .setMessages(ImmutableList.of(expectedPubsubMessage)); PublishResponse expectedResponse = new PublishResponse() @@ -127,7 +125,7 @@ public void publishOneMessage() throws IOException { .publish(expectedTopic, expectedRequest) .execute()) .thenReturn(expectedResponse); - OutgoingMessage actualMessage = new OutgoingMessage(DATA.getBytes(), MESSAGE_TIME); + OutgoingMessage actualMessage = new OutgoingMessage(DATA.getBytes(), MESSAGE_TIME, RECORD_ID); int n = client.publish(TOPIC, ImmutableList.of(actualMessage)); assertEquals(1, n); } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/util/PubsubGrpcClientTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/util/PubsubGrpcClientTest.java index 189049c07ea4..71ee27c86aae 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/util/PubsubGrpcClientTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/util/PubsubGrpcClientTest.java @@ -28,7 +28,6 @@ import com.google.auth.oauth2.GoogleCredentials; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import com.google.common.hash.Hashing; import com.google.protobuf.ByteString; import com.google.protobuf.Timestamp; import com.google.pubsub.v1.PublishRequest; @@ -70,8 +69,7 @@ public class PubsubGrpcClientTest { private static final String ID_LABEL = "id"; private static final String MESSAGE_ID = "testMessageId"; private static final String DATA = "testData"; - private static final String CUSTOM_ID = - Hashing.murmur3_128().hashBytes(DATA.getBytes()).toString(); + private static final String RECORD_ID = "testRecordId"; private static final String ACK_ID = "testAckId"; @Before @@ -118,7 +116,7 @@ public void pullOneMessage() throws IOException { .putAllAttributes( ImmutableMap.of(TIMESTAMP_LABEL, String.valueOf(MESSAGE_TIME), - ID_LABEL, CUSTOM_ID)) + ID_LABEL, RECORD_ID)) .build(); ReceivedMessage expectedReceivedMessage = ReceivedMessage.newBuilder() @@ -136,7 +134,7 @@ public void pullOneMessage() throws IOException { IncomingMessage actualMessage = acutalMessages.get(0); assertEquals(ACK_ID, actualMessage.ackId); assertEquals(DATA, new String(actualMessage.elementBytes)); - assertEquals(CUSTOM_ID, new String(actualMessage.recordId)); + assertEquals(RECORD_ID, actualMessage.recordId); assertEquals(REQ_TIME, actualMessage.requestTimeMsSinceEpoch); assertEquals(MESSAGE_TIME, actualMessage.timestampMsSinceEpoch); } @@ -149,7 +147,7 @@ public void publishOneMessage() throws IOException { .setData(ByteString.copyFrom(DATA.getBytes())) .putAllAttributes( ImmutableMap.of(TIMESTAMP_LABEL, String.valueOf(MESSAGE_TIME), - ID_LABEL, CUSTOM_ID)) + ID_LABEL, RECORD_ID)) .build(); PublishRequest expectedRequest = PublishRequest.newBuilder() @@ -163,7 +161,7 @@ public void publishOneMessage() throws IOException { .build(); Mockito.when(mockPublisherStub.publish(expectedRequest)) .thenReturn(expectedResponse); - OutgoingMessage actualMessage = new OutgoingMessage(DATA.getBytes(), MESSAGE_TIME); + OutgoingMessage actualMessage = new OutgoingMessage(DATA.getBytes(), MESSAGE_TIME, RECORD_ID); int n = client.publish(TOPIC, ImmutableList.of(actualMessage)); assertEquals(1, n); } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/util/PubsubTestClientTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/util/PubsubTestClientTest.java index fedc8bf57ac8..d788f1070cec 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/util/PubsubTestClientTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/util/PubsubTestClientTest.java @@ -61,7 +61,7 @@ public long currentTimeMillis() { } }; IncomingMessage expectedIncomingMessage = - new IncomingMessage(DATA.getBytes(), MESSAGE_TIME, REQ_TIME, ACK_ID, MESSAGE_ID.getBytes()); + new IncomingMessage(DATA.getBytes(), MESSAGE_TIME, REQ_TIME, ACK_ID, MESSAGE_ID); try (PubsubTestClientFactory factory = PubsubTestClient.createFactoryForPull(clock, SUBSCRIPTION, ACK_TIMEOUT_S, Lists.newArrayList(expectedIncomingMessage))) { @@ -99,9 +99,13 @@ public long currentTimeMillis() { @Test public void publishOneMessage() throws IOException { - OutgoingMessage expectedOutgoingMessage = new OutgoingMessage(DATA.getBytes(), MESSAGE_TIME); - try (PubsubTestClientFactory factory = PubsubTestClient.createFactoryForPublish(TOPIC, Sets - .newHashSet(expectedOutgoingMessage))) { + OutgoingMessage expectedOutgoingMessage = + new OutgoingMessage(DATA.getBytes(), MESSAGE_TIME, MESSAGE_ID); + try (PubsubTestClientFactory factory = + PubsubTestClient.createFactoryForPublish( + TOPIC, + Sets.newHashSet(expectedOutgoingMessage), + ImmutableList.of())) { try (PubsubTestClient client = (PubsubTestClient) factory.newClient(null, null, null)) { client.publish(TOPIC, ImmutableList.of(expectedOutgoingMessage)); } From 6c10a1d73981ea6cac8b557f492c3682e9035d4d Mon Sep 17 00:00:00 2001 From: Dan Halperin Date: Thu, 19 May 2016 17:58:49 -0700 Subject: [PATCH 06/21] DisplayData: a few build fixups --- .../src/main/java/org/apache/beam/sdk/transforms/Combine.java | 3 ++- .../beam/sdk/transforms/display/DisplayDataEvaluatorTest.java | 3 +++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Combine.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Combine.java index 190c413ca0a2..20c12425c9b4 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Combine.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Combine.java @@ -1785,7 +1785,8 @@ public List> getSideInputs() { public PCollection> apply(PCollection> input) { return input .apply(GroupByKey.create(fewKeys)) - .apply(Combine.groupedValues(fn, fnDisplayData).withSideInputs(sideInputs)); + .apply(Combine.groupedValues(fn, fnDisplayData) + .withSideInputs(sideInputs)); } @Override diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/display/DisplayDataEvaluatorTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/display/DisplayDataEvaluatorTest.java index 7b1dc79e6bc9..318c11642b91 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/display/DisplayDataEvaluatorTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/display/DisplayDataEvaluatorTest.java @@ -30,6 +30,8 @@ import org.apache.beam.sdk.values.POutput; import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; import java.io.Serializable; import java.util.Set; @@ -37,6 +39,7 @@ /** * Unit tests for {@link DisplayDataEvaluator}. */ +@RunWith(JUnit4.class) public class DisplayDataEvaluatorTest implements Serializable { @Test From dc98211ccf17e94afb03ba51992c731684f855fa Mon Sep 17 00:00:00 2001 From: Thomas Groh Date: Wed, 18 May 2016 13:37:13 -0700 Subject: [PATCH 07/21] Add CrashingRunner for use in TestPipeline CrashingRunner is a PipelineRunner that crashes on calls to run() with an IllegalArgumentException. As a runner is currently required to construct a Pipeline object, this allows removal of all Pipeline Runners from the core SDK while retaining tests that depend only on the graph construction behavior. --- .../beam/sdk/testing/CrashingRunner.java | 72 ++++++++++++++++++ .../apache/beam/sdk/testing/TestPipeline.java | 10 ++- .../beam/sdk/testing/CrashingRunnerTest.java | 76 +++++++++++++++++++ .../beam/sdk/testing/TestPipelineTest.java | 17 ++++- 4 files changed, 172 insertions(+), 3 deletions(-) create mode 100644 sdks/java/core/src/main/java/org/apache/beam/sdk/testing/CrashingRunner.java create mode 100644 sdks/java/core/src/test/java/org/apache/beam/sdk/testing/CrashingRunnerTest.java diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/CrashingRunner.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/CrashingRunner.java new file mode 100644 index 000000000000..975faccc3fa0 --- /dev/null +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/CrashingRunner.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.sdk.testing; + +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.PipelineResult; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.runners.AggregatorRetrievalException; +import org.apache.beam.sdk.runners.AggregatorValues; +import org.apache.beam.sdk.runners.PipelineRunner; +import org.apache.beam.sdk.transforms.Aggregator; + +/** + * A {@link PipelineRunner} that applies no overrides and throws an exception on calls to + * {@link Pipeline#run()}. For use in {@link TestPipeline} to construct but not execute pipelines. + */ +public class CrashingRunner extends PipelineRunner{ + + public static CrashingRunner fromOptions(PipelineOptions opts) { + return new CrashingRunner(); + } + + @Override + public PipelineResult run(Pipeline pipeline) { + throw new IllegalArgumentException(String.format("Cannot call #run(Pipeline) on an instance " + + "of %s. %s should only be used as the default to construct a Pipeline " + + "using %s, and cannot execute Pipelines. Instead, specify a %s " + + "by providing PipelineOptions in the environment variable '%s'.", + getClass().getSimpleName(), + getClass().getSimpleName(), + TestPipeline.class.getSimpleName(), + PipelineRunner.class.getSimpleName(), + TestPipeline.PROPERTY_BEAM_TEST_PIPELINE_OPTIONS)); + } + + private static class TestPipelineResult implements PipelineResult { + private TestPipelineResult() { + // Should never be instantiated by the enclosing class + throw new UnsupportedOperationException(String.format("Forbidden to instantiate %s", + getClass().getSimpleName())); + } + + @Override + public State getState() { + throw new UnsupportedOperationException(String.format("Forbidden to instantiate %s", + getClass().getSimpleName())); + } + + @Override + public AggregatorValues getAggregatorValues(Aggregator aggregator) + throws AggregatorRetrievalException { + throw new AssertionError(String.format("Forbidden to instantiate %s", + getClass().getSimpleName())); + } + } +} diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/TestPipeline.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/TestPipeline.java index a4921d56be0a..4618e33a7204 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/TestPipeline.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/TestPipeline.java @@ -84,7 +84,8 @@ * containing the message from the {@link PAssert} that failed. */ public class TestPipeline extends Pipeline { - private static final String PROPERTY_BEAM_TEST_PIPELINE_OPTIONS = "beamTestPipelineOptions"; + static final String PROPERTY_BEAM_TEST_PIPELINE_OPTIONS = "beamTestPipelineOptions"; + static final String PROPERTY_USE_DEFAULT_DUMMY_RUNNER = "beamUseDummyRunner"; private static final ObjectMapper MAPPER = new ObjectMapper(); /** @@ -145,8 +146,13 @@ public static PipelineOptions testingPipelineOptions() { .as(TestPipelineOptions.class); options.as(ApplicationNameOptions.class).setAppName(getAppName()); - // If no options were specified, use a test credential object on all pipelines. + // If no options were specified, set some reasonable defaults if (Strings.isNullOrEmpty(beamTestPipelineOptions)) { + // If there are no provided options, check to see if a dummy runner should be used. + String useDefaultDummy = System.getProperty(PROPERTY_USE_DEFAULT_DUMMY_RUNNER); + if (!Strings.isNullOrEmpty(useDefaultDummy) && Boolean.valueOf(useDefaultDummy)) { + options.setRunner(CrashingRunner.class); + } options.as(GcpOptions.class).setGcpCredential(new TestCredential()); } options.setStableUniqueNames(CheckEnabled.ERROR); diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/CrashingRunnerTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/CrashingRunnerTest.java new file mode 100644 index 000000000000..041a73ae2d26 --- /dev/null +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/CrashingRunnerTest.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.sdk.testing; + +import static org.junit.Assert.assertTrue; + +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.PipelineResult; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.runners.PipelineRunner; +import org.apache.beam.sdk.transforms.Create; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Tests for {@link CrashingRunner}. + */ +@RunWith(JUnit4.class) +public class CrashingRunnerTest { + @Rule + public ExpectedException thrown = ExpectedException.none(); + + @Test + public void fromOptionsCreatesInstance() { + PipelineOptions opts = PipelineOptionsFactory.create(); + opts.setRunner(CrashingRunner.class); + PipelineRunner runner = PipelineRunner.fromOptions(opts); + + assertTrue("Should have created a CrashingRunner", runner instanceof CrashingRunner); + } + + @Test + public void applySucceeds() { + PipelineOptions opts = PipelineOptionsFactory.create(); + opts.setRunner(CrashingRunner.class); + + Pipeline p = Pipeline.create(opts); + p.apply(Create.of(1, 2, 3)); + } + + @Test + public void runThrows() { + PipelineOptions opts = PipelineOptionsFactory.create(); + opts.setRunner(CrashingRunner.class); + + Pipeline p = Pipeline.create(opts); + p.apply(Create.of(1, 2, 3)); + + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("Cannot call #run"); + thrown.expectMessage(TestPipeline.PROPERTY_BEAM_TEST_PIPELINE_OPTIONS); + + p.run(); + } +} diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/TestPipelineTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/TestPipelineTest.java index 8af4ff25bb06..b741e2ed0e2c 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/TestPipelineTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/TestPipelineTest.java @@ -17,8 +17,8 @@ */ package org.apache.beam.sdk.testing; -import static org.hamcrest.CoreMatchers.startsWith; import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.startsWith; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertThat; @@ -29,6 +29,7 @@ import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.runners.DirectPipelineRunner; +import org.apache.beam.sdk.transforms.Create; import com.fasterxml.jackson.databind.ObjectMapper; @@ -36,6 +37,7 @@ import org.hamcrest.Description; import org.junit.Rule; import org.junit.Test; +import org.junit.rules.ExpectedException; import org.junit.rules.TestRule; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -49,6 +51,7 @@ @RunWith(JUnit4.class) public class TestPipelineTest { @Rule public TestRule restoreSystemProperties = new RestoreSystemProperties(); + @Rule public ExpectedException thrown = ExpectedException.none(); @Test public void testCreationUsingDefaults() { @@ -139,6 +142,18 @@ public void testMatcherSerializationDeserialization() { assertEquals(m2, newOpts.getOnSuccessMatcher()); } + @Test + public void testRunWithDummyEnvironmentVariableFails() { + System.getProperties() + .setProperty(TestPipeline.PROPERTY_USE_DEFAULT_DUMMY_RUNNER, Boolean.toString(true)); + TestPipeline pipeline = TestPipeline.create(); + pipeline.apply(Create.of(1, 2, 3)); + + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("Cannot call #run"); + pipeline.run(); + } + /** * TestMatcher is a matcher designed for testing matcher serialization/deserialization. */ From 58d66a344985eecc9cc3f43c0ecd5dbc7b4fb2e6 Mon Sep 17 00:00:00 2001 From: Kenneth Knowles Date: Mon, 2 May 2016 13:11:12 -0700 Subject: [PATCH 08/21] Add TestFlinkPipelineRunner to FlinkRunnerRegistrar This makes the runner available for selection by integration tests. --- .../runners/flink/FlinkPipelineRunner.java | 16 +---- .../runners/flink/FlinkRunnerRegistrar.java | 4 +- .../flink/TestFlinkPipelineRunner.java | 66 +++++++++++++++++++ .../beam/runners/flink/FlinkTestPipeline.java | 2 +- 4 files changed, 71 insertions(+), 17 deletions(-) create mode 100644 runners/flink/runner/src/main/java/org/apache/beam/runners/flink/TestFlinkPipelineRunner.java diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkPipelineRunner.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkPipelineRunner.java index 3edf6f30c22d..b5ffced60d19 100644 --- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkPipelineRunner.java +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkPipelineRunner.java @@ -108,7 +108,7 @@ public FlinkRunnerResult run(Pipeline pipeline) { this.flinkJobEnv.translate(pipeline); LOG.info("Starting execution of Flink program."); - + JobExecutionResult result; try { result = this.flinkJobEnv.executePipeline(); @@ -138,20 +138,6 @@ public FlinkPipelineOptions getPipelineOptions() { return options; } - /** - * Constructs a runner with default properties for testing. - * - * @return The newly created runner. - */ - public static FlinkPipelineRunner createForTest(boolean streaming) { - FlinkPipelineOptions options = PipelineOptionsFactory.as(FlinkPipelineOptions.class); - // we use [auto] for testing since this will make it pick up the Testing - // ExecutionEnvironment - options.setFlinkMaster("[auto]"); - options.setStreaming(streaming); - return new FlinkPipelineRunner(options); - } - @Override public Output apply( PTransform transform, Input input) { diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkRunnerRegistrar.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkRunnerRegistrar.java index cd99f4e65bce..ec61805a4ed0 100644 --- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkRunnerRegistrar.java +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkRunnerRegistrar.java @@ -41,7 +41,9 @@ private FlinkRunnerRegistrar() { } public static class Runner implements PipelineRunnerRegistrar { @Override public Iterable>> getPipelineRunners() { - return ImmutableList.>>of(FlinkPipelineRunner.class); + return ImmutableList.>>of( + FlinkPipelineRunner.class, + TestFlinkPipelineRunner.class); } } diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/TestFlinkPipelineRunner.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/TestFlinkPipelineRunner.java new file mode 100644 index 000000000000..24883c8035c2 --- /dev/null +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/TestFlinkPipelineRunner.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink; + +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.options.PipelineOptionsValidator; +import org.apache.beam.sdk.runners.PipelineRunner; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.values.PInput; +import org.apache.beam.sdk.values.POutput; + +public class TestFlinkPipelineRunner extends PipelineRunner { + + private FlinkPipelineRunner delegate; + + private TestFlinkPipelineRunner(FlinkPipelineOptions options) { + // We use [auto] for testing since this will make it pick up the Testing ExecutionEnvironment + options.setFlinkMaster("[auto]"); + this.delegate = FlinkPipelineRunner.fromOptions(options); + } + + public static TestFlinkPipelineRunner fromOptions(PipelineOptions options) { + FlinkPipelineOptions flinkOptions = PipelineOptionsValidator.validate(FlinkPipelineOptions.class, options); + return new TestFlinkPipelineRunner(flinkOptions); + } + + public static TestFlinkPipelineRunner create(boolean streaming) { + FlinkPipelineOptions flinkOptions = PipelineOptionsFactory.as(FlinkPipelineOptions.class); + flinkOptions.setStreaming(streaming); + return TestFlinkPipelineRunner.fromOptions(flinkOptions); + } + + @Override + public + OutputT apply(PTransform transform, InputT input) { + return delegate.apply(transform, input); + } + + @Override + public FlinkRunnerResult run(Pipeline pipeline) { + return delegate.run(pipeline); + } + + public PipelineOptions getPipelineOptions() { + return delegate.getPipelineOptions(); + } +} + + diff --git a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/FlinkTestPipeline.java b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/FlinkTestPipeline.java index f015a6680568..edde925c330c 100644 --- a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/FlinkTestPipeline.java +++ b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/FlinkTestPipeline.java @@ -60,7 +60,7 @@ public static FlinkTestPipeline createForStreaming() { * @return The Test Pipeline. */ private static FlinkTestPipeline create(boolean streaming) { - FlinkPipelineRunner flinkRunner = FlinkPipelineRunner.createForTest(streaming); + TestFlinkPipelineRunner flinkRunner = TestFlinkPipelineRunner.create(streaming); return new FlinkTestPipeline(flinkRunner, flinkRunner.getPipelineOptions()); } From bfc1a2ba041c1b8b0033f886266321e5ee53cf6c Mon Sep 17 00:00:00 2001 From: Kenneth Knowles Date: Mon, 2 May 2016 14:04:20 -0700 Subject: [PATCH 09/21] Configure RunnableOnService tests for Flink in batch mode Today Flink batch supports only global windows. This is a situation we intend our build to allow, eventually via JUnit category filtering. For now all the test classes that use non-global windows are excluded entirely via maven configuration. In the future, it should be on a per-test-method basis. --- runners/flink/runner/pom.xml | 106 ++++++++++++++++++++++++++--------- 1 file changed, 79 insertions(+), 27 deletions(-) diff --git a/runners/flink/runner/pom.xml b/runners/flink/runner/pom.xml index a53a386c2828..cde910873285 100644 --- a/runners/flink/runner/pom.xml +++ b/runners/flink/runner/pom.xml @@ -34,31 +34,6 @@ jar - - - disable-runnable-on-service-tests - - true - - - - - org.apache.maven.plugins - maven-surefire-plugin - - - runnable-on-service-tests - - true - - - - - - - - - @@ -87,7 +62,8 @@ flink-avro_2.10 ${flink.version} - + + org.apache.beam java-sdk-all @@ -111,6 +87,21 @@ + + + + org.apache.beam + java-sdk-all + tests + test + + + org.slf4j + slf4j-jdk14 + + + + org.apache.beam java-examples-all @@ -168,10 +159,71 @@ org.apache.maven.plugins maven-surefire-plugin + + + runnable-on-service-tests + integration-test + + test + + + org.apache.beam.sdk.testing.RunnableOnService + all + 4 + true + + org.apache.beam:java-sdk-all + + + + [ + "--runner=org.apache.beam.runners.flink.TestFlinkPipelineRunner", + "--streaming=false" + ] + + + + + **/org/apache/beam/sdk/transforms/CombineTest.java + **/org/apache/beam/sdk/transforms/GroupByKeyTest.java + **/org/apache/beam/sdk/transforms/ViewTest.java + **/org/apache/beam/sdk/transforms/join/CoGroupByKeyTest.java + **/org/apache/beam/sdk/transforms/windowing/WindowTest.java + **/org/apache/beam/sdk/transforms/windowing/WindowingTest.java + **/org/apache/beam/sdk/util/ReshuffleTest.java + + + + + streaming-runnable-on-service-tests + integration-test + + test + + + org.apache.beam.sdk.testing.RunnableOnService + all + 4 + true + + org.apache.beam:java-sdk-all + + + + [ + "--runner=org.apache.beam.runners.flink.TestFlinkPipelineRunner", + "--streaming=true" + ] + + + + + + + - From 55f39bf7cbe65980ed4233146517d608977ddaf6 Mon Sep 17 00:00:00 2001 From: Kenneth Knowles Date: Fri, 6 May 2016 10:54:41 -0700 Subject: [PATCH 10/21] Remove unused threadCount from integration tests --- runners/flink/runner/pom.xml | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/runners/flink/runner/pom.xml b/runners/flink/runner/pom.xml index cde910873285..f94ce6894203 100644 --- a/runners/flink/runner/pom.xml +++ b/runners/flink/runner/pom.xml @@ -168,8 +168,7 @@ org.apache.beam.sdk.testing.RunnableOnService - all - 4 + none true org.apache.beam:java-sdk-all @@ -202,8 +201,7 @@ org.apache.beam.sdk.testing.RunnableOnService - all - 4 + none true org.apache.beam:java-sdk-all From 2d71af71c26992c0ccbf2ab6df8f2a0aef5e586b Mon Sep 17 00:00:00 2001 From: Kenneth Knowles Date: Fri, 6 May 2016 10:55:16 -0700 Subject: [PATCH 11/21] Disable Flink streaming integration tests for now --- runners/flink/runner/pom.xml | 1 + 1 file changed, 1 insertion(+) diff --git a/runners/flink/runner/pom.xml b/runners/flink/runner/pom.xml index f94ce6894203..7e8c5a98a79d 100644 --- a/runners/flink/runner/pom.xml +++ b/runners/flink/runner/pom.xml @@ -200,6 +200,7 @@ test + true org.apache.beam.sdk.testing.RunnableOnService none true From af8e98878bbc8678e33a4c00548ccabf6cf55a17 Mon Sep 17 00:00:00 2001 From: Kenneth Knowles Date: Fri, 6 May 2016 12:49:55 -0700 Subject: [PATCH 12/21] Special casing job exec AssertionError in TestFlinkPipelineRunner --- .../runners/flink/TestFlinkPipelineRunner.java | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/TestFlinkPipelineRunner.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/TestFlinkPipelineRunner.java index 24883c8035c2..139aebf9dd2b 100644 --- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/TestFlinkPipelineRunner.java +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/TestFlinkPipelineRunner.java @@ -26,6 +26,8 @@ import org.apache.beam.sdk.values.PInput; import org.apache.beam.sdk.values.POutput; +import org.apache.flink.runtime.client.JobExecutionException; + public class TestFlinkPipelineRunner extends PipelineRunner { private FlinkPipelineRunner delegate; @@ -55,7 +57,19 @@ OutputT apply(PTransform transform, InputT input) { @Override public FlinkRunnerResult run(Pipeline pipeline) { - return delegate.run(pipeline); + try { + return delegate.run(pipeline); + } catch (RuntimeException e) { + // Special case hack to pull out assertion errors from PAssert; instead there should + // probably be a better story along the lines of UserCodeException. + if (e.getCause() != null + && e.getCause() instanceof JobExecutionException + && e.getCause().getCause() instanceof AssertionError) { + throw (AssertionError) e.getCause().getCause(); + } else { + throw e; + } + } } public PipelineOptions getPipelineOptions() { From 1664c96db5951d74b5ab9a5850def1dbef8adea6 Mon Sep 17 00:00:00 2001 From: Aljoscha Krettek Date: Fri, 6 May 2016 09:38:55 +0200 Subject: [PATCH 13/21] Add hamcrest dependency to Flink Runner Without it the RunnableOnService tests seem to not work --- runners/flink/runner/pom.xml | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/runners/flink/runner/pom.xml b/runners/flink/runner/pom.xml index 7e8c5a98a79d..fda27a863581 100644 --- a/runners/flink/runner/pom.xml +++ b/runners/flink/runner/pom.xml @@ -88,6 +88,22 @@ + + org.hamcrest + hamcrest-all + test + + + junit + junit + test + + + org.mockito + mockito-all + test + + org.apache.beam @@ -124,12 +140,6 @@ org.apache.flink flink-test-utils_2.10 ${flink.version} - test - - - org.mockito - mockito-all - test From 26fa0b21cfda3049e26d47ce174a9b29fe3ec29c Mon Sep 17 00:00:00 2001 From: Aljoscha Krettek Date: Fri, 6 May 2016 08:26:50 +0200 Subject: [PATCH 14/21] Fix Dangling Flink DataSets --- .../FlinkBatchPipelineTranslator.java | 14 ++++++++++++++ .../FlinkBatchTranslationContext.java | 18 +++++++++++++++++- 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkBatchPipelineTranslator.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkBatchPipelineTranslator.java index 3d39e8182cab..512b8229c9ce 100644 --- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkBatchPipelineTranslator.java +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkBatchPipelineTranslator.java @@ -17,6 +17,7 @@ */ package org.apache.beam.runners.flink.translation; +import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.runners.TransformTreeNode; import org.apache.beam.sdk.transforms.AppliedPTransform; @@ -24,7 +25,9 @@ import org.apache.beam.sdk.transforms.join.CoGroupByKey; import org.apache.beam.sdk.values.PValue; +import org.apache.flink.api.java.DataSet; import org.apache.flink.api.java.ExecutionEnvironment; +import org.apache.flink.api.java.io.DiscardingOutputFormat; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -47,6 +50,17 @@ public FlinkBatchPipelineTranslator(ExecutionEnvironment env, PipelineOptions op this.batchContext = new FlinkBatchTranslationContext(env, options); } + @Override + @SuppressWarnings("rawtypes, unchecked") + public void translate(Pipeline pipeline) { + super.translate(pipeline); + + // terminate dangling DataSets + for (DataSet dataSet: batchContext.getDanglingDataSets().values()) { + dataSet.output(new DiscardingOutputFormat()); + } + } + // -------------------------------------------------------------------------------------------- // Pipeline Visitor Methods // -------------------------------------------------------------------------------------------- diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkBatchTranslationContext.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkBatchTranslationContext.java index 71950cf216cb..501b1ea5555c 100644 --- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkBatchTranslationContext.java +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkBatchTranslationContext.java @@ -43,6 +43,13 @@ public class FlinkBatchTranslationContext { private final Map> dataSets; private final Map, DataSet> broadcastDataSets; + /** + * For keeping track about which DataSets don't have a successor. We + * need to terminate these with a discarding sink because the Beam + * model allows dangling operations. + */ + private final Map> danglingDataSets; + private final ExecutionEnvironment env; private final PipelineOptions options; @@ -55,10 +62,16 @@ public FlinkBatchTranslationContext(ExecutionEnvironment env, PipelineOptions op this.options = options; this.dataSets = new HashMap<>(); this.broadcastDataSets = new HashMap<>(); + + this.danglingDataSets = new HashMap<>(); } // ------------------------------------------------------------------------ - + + public Map> getDanglingDataSets() { + return danglingDataSets; + } + public ExecutionEnvironment getExecutionEnvironment() { return env; } @@ -69,12 +82,15 @@ public PipelineOptions getPipelineOptions() { @SuppressWarnings("unchecked") public DataSet getInputDataSet(PValue value) { + // assume that the DataSet is used as an input if retrieved here + danglingDataSets.remove(value); return (DataSet) dataSets.get(value); } public void setOutputDataSet(PValue value, DataSet set) { if (!dataSets.containsKey(value)) { dataSets.put(value, set); + danglingDataSets.put(value, set); } } From 4e60a497b313414aa2b2968b8def6c6f753908fe Mon Sep 17 00:00:00 2001 From: Aljoscha Krettek Date: Fri, 13 May 2016 14:17:50 +0200 Subject: [PATCH 15/21] Fix faulty Flink Flatten when PCollectionList is empty --- .../FlinkBatchTransformTranslators.java | 32 +++++++++++++++---- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkBatchTransformTranslators.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkBatchTransformTranslators.java index a03352efae15..07785aa47c69 100644 --- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkBatchTransformTranslators.java +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkBatchTransformTranslators.java @@ -34,6 +34,7 @@ import org.apache.beam.sdk.coders.CannotProvideCoderException; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.coders.VoidCoder; import org.apache.beam.sdk.io.AvroIO; import org.apache.beam.sdk.io.BoundedSource; import org.apache.beam.sdk.io.Read; @@ -61,6 +62,7 @@ import com.google.api.client.util.Maps; import com.google.common.collect.Lists; +import org.apache.flink.api.common.functions.FlatMapFunction; import org.apache.flink.api.common.functions.GroupReduceFunction; import org.apache.flink.api.common.operators.Keys; import org.apache.flink.api.common.typeinfo.TypeInformation; @@ -78,6 +80,7 @@ import org.apache.flink.api.java.operators.MapPartitionOperator; import org.apache.flink.api.java.operators.UnsortedGrouping; import org.apache.flink.core.fs.Path; +import org.apache.flink.util.Collector; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -91,7 +94,7 @@ /** * Translators for transforming * Dataflow {@link org.apache.beam.sdk.transforms.PTransform}s to - * Flink {@link org.apache.flink.api.java.DataSet}s + * Flink {@link org.apache.flink.api.java.DataSet}s. */ public class FlinkBatchTransformTranslators { @@ -465,15 +468,30 @@ public void translateNode(ParDo.BoundMulti transform, FlinkBatchTransla private static class FlattenPCollectionTranslatorBatch implements FlinkBatchPipelineTranslator.BatchTransformTranslator> { @Override + @SuppressWarnings("unchecked") public void translateNode(Flatten.FlattenPCollectionList transform, FlinkBatchTranslationContext context) { List> allInputs = context.getInput(transform).getAll(); DataSet result = null; - for(PCollection collection : allInputs) { - DataSet current = context.getInputDataSet(collection); - if (result == null) { - result = current; - } else { - result = result.union(current); + if (allInputs.isEmpty()) { + // create an empty dummy source to satisfy downstream operations + // we cannot create an empty source in Flink, therefore we have to + // add the flatMap that simply never forwards the single element + DataSource dummySource = + context.getExecutionEnvironment().fromElements("dummy"); + result = dummySource.flatMap(new FlatMapFunction() { + @Override + public void flatMap(String s, Collector collector) throws Exception { + // never return anything + } + }).returns(new CoderTypeInformation<>((Coder) VoidCoder.of())); + } else { + for (PCollection collection : allInputs) { + DataSet current = context.getInputDataSet(collection); + if (result == null) { + result = current; + } else { + result = result.union(current); + } } } context.setOutputDataSet(context.getOutput(transform), result); From 24bfca230d5db3cb75dd0e30093a10f7523c1238 Mon Sep 17 00:00:00 2001 From: Aljoscha Krettek Date: Tue, 10 May 2016 13:53:03 +0200 Subject: [PATCH 16/21] [BEAM-270] Support Timestamps/Windows in Flink Batch With this change we always use WindowedValue for the underlying Flink DataSets instead of just T. This allows us to support windowing as well. This changes also a lot of other stuff enabled by the above: - Use WindowedValue throughout - Add proper translation for Window.into() - Make side inputs window aware - Make GroupByKey and Combine transformations window aware, this includes support for merging windows. GroupByKey is implemented as a Combine with a concatenating CombineFn, for simplicity This removes Flink specific transformations for things that are handled by builtin sources/sinks, among other things this: - Removes special translation for AvroIO.Read/Write and TextIO.Read/Write - Removes special support for Write.Bound, this was not working properly and is now handled by the Beam machinery that uses DoFns for this - Removes special translation for binary Co-Group, the code was still in there but was never used - Removes ConsoleIO, this can be done using a DoFn With this change all RunnableOnService tests run on Flink Batch. --- runners/flink/runner/pom.xml | 10 - .../beam/runners/flink/io/ConsoleIO.java | 82 -- .../FlinkBatchPipelineTranslator.java | 4 +- .../FlinkBatchTransformTranslators.java | 846 ++++++++++++------ .../FlinkBatchTranslationContext.java | 56 +- .../FlinkStreamingTransformTranslators.java | 22 +- .../FlinkStreamingTranslationContext.java | 29 +- .../functions/FlinkAssignContext.java | 56 ++ .../functions/FlinkAssignWindows.java | 51 ++ .../FlinkCoGroupKeyedListAggregator.java | 61 -- .../functions/FlinkCreateFunction.java | 63 -- .../functions/FlinkDoFnFunction.java | 194 ++-- .../FlinkKeyedListAggregationFunction.java | 78 -- .../FlinkMergingNonShuffleReduceFunction.java | 238 +++++ .../FlinkMergingPartialReduceFunction.java | 205 +++++ .../functions/FlinkMergingReduceFunction.java | 207 +++++ .../FlinkMultiOutputDoFnFunction.java | 157 ++-- .../FlinkMultiOutputProcessContext.java | 176 ++++ .../FlinkMultiOutputPruningFunction.java | 25 +- .../FlinkNoElementAssignContext.java | 71 ++ .../functions/FlinkPartialReduceFunction.java | 171 +++- .../functions/FlinkProcessContext.java | 324 +++++++ .../functions/FlinkReduceFunction.java | 174 +++- .../functions/SideInputInitializer.java | 75 ++ .../translation/functions/UnionCoder.java | 152 ---- .../types/CoderTypeInformation.java | 21 +- .../types/CoderTypeSerializer.java | 14 +- .../translation/types/KvCoderComperator.java | 102 +-- .../types/KvCoderTypeInformation.java | 63 +- .../types/VoidCoderTypeSerializer.java | 112 --- .../wrappers/CombineFnAggregatorWrapper.java | 94 -- .../SerializableFnAggregatorWrapper.java | 31 +- .../wrappers/SinkOutputFormat.java | 10 +- .../wrappers/SourceInputFormat.java | 18 +- .../streaming/FlinkGroupByKeyWrapper.java | 10 +- .../io/FlinkStreamingCreateFunction.java | 9 +- .../apache/beam/runners/flink/AvroITCase.java | 129 --- .../beam/runners/flink/FlattenizeITCase.java | 76 -- .../runners/flink/JoinExamplesITCase.java | 102 --- .../runners/flink/MaybeEmptyTestITCase.java | 66 -- .../runners/flink/ParDoMultiOutputITCase.java | 102 --- .../beam/runners/flink/ReadSourceITCase.java | 14 +- .../flink/RemoveDuplicatesEmptyITCase.java | 72 -- .../runners/flink/RemoveDuplicatesITCase.java | 73 -- .../beam/runners/flink/SideInputITCase.java | 70 -- .../beam/runners/flink/TfIdfITCase.java | 80 -- .../beam/runners/flink/WordCountITCase.java | 77 -- .../runners/flink/WordCountJoin2ITCase.java | 140 --- .../runners/flink/WordCountJoin3ITCase.java | 158 ---- .../streaming/GroupAlsoByWindowTest.java | 3 +- .../beam/runners/flink/util/JoinExamples.java | 161 ---- .../beam/sdk/transforms/join/UnionCoder.java | 2 +- 52 files changed, 2605 insertions(+), 2731 deletions(-) delete mode 100644 runners/flink/runner/src/main/java/org/apache/beam/runners/flink/io/ConsoleIO.java create mode 100644 runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkAssignContext.java create mode 100644 runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkAssignWindows.java delete mode 100644 runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkCoGroupKeyedListAggregator.java delete mode 100644 runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkCreateFunction.java delete mode 100644 runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkKeyedListAggregationFunction.java create mode 100644 runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMergingNonShuffleReduceFunction.java create mode 100644 runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMergingPartialReduceFunction.java create mode 100644 runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMergingReduceFunction.java create mode 100644 runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMultiOutputProcessContext.java create mode 100644 runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkNoElementAssignContext.java create mode 100644 runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkProcessContext.java create mode 100644 runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/SideInputInitializer.java delete mode 100644 runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/UnionCoder.java delete mode 100644 runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/types/VoidCoderTypeSerializer.java delete mode 100644 runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/CombineFnAggregatorWrapper.java delete mode 100644 runners/flink/runner/src/test/java/org/apache/beam/runners/flink/AvroITCase.java delete mode 100644 runners/flink/runner/src/test/java/org/apache/beam/runners/flink/FlattenizeITCase.java delete mode 100644 runners/flink/runner/src/test/java/org/apache/beam/runners/flink/JoinExamplesITCase.java delete mode 100644 runners/flink/runner/src/test/java/org/apache/beam/runners/flink/MaybeEmptyTestITCase.java delete mode 100644 runners/flink/runner/src/test/java/org/apache/beam/runners/flink/ParDoMultiOutputITCase.java delete mode 100644 runners/flink/runner/src/test/java/org/apache/beam/runners/flink/RemoveDuplicatesEmptyITCase.java delete mode 100644 runners/flink/runner/src/test/java/org/apache/beam/runners/flink/RemoveDuplicatesITCase.java delete mode 100644 runners/flink/runner/src/test/java/org/apache/beam/runners/flink/SideInputITCase.java delete mode 100644 runners/flink/runner/src/test/java/org/apache/beam/runners/flink/TfIdfITCase.java delete mode 100644 runners/flink/runner/src/test/java/org/apache/beam/runners/flink/WordCountITCase.java delete mode 100644 runners/flink/runner/src/test/java/org/apache/beam/runners/flink/WordCountJoin2ITCase.java delete mode 100644 runners/flink/runner/src/test/java/org/apache/beam/runners/flink/WordCountJoin3ITCase.java delete mode 100644 runners/flink/runner/src/test/java/org/apache/beam/runners/flink/util/JoinExamples.java diff --git a/runners/flink/runner/pom.xml b/runners/flink/runner/pom.xml index fda27a863581..b29a5bf221c0 100644 --- a/runners/flink/runner/pom.xml +++ b/runners/flink/runner/pom.xml @@ -191,16 +191,6 @@ ] - - - **/org/apache/beam/sdk/transforms/CombineTest.java - **/org/apache/beam/sdk/transforms/GroupByKeyTest.java - **/org/apache/beam/sdk/transforms/ViewTest.java - **/org/apache/beam/sdk/transforms/join/CoGroupByKeyTest.java - **/org/apache/beam/sdk/transforms/windowing/WindowTest.java - **/org/apache/beam/sdk/transforms/windowing/WindowingTest.java - **/org/apache/beam/sdk/util/ReshuffleTest.java - diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/io/ConsoleIO.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/io/ConsoleIO.java deleted file mode 100644 index 9c36c217df36..000000000000 --- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/io/ConsoleIO.java +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.flink.io; - -import org.apache.beam.sdk.transforms.PTransform; -import org.apache.beam.sdk.values.PCollection; -import org.apache.beam.sdk.values.PDone; - -/** - * Transform for printing the contents of a {@link org.apache.beam.sdk.values.PCollection}. - * to standard output. - * - * This is Flink-specific and will only work when executed using the - * {@link org.apache.beam.runners.flink.FlinkPipelineRunner}. - */ -public class ConsoleIO { - - /** - * A PTransform that writes a PCollection to a standard output. - */ - public static class Write { - - /** - * Returns a ConsoleIO.Write PTransform with a default step name. - */ - public static Bound create() { - return new Bound(); - } - - /** - * Returns a ConsoleIO.Write PTransform with the given step name. - */ - public static Bound named(String name) { - return new Bound().named(name); - } - - /** - * A PTransform that writes a bounded PCollection to standard output. - */ - public static class Bound extends PTransform, PDone> { - private static final long serialVersionUID = 0; - - Bound() { - super("ConsoleIO.Write"); - } - - Bound(String name) { - super(name); - } - - /** - * Returns a new ConsoleIO.Write PTransform that's like this one but with the given - * step - * name. Does not modify this object. - */ - public Bound named(String name) { - return new Bound(name); - } - - @Override - public PDone apply(PCollection input) { - return PDone.in(input.getPipeline()); - } - } - } -} - diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkBatchPipelineTranslator.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkBatchPipelineTranslator.java index 512b8229c9ce..69c02a22b36d 100644 --- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkBatchPipelineTranslator.java +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkBatchPipelineTranslator.java @@ -32,8 +32,8 @@ import org.slf4j.LoggerFactory; /** - * FlinkBatchPipelineTranslator knows how to translate Pipeline objects into Flink Jobs. - * This is based on {@link org.apache.beam.runners.dataflow.DataflowPipelineTranslator} + * {@link Pipeline.PipelineVisitor} for executing a {@link Pipeline} as a + * Flink batch job. */ public class FlinkBatchPipelineTranslator extends FlinkPipelineTranslator { diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkBatchTransformTranslators.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkBatchTransformTranslators.java index 07785aa47c69..83588076c46c 100644 --- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkBatchTransformTranslators.java +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkBatchTransformTranslators.java @@ -17,23 +17,24 @@ */ package org.apache.beam.runners.flink.translation; -import org.apache.beam.runners.flink.io.ConsoleIO; -import org.apache.beam.runners.flink.translation.functions.FlinkCoGroupKeyedListAggregator; -import org.apache.beam.runners.flink.translation.functions.FlinkCreateFunction; +import org.apache.beam.runners.flink.translation.functions.FlinkAssignWindows; import org.apache.beam.runners.flink.translation.functions.FlinkDoFnFunction; -import org.apache.beam.runners.flink.translation.functions.FlinkKeyedListAggregationFunction; +import org.apache.beam.runners.flink.translation.functions.FlinkMergingNonShuffleReduceFunction; +import org.apache.beam.runners.flink.translation.functions.FlinkMergingPartialReduceFunction; +import org.apache.beam.runners.flink.translation.functions.FlinkMergingReduceFunction; import org.apache.beam.runners.flink.translation.functions.FlinkMultiOutputDoFnFunction; import org.apache.beam.runners.flink.translation.functions.FlinkMultiOutputPruningFunction; import org.apache.beam.runners.flink.translation.functions.FlinkPartialReduceFunction; import org.apache.beam.runners.flink.translation.functions.FlinkReduceFunction; -import org.apache.beam.runners.flink.translation.functions.UnionCoder; import org.apache.beam.runners.flink.translation.types.CoderTypeInformation; import org.apache.beam.runners.flink.translation.types.KvCoderTypeInformation; import org.apache.beam.runners.flink.translation.wrappers.SinkOutputFormat; import org.apache.beam.runners.flink.translation.wrappers.SourceInputFormat; import org.apache.beam.sdk.coders.CannotProvideCoderException; import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.CoderRegistry; import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.coders.ListCoder; import org.apache.beam.sdk.coders.VoidCoder; import org.apache.beam.sdk.io.AvroIO; import org.apache.beam.sdk.io.BoundedSource; @@ -41,60 +42,63 @@ import org.apache.beam.sdk.io.TextIO; import org.apache.beam.sdk.io.Write; import org.apache.beam.sdk.transforms.Combine; -import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.CombineFnBase; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.Flatten; import org.apache.beam.sdk.transforms.GroupByKey; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.View; -import org.apache.beam.sdk.transforms.join.CoGbkResult; -import org.apache.beam.sdk.transforms.join.CoGbkResultSchema; -import org.apache.beam.sdk.transforms.join.CoGroupByKey; -import org.apache.beam.sdk.transforms.join.KeyedPCollectionTuple; import org.apache.beam.sdk.transforms.join.RawUnionValue; +import org.apache.beam.sdk.transforms.join.UnionCoder; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; +import org.apache.beam.sdk.transforms.windowing.IntervalWindow; +import org.apache.beam.sdk.transforms.windowing.Window; +import org.apache.beam.sdk.transforms.windowing.WindowFn; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.util.WindowingStrategy; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.PValue; import org.apache.beam.sdk.values.TupleTag; -import com.google.api.client.util.Maps; import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import org.apache.flink.api.common.functions.FilterFunction; import org.apache.flink.api.common.functions.FlatMapFunction; -import org.apache.flink.api.common.functions.GroupReduceFunction; +import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.api.common.operators.Keys; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.DataSet; -import org.apache.flink.api.java.io.AvroInputFormat; import org.apache.flink.api.java.io.AvroOutputFormat; -import org.apache.flink.api.java.io.TextInputFormat; -import org.apache.flink.api.java.operators.CoGroupOperator; import org.apache.flink.api.java.operators.DataSink; import org.apache.flink.api.java.operators.DataSource; import org.apache.flink.api.java.operators.FlatMapOperator; import org.apache.flink.api.java.operators.GroupCombineOperator; import org.apache.flink.api.java.operators.GroupReduceOperator; import org.apache.flink.api.java.operators.Grouping; +import org.apache.flink.api.java.operators.MapOperator; import org.apache.flink.api.java.operators.MapPartitionOperator; +import org.apache.flink.api.java.operators.SingleInputUdfOperator; import org.apache.flink.api.java.operators.UnsortedGrouping; import org.apache.flink.core.fs.Path; import org.apache.flink.util.Collector; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.io.ByteArrayOutputStream; -import java.io.IOException; import java.lang.reflect.Field; +import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; /** - * Translators for transforming - * Dataflow {@link org.apache.beam.sdk.transforms.PTransform}s to - * Flink {@link org.apache.flink.api.java.DataSet}s. + * Translators for transforming {@link PTransform PTransforms} to + * Flink {@link DataSet DataSets}. */ public class FlinkBatchTransformTranslators { @@ -103,113 +107,90 @@ public class FlinkBatchTransformTranslators { // -------------------------------------------------------------------------------------------- @SuppressWarnings("rawtypes") - private static final Map, FlinkBatchPipelineTranslator.BatchTransformTranslator> TRANSLATORS = new HashMap<>(); + private static final Map< + Class, + FlinkBatchPipelineTranslator.BatchTransformTranslator> TRANSLATORS = new HashMap<>(); - // register the known translators static { TRANSLATORS.put(View.CreatePCollectionView.class, new CreatePCollectionViewTranslatorBatch()); TRANSLATORS.put(Combine.PerKey.class, new CombinePerKeyTranslatorBatch()); - // we don't need this because we translate the Combine.PerKey directly - //TRANSLATORS.put(Combine.GroupedValues.class, new CombineGroupedValuesTranslator()); - - TRANSLATORS.put(Create.Values.class, new CreateTranslatorBatch()); + TRANSLATORS.put(GroupByKey.class, new GroupByKeyTranslatorBatch()); TRANSLATORS.put(Flatten.FlattenPCollectionList.class, new FlattenPCollectionTranslatorBatch()); - // TODO we're currently ignoring windows here but that has to change in the future - TRANSLATORS.put(GroupByKey.class, new GroupByKeyTranslatorBatch()); + TRANSLATORS.put(Window.Bound.class, new WindowBoundTranslatorBatch()); - TRANSLATORS.put(ParDo.BoundMulti.class, new ParDoBoundMultiTranslatorBatch()); TRANSLATORS.put(ParDo.Bound.class, new ParDoBoundTranslatorBatch()); - - TRANSLATORS.put(CoGroupByKey.class, new CoGroupByKeyTranslatorBatch()); - - TRANSLATORS.put(AvroIO.Read.Bound.class, new AvroIOReadTranslatorBatch()); - TRANSLATORS.put(AvroIO.Write.Bound.class, new AvroIOWriteTranslatorBatch()); + TRANSLATORS.put(ParDo.BoundMulti.class, new ParDoBoundMultiTranslatorBatch()); TRANSLATORS.put(Read.Bounded.class, new ReadSourceTranslatorBatch()); - TRANSLATORS.put(Write.Bound.class, new WriteSinkTranslatorBatch()); - - TRANSLATORS.put(TextIO.Read.Bound.class, new TextIOReadTranslatorBatch()); - TRANSLATORS.put(TextIO.Write.Bound.class, new TextIOWriteTranslatorBatch()); - - // Flink-specific - TRANSLATORS.put(ConsoleIO.Write.Bound.class, new ConsoleIOWriteTranslatorBatch()); - } - public static FlinkBatchPipelineTranslator.BatchTransformTranslator getTranslator(PTransform transform) { + public static FlinkBatchPipelineTranslator.BatchTransformTranslator getTranslator( + PTransform transform) { return TRANSLATORS.get(transform.getClass()); } - private static class ReadSourceTranslatorBatch implements FlinkBatchPipelineTranslator.BatchTransformTranslator> { + private static class ReadSourceTranslatorBatch + implements FlinkBatchPipelineTranslator.BatchTransformTranslator> { @Override public void translateNode(Read.Bounded transform, FlinkBatchTranslationContext context) { String name = transform.getName(); BoundedSource source = transform.getSource(); PCollection output = context.getOutput(transform); - Coder coder = output.getCoder(); - TypeInformation typeInformation = context.getTypeInfo(output); + TypeInformation> typeInformation = context.getTypeInfo(output); - DataSource dataSource = new DataSource<>(context.getExecutionEnvironment(), - new SourceInputFormat<>(source, context.getPipelineOptions()), typeInformation, name); + DataSource> dataSource = new DataSource<>( + context.getExecutionEnvironment(), + new SourceInputFormat<>(source, context.getPipelineOptions()), + typeInformation, + name); context.setOutputDataSet(output, dataSource); } } - private static class AvroIOReadTranslatorBatch implements FlinkBatchPipelineTranslator.BatchTransformTranslator> { - private static final Logger LOG = LoggerFactory.getLogger(AvroIOReadTranslatorBatch.class); + private static class WriteSinkTranslatorBatch + implements FlinkBatchPipelineTranslator.BatchTransformTranslator> { @Override - public void translateNode(AvroIO.Read.Bound transform, FlinkBatchTranslationContext context) { - String path = transform.getFilepattern(); + public void translateNode(Write.Bound transform, FlinkBatchTranslationContext context) { String name = transform.getName(); -// Schema schema = transform.getSchema(); - PValue output = context.getOutput(transform); - - TypeInformation typeInformation = context.getTypeInfo(output); - - // This is super hacky, but unfortunately we cannot get the type otherwise - Class extractedAvroType; - try { - Field typeField = transform.getClass().getDeclaredField("type"); - typeField.setAccessible(true); - @SuppressWarnings("unchecked") - Class avroType = (Class) typeField.get(transform); - extractedAvroType = avroType; - } catch (NoSuchFieldException | IllegalAccessException e) { - // we know that the field is there and it is accessible - throw new RuntimeException("Could not access type from AvroIO.Bound", e); - } - - DataSource source = new DataSource<>(context.getExecutionEnvironment(), - new AvroInputFormat<>(new Path(path), extractedAvroType), - typeInformation, name); + PValue input = context.getInput(transform); + DataSet> inputDataSet = context.getInputDataSet(input); - context.setOutputDataSet(output, source); + inputDataSet.output(new SinkOutputFormat<>(transform, context.getPipelineOptions())) + .name(name); } } - private static class AvroIOWriteTranslatorBatch implements FlinkBatchPipelineTranslator.BatchTransformTranslator> { + private static class AvroIOWriteTranslatorBatch implements + FlinkBatchPipelineTranslator.BatchTransformTranslator> { private static final Logger LOG = LoggerFactory.getLogger(AvroIOWriteTranslatorBatch.class); + @Override - public void translateNode(AvroIO.Write.Bound transform, FlinkBatchTranslationContext context) { - DataSet inputDataSet = context.getInputDataSet(context.getInput(transform)); + public void translateNode( + AvroIO.Write.Bound transform, + FlinkBatchTranslationContext context) { + DataSet> inputDataSet = context.getInputDataSet(context.getInput(transform)); + String filenamePrefix = transform.getFilenamePrefix(); String filenameSuffix = transform.getFilenameSuffix(); int numShards = transform.getNumShards(); String shardNameTemplate = transform.getShardNameTemplate(); // TODO: Implement these. We need Flink support for this. - LOG.warn("Translation of TextIO.Write.filenameSuffix not yet supported. Is: {}.", + LOG.warn( + "Translation of TextIO.Write.filenameSuffix not yet supported. Is: {}.", filenameSuffix); - LOG.warn("Translation of TextIO.Write.shardNameTemplate not yet supported. Is: {}.", shardNameTemplate); + LOG.warn( + "Translation of TextIO.Write.shardNameTemplate not yet supported. Is: {}.", + shardNameTemplate); // This is super hacky, but unfortunately we cannot get the type otherwise Class extractedAvroType; @@ -224,8 +205,17 @@ public void translateNode(AvroIO.Write.Bound transform, FlinkBatchTranslation throw new RuntimeException("Could not access type from AvroIO.Bound", e); } - DataSink dataSink = inputDataSet.output(new AvroOutputFormat<>(new Path - (filenamePrefix), extractedAvroType)); + MapOperator, T> valueStream = inputDataSet.map( + new MapFunction, T>() { + @Override + public T map(WindowedValue value) throws Exception { + return value.getValue(); + } + }).returns(new CoderTypeInformation<>(context.getInput(transform).getCoder())); + + + DataSink dataSink = valueStream.output( + new AvroOutputFormat<>(new Path(filenamePrefix), extractedAvroType)); if (numShards > 0) { dataSink.setParallelism(numShards); @@ -233,37 +223,16 @@ public void translateNode(AvroIO.Write.Bound transform, FlinkBatchTranslation } } - private static class TextIOReadTranslatorBatch implements FlinkBatchPipelineTranslator.BatchTransformTranslator> { - private static final Logger LOG = LoggerFactory.getLogger(TextIOReadTranslatorBatch.class); - - @Override - public void translateNode(TextIO.Read.Bound transform, FlinkBatchTranslationContext context) { - String path = transform.getFilepattern(); - String name = transform.getName(); - - TextIO.CompressionType compressionType = transform.getCompressionType(); - boolean needsValidation = transform.needsValidation(); - - // TODO: Implement these. We need Flink support for this. - LOG.warn("Translation of TextIO.CompressionType not yet supported. Is: {}.", compressionType); - LOG.warn("Translation of TextIO.Read.needsValidation not yet supported. Is: {}.", needsValidation); - - PValue output = context.getOutput(transform); - - TypeInformation typeInformation = context.getTypeInfo(output); - DataSource source = new DataSource<>(context.getExecutionEnvironment(), new TextInputFormat(new Path(path)), typeInformation, name); - - context.setOutputDataSet(output, source); - } - } - - private static class TextIOWriteTranslatorBatch implements FlinkBatchPipelineTranslator.BatchTransformTranslator> { + private static class TextIOWriteTranslatorBatch + implements FlinkBatchPipelineTranslator.BatchTransformTranslator> { private static final Logger LOG = LoggerFactory.getLogger(TextIOWriteTranslatorBatch.class); @Override - public void translateNode(TextIO.Write.Bound transform, FlinkBatchTranslationContext context) { + public void translateNode( + TextIO.Write.Bound transform, + FlinkBatchTranslationContext context) { PValue input = context.getInput(transform); - DataSet inputDataSet = context.getInputDataSet(input); + DataSet> inputDataSet = context.getInputDataSet(input); String filenamePrefix = transform.getFilenamePrefix(); String filenameSuffix = transform.getFilenameSuffix(); @@ -272,12 +241,25 @@ public void translateNode(TextIO.Write.Bound transform, FlinkBatchTranslation String shardNameTemplate = transform.getShardNameTemplate(); // TODO: Implement these. We need Flink support for this. - LOG.warn("Translation of TextIO.Write.needsValidation not yet supported. Is: {}.", needsValidation); - LOG.warn("Translation of TextIO.Write.filenameSuffix not yet supported. Is: {}.", filenameSuffix); - LOG.warn("Translation of TextIO.Write.shardNameTemplate not yet supported. Is: {}.", shardNameTemplate); + LOG.warn( + "Translation of TextIO.Write.needsValidation not yet supported. Is: {}.", + needsValidation); + LOG.warn( + "Translation of TextIO.Write.filenameSuffix not yet supported. Is: {}.", + filenameSuffix); + LOG.warn( + "Translation of TextIO.Write.shardNameTemplate not yet supported. Is: {}.", + shardNameTemplate); - //inputDataSet.print(); - DataSink dataSink = inputDataSet.writeAsText(filenamePrefix); + MapOperator, T> valueStream = inputDataSet.map( + new MapFunction, T>() { + @Override + public T map(WindowedValue value) throws Exception { + return value.getValue(); + } + }).returns(new CoderTypeInformation<>(transform.getCoder())); + + DataSink dataSink = valueStream.writeAsText(filenamePrefix); if (numShards > 0) { dataSink.setParallelism(numShards); @@ -285,148 +267,414 @@ public void translateNode(TextIO.Write.Bound transform, FlinkBatchTranslation } } - private static class ConsoleIOWriteTranslatorBatch implements FlinkBatchPipelineTranslator.BatchTransformTranslator { + private static class WindowBoundTranslatorBatch + implements FlinkBatchPipelineTranslator.BatchTransformTranslator> { + @Override - public void translateNode(ConsoleIO.Write.Bound transform, FlinkBatchTranslationContext context) { + public void translateNode(Window.Bound transform, FlinkBatchTranslationContext context) { PValue input = context.getInput(transform); - DataSet inputDataSet = context.getInputDataSet(input); - inputDataSet.printOnTaskManager(transform.getName()); + + TypeInformation> resultTypeInfo = + context.getTypeInfo(context.getOutput(transform)); + + DataSet> inputDataSet = context.getInputDataSet(input); + + @SuppressWarnings("unchecked") + final WindowingStrategy windowingStrategy = + (WindowingStrategy) + context.getOutput(transform).getWindowingStrategy(); + + WindowFn windowFn = windowingStrategy.getWindowFn(); + + FlinkAssignWindows assignWindowsFunction = + new FlinkAssignWindows<>(windowFn); + + DataSet> resultDataSet = inputDataSet + .flatMap(assignWindowsFunction) + .name(context.getOutput(transform).getName()) + .returns(resultTypeInfo); + + context.setOutputDataSet(context.getOutput(transform), resultDataSet); } } - private static class WriteSinkTranslatorBatch implements FlinkBatchPipelineTranslator.BatchTransformTranslator> { + private static class GroupByKeyTranslatorBatch + implements FlinkBatchPipelineTranslator.BatchTransformTranslator> { @Override - public void translateNode(Write.Bound transform, FlinkBatchTranslationContext context) { - String name = transform.getName(); - PValue input = context.getInput(transform); - DataSet inputDataSet = context.getInputDataSet(input); + public void translateNode( + GroupByKey transform, + FlinkBatchTranslationContext context) { + + // for now, this is copied from the Combine.PerKey translater. Once we have the new runner API + // we can replace GroupByKey by a Combine.PerKey with the Concatenate CombineFn + + DataSet>> inputDataSet = + context.getInputDataSet(context.getInput(transform)); + + Combine.KeyedCombineFn, List> combineFn = + new Concatenate().asKeyedFn(); + + KvCoder inputCoder = (KvCoder) context.getInput(transform).getCoder(); + + Coder> accumulatorCoder; + + try { + accumulatorCoder = + combineFn.getAccumulatorCoder( + context.getInput(transform).getPipeline().getCoderRegistry(), + inputCoder.getKeyCoder(), + inputCoder.getValueCoder()); + } catch (CannotProvideCoderException e) { + throw new RuntimeException(e); + } + + WindowingStrategy windowingStrategy = + context.getInput(transform).getWindowingStrategy(); + + TypeInformation>> kvCoderTypeInformation = + new KvCoderTypeInformation<>( + WindowedValue.getFullCoder( + inputCoder, + windowingStrategy.getWindowFn().windowCoder())); + + TypeInformation>>> partialReduceTypeInfo = + new KvCoderTypeInformation<>( + WindowedValue.getFullCoder( + KvCoder.of(inputCoder.getKeyCoder(), accumulatorCoder), + windowingStrategy.getWindowFn().windowCoder())); + + Grouping>> inputGrouping = + new UnsortedGrouping<>( + inputDataSet, + new Keys.ExpressionKeys<>(new String[]{"key"}, + kvCoderTypeInformation)); + + FlinkPartialReduceFunction, ?> partialReduceFunction; + FlinkReduceFunction, List, ?> reduceFunction; + + if (windowingStrategy.getWindowFn().isNonMerging()) { + @SuppressWarnings("unchecked") + WindowingStrategy boundedStrategy = + (WindowingStrategy) windowingStrategy; + + partialReduceFunction = new FlinkPartialReduceFunction<>( + combineFn, + boundedStrategy, + Collections., WindowingStrategy>emptyMap(), + context.getPipelineOptions()); + + reduceFunction = new FlinkReduceFunction<>( + combineFn, + boundedStrategy, + Collections., WindowingStrategy>emptyMap(), + context.getPipelineOptions()); + + } else { + if (!windowingStrategy.getWindowFn().windowCoder().equals(IntervalWindow.getCoder())) { + throw new UnsupportedOperationException( + "Merging WindowFn with windows other than IntervalWindow are not supported."); + } + + @SuppressWarnings("unchecked") + WindowingStrategy intervalStrategy = + (WindowingStrategy) windowingStrategy; + + partialReduceFunction = new FlinkMergingPartialReduceFunction<>( + combineFn, + intervalStrategy, + Collections., WindowingStrategy>emptyMap(), + context.getPipelineOptions()); + + reduceFunction = new FlinkMergingReduceFunction<>( + combineFn, + intervalStrategy, + Collections., WindowingStrategy>emptyMap(), + context.getPipelineOptions()); + } + + // Partially GroupReduce the values into the intermediate format AccumT (combine) + GroupCombineOperator< + WindowedValue>, + WindowedValue>>> groupCombine = + new GroupCombineOperator<>( + inputGrouping, + partialReduceTypeInfo, + partialReduceFunction, + "GroupCombine: " + transform.getName()); + + Grouping>>> intermediateGrouping = + new UnsortedGrouping<>( + groupCombine, new Keys.ExpressionKeys<>(new String[]{"key"}, groupCombine.getType())); + + // Fully reduce the values and create output format VO + GroupReduceOperator< + WindowedValue>>, WindowedValue>>> outputDataSet = + new GroupReduceOperator<>( + intermediateGrouping, partialReduceTypeInfo, reduceFunction, transform.getName()); + + context.setOutputDataSet(context.getOutput(transform), outputDataSet); - inputDataSet.output(new SinkOutputFormat<>(transform, context.getPipelineOptions())).name(name); } } /** - * Translates a GroupByKey while ignoring window assignments. Current ignores windows. + * Combiner that combines {@code T}s into a single {@code List} containing all inputs. + * + *

    For internal use to translate {@link GroupByKey}. For a large {@link PCollection} this + * is expected to crash! + * + *

    This is copied from the dataflow runner code. + * + * @param the type of elements to concatenate. */ - private static class GroupByKeyTranslatorBatch implements FlinkBatchPipelineTranslator.BatchTransformTranslator> { + private static class Concatenate extends Combine.CombineFn, List> { + @Override + public List createAccumulator() { + return new ArrayList(); + } @Override - public void translateNode(GroupByKey transform, FlinkBatchTranslationContext context) { - DataSet> inputDataSet = context.getInputDataSet(context.getInput(transform)); - GroupReduceFunction, KV>> groupReduceFunction = new FlinkKeyedListAggregationFunction<>(); + public List addInput(List accumulator, T input) { + accumulator.add(input); + return accumulator; + } - TypeInformation>> typeInformation = context.getTypeInfo(context.getOutput(transform)); + @Override + public List mergeAccumulators(Iterable> accumulators) { + List result = createAccumulator(); + for (List accumulator : accumulators) { + result.addAll(accumulator); + } + return result; + } - Grouping> grouping = new UnsortedGrouping<>(inputDataSet, new Keys.ExpressionKeys<>(new String[]{"key"}, inputDataSet.getType())); + @Override + public List extractOutput(List accumulator) { + return accumulator; + } - GroupReduceOperator, KV>> outputDataSet = - new GroupReduceOperator<>(grouping, typeInformation, groupReduceFunction, transform.getName()); + @Override + public Coder> getAccumulatorCoder(CoderRegistry registry, Coder inputCoder) { + return ListCoder.of(inputCoder); + } - context.setOutputDataSet(context.getOutput(transform), outputDataSet); + @Override + public Coder> getDefaultOutputCoder(CoderRegistry registry, Coder inputCoder) { + return ListCoder.of(inputCoder); } } - private static class CombinePerKeyTranslatorBatch implements FlinkBatchPipelineTranslator.BatchTransformTranslator> { + + private static class CombinePerKeyTranslatorBatch + implements FlinkBatchPipelineTranslator.BatchTransformTranslator< + Combine.PerKey> { @Override - public void translateNode(Combine.PerKey transform, FlinkBatchTranslationContext context) { - DataSet> inputDataSet = context.getInputDataSet(context.getInput(transform)); + @SuppressWarnings("unchecked") + public void translateNode( + Combine.PerKey transform, + FlinkBatchTranslationContext context) { + DataSet>> inputDataSet = + context.getInputDataSet(context.getInput(transform)); - @SuppressWarnings("unchecked") - Combine.KeyedCombineFn keyedCombineFn = (Combine.KeyedCombineFn) transform.getFn(); + CombineFnBase.PerKeyCombineFn combineFn = + (CombineFnBase.PerKeyCombineFn) transform.getFn(); + + KvCoder inputCoder = (KvCoder) context.getInput(transform).getCoder(); - KvCoder inputCoder = (KvCoder) context.getInput(transform).getCoder(); + Coder accumulatorCoder; - Coder accumulatorCoder = - null; try { - accumulatorCoder = keyedCombineFn.getAccumulatorCoder(context.getInput(transform).getPipeline().getCoderRegistry(), inputCoder.getKeyCoder(), inputCoder.getValueCoder()); + accumulatorCoder = + combineFn.getAccumulatorCoder( + context.getInput(transform).getPipeline().getCoderRegistry(), + inputCoder.getKeyCoder(), + inputCoder.getValueCoder()); } catch (CannotProvideCoderException e) { - e.printStackTrace(); - // TODO + throw new RuntimeException(e); } - TypeInformation> kvCoderTypeInformation = new KvCoderTypeInformation<>(inputCoder); - TypeInformation> partialReduceTypeInfo = new KvCoderTypeInformation<>(KvCoder.of(inputCoder.getKeyCoder(), accumulatorCoder)); + WindowingStrategy windowingStrategy = + context.getInput(transform).getWindowingStrategy(); + + TypeInformation>> kvCoderTypeInformation = + new KvCoderTypeInformation<>( + WindowedValue.getFullCoder( + inputCoder, + windowingStrategy.getWindowFn().windowCoder())); + + TypeInformation>> partialReduceTypeInfo = + new KvCoderTypeInformation<>( + WindowedValue.getFullCoder( + KvCoder.of(inputCoder.getKeyCoder(), accumulatorCoder), + windowingStrategy.getWindowFn().windowCoder())); + + Grouping>> inputGrouping = + new UnsortedGrouping<>( + inputDataSet, + new Keys.ExpressionKeys<>(new String[]{"key"}, + kvCoderTypeInformation)); + + // construct a map from side input to WindowingStrategy so that + // the DoFn runner can map main-input windows to side input windows + Map, WindowingStrategy> sideInputStrategies = new HashMap<>(); + for (PCollectionView sideInput: transform.getSideInputs()) { + sideInputStrategies.put(sideInput, sideInput.getWindowingStrategyInternal()); + } - Grouping> inputGrouping = new UnsortedGrouping<>(inputDataSet, new Keys.ExpressionKeys<>(new String[]{"key"}, kvCoderTypeInformation)); + if (windowingStrategy.getWindowFn().isNonMerging()) { + WindowingStrategy boundedStrategy = + (WindowingStrategy) windowingStrategy; + + FlinkPartialReduceFunction partialReduceFunction = + new FlinkPartialReduceFunction<>( + combineFn, + boundedStrategy, + sideInputStrategies, + context.getPipelineOptions()); + + FlinkReduceFunction reduceFunction = + new FlinkReduceFunction<>( + combineFn, + boundedStrategy, + sideInputStrategies, + context.getPipelineOptions()); + + // Partially GroupReduce the values into the intermediate format AccumT (combine) + GroupCombineOperator< + WindowedValue>, + WindowedValue>> groupCombine = + new GroupCombineOperator<>( + inputGrouping, + partialReduceTypeInfo, + partialReduceFunction, + "GroupCombine: " + transform.getName()); + + transformSideInputs(transform.getSideInputs(), groupCombine, context); + + TypeInformation>> reduceTypeInfo = + context.getTypeInfo(context.getOutput(transform)); + + Grouping>> intermediateGrouping = + new UnsortedGrouping<>( + groupCombine, + new Keys.ExpressionKeys<>(new String[]{"key"}, groupCombine.getType())); + + // Fully reduce the values and create output format OutputT + GroupReduceOperator< + WindowedValue>, WindowedValue>> outputDataSet = + new GroupReduceOperator<>( + intermediateGrouping, reduceTypeInfo, reduceFunction, transform.getName()); + + transformSideInputs(transform.getSideInputs(), outputDataSet, context); + + context.setOutputDataSet(context.getOutput(transform), outputDataSet); - FlinkPartialReduceFunction partialReduceFunction = new FlinkPartialReduceFunction<>(keyedCombineFn); + } else { + if (!windowingStrategy.getWindowFn().windowCoder().equals(IntervalWindow.getCoder())) { + throw new UnsupportedOperationException( + "Merging WindowFn with windows other than IntervalWindow are not supported."); + } - // Partially GroupReduce the values into the intermediate format VA (combine) - GroupCombineOperator, KV> groupCombine = - new GroupCombineOperator<>(inputGrouping, partialReduceTypeInfo, partialReduceFunction, - "GroupCombine: " + transform.getName()); + // for merging windows we can't to a pre-shuffle combine step since + // elements would not be in their correct windows for side-input access - // Reduce fully to VO - GroupReduceFunction, KV> reduceFunction = new FlinkReduceFunction<>(keyedCombineFn); + WindowingStrategy intervalStrategy = + (WindowingStrategy) windowingStrategy; - TypeInformation> reduceTypeInfo = context.getTypeInfo(context.getOutput(transform)); + FlinkMergingNonShuffleReduceFunction reduceFunction = + new FlinkMergingNonShuffleReduceFunction<>( + combineFn, + intervalStrategy, + sideInputStrategies, + context.getPipelineOptions()); - Grouping> intermediateGrouping = new UnsortedGrouping<>(groupCombine, new Keys.ExpressionKeys<>(new String[]{"key"}, groupCombine.getType())); + TypeInformation>> reduceTypeInfo = + context.getTypeInfo(context.getOutput(transform)); + + Grouping>> grouping = + new UnsortedGrouping<>( + inputDataSet, + new Keys.ExpressionKeys<>(new String[]{"key"}, kvCoderTypeInformation)); + + // Fully reduce the values and create output format OutputT + GroupReduceOperator< + WindowedValue>, WindowedValue>> outputDataSet = + new GroupReduceOperator<>( + grouping, reduceTypeInfo, reduceFunction, transform.getName()); + + transformSideInputs(transform.getSideInputs(), outputDataSet, context); + + context.setOutputDataSet(context.getOutput(transform), outputDataSet); + } - // Fully reduce the values and create output format VO - GroupReduceOperator, KV> outputDataSet = - new GroupReduceOperator<>(intermediateGrouping, reduceTypeInfo, reduceFunction, transform.getName()); - context.setOutputDataSet(context.getOutput(transform), outputDataSet); } } -// private static class CombineGroupedValuesTranslator implements FlinkPipelineTranslator.TransformTranslator> { -// -// @Override -// public void translateNode(Combine.GroupedValues transform, TranslationContext context) { -// DataSet> inputDataSet = context.getInputDataSet(transform.getInput()); -// -// Combine.KeyedCombineFn keyedCombineFn = transform.getFn(); -// -// GroupReduceFunction, KV> groupReduceFunction = new FlinkCombineFunction<>(keyedCombineFn); -// -// TypeInformation> typeInformation = context.getTypeInfo(transform.getOutput()); -// -// Grouping> grouping = new UnsortedGrouping<>(inputDataSet, new Keys.ExpressionKeys<>(new String[]{""}, inputDataSet.getType())); -// -// GroupReduceOperator, KV> outputDataSet = -// new GroupReduceOperator<>(grouping, typeInformation, groupReduceFunction, transform.getName()); -// context.setOutputDataSet(transform.getOutput(), outputDataSet); -// } -// } - - private static class ParDoBoundTranslatorBatch implements FlinkBatchPipelineTranslator.BatchTransformTranslator> { - private static final Logger LOG = LoggerFactory.getLogger(ParDoBoundTranslatorBatch.class); + private static class ParDoBoundTranslatorBatch + implements FlinkBatchPipelineTranslator.BatchTransformTranslator< + ParDo.Bound> { @Override - public void translateNode(ParDo.Bound transform, FlinkBatchTranslationContext context) { - DataSet inputDataSet = context.getInputDataSet(context.getInput(transform)); + public void translateNode( + ParDo.Bound transform, + FlinkBatchTranslationContext context) { + DataSet> inputDataSet = + context.getInputDataSet(context.getInput(transform)); - final DoFn doFn = transform.getFn(); + final DoFn doFn = transform.getFn(); - TypeInformation typeInformation = context.getTypeInfo(context.getOutput(transform)); + TypeInformation> typeInformation = + context.getTypeInfo(context.getOutput(transform)); - FlinkDoFnFunction doFnWrapper = new FlinkDoFnFunction<>(doFn, context.getPipelineOptions()); - MapPartitionOperator outputDataSet = new MapPartitionOperator<>(inputDataSet, typeInformation, doFnWrapper, transform.getName()); + List> sideInputs = transform.getSideInputs(); - transformSideInputs(transform.getSideInputs(), outputDataSet, context); + // construct a map from side input to WindowingStrategy so that + // the DoFn runner can map main-input windows to side input windows + Map, WindowingStrategy> sideInputStrategies = new HashMap<>(); + for (PCollectionView sideInput: sideInputs) { + sideInputStrategies.put(sideInput, sideInput.getWindowingStrategyInternal()); + } + + FlinkDoFnFunction doFnWrapper = + new FlinkDoFnFunction<>( + doFn, + context.getOutput(transform).getWindowingStrategy(), + sideInputStrategies, + context.getPipelineOptions()); + + MapPartitionOperator, WindowedValue> outputDataSet = + new MapPartitionOperator<>( + inputDataSet, + typeInformation, + doFnWrapper, + transform.getName()); + + transformSideInputs(sideInputs, outputDataSet, context); context.setOutputDataSet(context.getOutput(transform), outputDataSet); } } - private static class ParDoBoundMultiTranslatorBatch implements FlinkBatchPipelineTranslator.BatchTransformTranslator> { - private static final Logger LOG = LoggerFactory.getLogger(ParDoBoundMultiTranslatorBatch.class); + private static class ParDoBoundMultiTranslatorBatch + implements FlinkBatchPipelineTranslator.BatchTransformTranslator< + ParDo.BoundMulti> { @Override - public void translateNode(ParDo.BoundMulti transform, FlinkBatchTranslationContext context) { - DataSet inputDataSet = context.getInputDataSet(context.getInput(transform)); + public void translateNode( + ParDo.BoundMulti transform, + FlinkBatchTranslationContext context) { + DataSet> inputDataSet = + context.getInputDataSet(context.getInput(transform)); - final DoFn doFn = transform.getFn(); + final DoFn doFn = transform.getFn(); Map, PCollection> outputs = context.getOutput(transform).getAll(); Map, Integer> outputMap = Maps.newHashMap(); - // put the main output at index 0, FlinkMultiOutputDoFnFunction also expects this + // put the main output at index 0, FlinkMultiOutputDoFnFunction expects this outputMap.put(transform.getMainOutputTag(), 0); int count = 1; for (TupleTag tag: outputs.keySet()) { @@ -435,58 +683,118 @@ public void translateNode(ParDo.BoundMulti transform, FlinkBatchTransla } } + // assume that the windowing strategy is the same for all outputs + WindowingStrategy windowingStrategy = null; + // collect all output Coders and create a UnionCoder for our tagged outputs List> outputCoders = Lists.newArrayList(); for (PCollection coll: outputs.values()) { outputCoders.add(coll.getCoder()); + windowingStrategy = coll.getWindowingStrategy(); + } + + if (windowingStrategy == null) { + throw new IllegalStateException("No outputs defined."); } UnionCoder unionCoder = UnionCoder.of(outputCoders); - @SuppressWarnings("unchecked") - TypeInformation typeInformation = new CoderTypeInformation<>(unionCoder); + TypeInformation> typeInformation = + new CoderTypeInformation<>( + WindowedValue.getFullCoder( + unionCoder, + windowingStrategy.getWindowFn().windowCoder())); - @SuppressWarnings("unchecked") - FlinkMultiOutputDoFnFunction doFnWrapper = new FlinkMultiOutputDoFnFunction(doFn, context.getPipelineOptions(), outputMap); - MapPartitionOperator outputDataSet = new MapPartitionOperator<>(inputDataSet, typeInformation, doFnWrapper, transform.getName()); + List> sideInputs = transform.getSideInputs(); - transformSideInputs(transform.getSideInputs(), outputDataSet, context); + // construct a map from side input to WindowingStrategy so that + // the DoFn runner can map main-input windows to side input windows + Map, WindowingStrategy> sideInputStrategies = new HashMap<>(); + for (PCollectionView sideInput: sideInputs) { + sideInputStrategies.put(sideInput, sideInput.getWindowingStrategyInternal()); + } - for (Map.Entry, PCollection> output: outputs.entrySet()) { - TypeInformation outputType = context.getTypeInfo(output.getValue()); - int outputTag = outputMap.get(output.getKey()); - FlinkMultiOutputPruningFunction pruningFunction = new FlinkMultiOutputPruningFunction<>(outputTag); - FlatMapOperator pruningOperator = new - FlatMapOperator<>(outputDataSet, outputType, - pruningFunction, output.getValue().getName()); - context.setOutputDataSet(output.getValue(), pruningOperator); + @SuppressWarnings("unchecked") + FlinkMultiOutputDoFnFunction doFnWrapper = + new FlinkMultiOutputDoFnFunction( + doFn, + windowingStrategy, + sideInputStrategies, + context.getPipelineOptions(), + outputMap); + + MapPartitionOperator, WindowedValue> taggedDataSet = + new MapPartitionOperator<>( + inputDataSet, + typeInformation, + doFnWrapper, + transform.getName()); + + transformSideInputs(sideInputs, taggedDataSet, context); + for (Map.Entry, PCollection> output: outputs.entrySet()) { + pruneOutput( + taggedDataSet, + context, + outputMap.get(output.getKey()), + (PCollection) output.getValue()); } } + + private void pruneOutput( + MapPartitionOperator, WindowedValue> taggedDataSet, + FlinkBatchTranslationContext context, + int integerTag, + PCollection collection) { + TypeInformation> outputType = context.getTypeInfo(collection); + + FlinkMultiOutputPruningFunction pruningFunction = + new FlinkMultiOutputPruningFunction<>(integerTag); + + FlatMapOperator, WindowedValue> pruningOperator = + new FlatMapOperator<>( + taggedDataSet, + outputType, + pruningFunction, + collection.getName()); + + context.setOutputDataSet(collection, pruningOperator); + } } - private static class FlattenPCollectionTranslatorBatch implements FlinkBatchPipelineTranslator.BatchTransformTranslator> { + private static class FlattenPCollectionTranslatorBatch + implements FlinkBatchPipelineTranslator.BatchTransformTranslator< + Flatten.FlattenPCollectionList> { @Override @SuppressWarnings("unchecked") - public void translateNode(Flatten.FlattenPCollectionList transform, FlinkBatchTranslationContext context) { + public void translateNode( + Flatten.FlattenPCollectionList transform, + FlinkBatchTranslationContext context) { + List> allInputs = context.getInput(transform).getAll(); - DataSet result = null; + DataSet> result = null; + if (allInputs.isEmpty()) { + // create an empty dummy source to satisfy downstream operations // we cannot create an empty source in Flink, therefore we have to // add the flatMap that simply never forwards the single element DataSource dummySource = context.getExecutionEnvironment().fromElements("dummy"); - result = dummySource.flatMap(new FlatMapFunction() { + result = dummySource.flatMap(new FlatMapFunction>() { @Override - public void flatMap(String s, Collector collector) throws Exception { + public void flatMap(String s, Collector> collector) throws Exception { // never return anything } - }).returns(new CoderTypeInformation<>((Coder) VoidCoder.of())); + }).returns( + new CoderTypeInformation<>( + WindowedValue.getFullCoder( + (Coder) VoidCoder.of(), + GlobalWindow.Coder.INSTANCE))); } else { for (PCollection collection : allInputs) { - DataSet current = context.getInputDataSet(collection); + DataSet> current = context.getInputDataSet(collection); if (result == null) { result = current; } else { @@ -494,103 +802,47 @@ public void flatMap(String s, Collector collector) throws Exception { } } } - context.setOutputDataSet(context.getOutput(transform), result); - } - } - private static class CreatePCollectionViewTranslatorBatch implements FlinkBatchPipelineTranslator.BatchTransformTranslator> { - @Override - public void translateNode(View.CreatePCollectionView transform, FlinkBatchTranslationContext context) { - DataSet inputDataSet = context.getInputDataSet(context.getInput(transform)); - PCollectionView input = transform.apply(null); - context.setSideInputDataSet(input, inputDataSet); + // insert a dummy filter, there seems to be a bug in Flink + // that produces duplicate elements after the union in some cases + // if we don't + result = result.filter(new FilterFunction>() { + @Override + public boolean filter(WindowedValue tWindowedValue) throws Exception { + return true; + } + }).name("UnionFixFilter"); + context.setOutputDataSet(context.getOutput(transform), result); } } - private static class CreateTranslatorBatch implements FlinkBatchPipelineTranslator.BatchTransformTranslator> { + private static class CreatePCollectionViewTranslatorBatch + implements FlinkBatchPipelineTranslator.BatchTransformTranslator< + View.CreatePCollectionView> { @Override - public void translateNode(Create.Values transform, FlinkBatchTranslationContext context) { - TypeInformation typeInformation = context.getOutputTypeInfo(); - Iterable elements = transform.getElements(); - - // we need to serialize the elements to byte arrays, since they might contain - // elements that are not serializable by Java serialization. We deserialize them - // in the FlatMap function using the Coder. - - List serializedElements = Lists.newArrayList(); - Coder coder = context.getOutput(transform).getCoder(); - for (OUT element: elements) { - ByteArrayOutputStream bao = new ByteArrayOutputStream(); - try { - coder.encode(element, bao, Coder.Context.OUTER); - serializedElements.add(bao.toByteArray()); - } catch (IOException e) { - throw new RuntimeException("Could not serialize Create elements using Coder: " + e); - } - } + public void translateNode( + View.CreatePCollectionView transform, + FlinkBatchTranslationContext context) { + DataSet> inputDataSet = + context.getInputDataSet(context.getInput(transform)); - DataSet initDataSet = context.getExecutionEnvironment().fromElements(1); - FlinkCreateFunction flatMapFunction = new FlinkCreateFunction<>(serializedElements, coder); - FlatMapOperator outputDataSet = new FlatMapOperator<>(initDataSet, typeInformation, flatMapFunction, transform.getName()); + PCollectionView input = transform.getView(); - context.setOutputDataSet(context.getOutput(transform), outputDataSet); + context.setSideInputDataSet(input, inputDataSet); } } - private static void transformSideInputs(List> sideInputs, - MapPartitionOperator outputDataSet, - FlinkBatchTranslationContext context) { + private static void transformSideInputs( + List> sideInputs, + SingleInputUdfOperator outputDataSet, + FlinkBatchTranslationContext context) { // get corresponding Flink broadcast DataSets - for(PCollectionView input : sideInputs) { + for (PCollectionView input : sideInputs) { DataSet broadcastSet = context.getSideInputDataSet(input); outputDataSet.withBroadcastSet(broadcastSet, input.getTagInternal().getId()); } } -// Disabled because it depends on a pending pull request to the DataFlowSDK - /** - * Special composite transform translator. Only called if the CoGroup is two dimensional. - * @param - */ - private static class CoGroupByKeyTranslatorBatch implements FlinkBatchPipelineTranslator.BatchTransformTranslator> { - - @Override - public void translateNode(CoGroupByKey transform, FlinkBatchTranslationContext context) { - KeyedPCollectionTuple input = context.getInput(transform); - - CoGbkResultSchema schema = input.getCoGbkResultSchema(); - List> keyedCollections = input.getKeyedCollections(); - - KeyedPCollectionTuple.TaggedKeyedPCollection taggedCollection1 = keyedCollections.get(0); - KeyedPCollectionTuple.TaggedKeyedPCollection taggedCollection2 = keyedCollections.get(1); - - TupleTag tupleTag1 = taggedCollection1.getTupleTag(); - TupleTag tupleTag2 = taggedCollection2.getTupleTag(); - - PCollection> collection1 = taggedCollection1.getCollection(); - PCollection> collection2 = taggedCollection2.getCollection(); - - DataSet> inputDataSet1 = context.getInputDataSet(collection1); - DataSet> inputDataSet2 = context.getInputDataSet(collection2); - - TypeInformation> typeInfo = context.getOutputTypeInfo(); - - FlinkCoGroupKeyedListAggregator aggregator = new FlinkCoGroupKeyedListAggregator<>(schema, tupleTag1, tupleTag2); - - Keys.ExpressionKeys> keySelector1 = new Keys.ExpressionKeys<>(new String[]{"key"}, inputDataSet1.getType()); - Keys.ExpressionKeys> keySelector2 = new Keys.ExpressionKeys<>(new String[]{"key"}, inputDataSet2.getType()); - - DataSet> out = new CoGroupOperator<>(inputDataSet1, inputDataSet2, - keySelector1, keySelector2, - aggregator, typeInfo, null, transform.getName()); - context.setOutputDataSet(context.getOutput(transform), out); - } - } - - // -------------------------------------------------------------------------------------------- - // Miscellaneous - // -------------------------------------------------------------------------------------------- - private FlinkBatchTransformTranslators() {} } diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkBatchTranslationContext.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkBatchTranslationContext.java index 501b1ea5555c..ecc3a65c7965 100644 --- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkBatchTranslationContext.java +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkBatchTranslationContext.java @@ -18,26 +18,28 @@ package org.apache.beam.runners.flink.translation; import org.apache.beam.runners.flink.translation.types.CoderTypeInformation; -import org.apache.beam.runners.flink.translation.types.KvCoderTypeInformation; import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.transforms.AppliedPTransform; import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.PInput; import org.apache.beam.sdk.values.POutput; import org.apache.beam.sdk.values.PValue; -import org.apache.beam.sdk.values.TypedPValue; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.DataSet; import org.apache.flink.api.java.ExecutionEnvironment; -import org.apache.flink.api.java.typeutils.GenericTypeInfo; import java.util.HashMap; import java.util.Map; +/** + * Helper for {@link FlinkBatchPipelineTranslator} and translators in + * {@link FlinkBatchTransformTranslators}. + */ public class FlinkBatchTranslationContext { private final Map> dataSets; @@ -81,13 +83,13 @@ public PipelineOptions getPipelineOptions() { } @SuppressWarnings("unchecked") - public DataSet getInputDataSet(PValue value) { + public DataSet> getInputDataSet(PValue value) { // assume that the DataSet is used as an input if retrieved here danglingDataSets.remove(value); - return (DataSet) dataSets.get(value); + return (DataSet>) dataSets.get(value); } - public void setOutputDataSet(PValue value, DataSet set) { + public void setOutputDataSet(PValue value, DataSet> set) { if (!dataSets.containsKey(value)) { dataSets.put(value, set); danglingDataSets.put(value, set); @@ -107,40 +109,32 @@ public DataSet getSideInputDataSet(PCollectionView value) { return (DataSet) broadcastDataSets.get(value); } - public void setSideInputDataSet(PCollectionView value, DataSet set) { + public void setSideInputDataSet( + PCollectionView value, + DataSet> set) { if (!broadcastDataSets.containsKey(value)) { broadcastDataSets.put(value, set); } } - - @SuppressWarnings("unchecked") - public TypeInformation getTypeInfo(PInput output) { - if (output instanceof TypedPValue) { - Coder outputCoder = ((TypedPValue) output).getCoder(); - if (outputCoder instanceof KvCoder) { - return new KvCoderTypeInformation((KvCoder) outputCoder); - } else { - return new CoderTypeInformation(outputCoder); - } - } - return new GenericTypeInfo<>((Class)Object.class); - } - - public TypeInformation getInputTypeInfo() { - return getTypeInfo(currentTransform.getInput()); - } - public TypeInformation getOutputTypeInfo() { - return getTypeInfo((PValue) currentTransform.getOutput()); + @SuppressWarnings("unchecked") + public TypeInformation> getTypeInfo(PCollection collection) { + Coder valueCoder = collection.getCoder(); + WindowedValue.FullWindowedValueCoder windowedValueCoder = + WindowedValue.getFullCoder( + valueCoder, + collection.getWindowingStrategy().getWindowFn().windowCoder()); + + return new CoderTypeInformation<>(windowedValueCoder); } @SuppressWarnings("unchecked") - I getInput(PTransform transform) { - return (I) currentTransform.getInput(); + T getInput(PTransform transform) { + return (T) currentTransform.getInput(); } @SuppressWarnings("unchecked") - O getOutput(PTransform transform) { - return (O) currentTransform.getOutput(); + T getOutput(PTransform transform) { + return (T) currentTransform.getOutput(); } } diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkStreamingTransformTranslators.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkStreamingTransformTranslators.java index 2778d5c3166e..b3fed99ad39f 100644 --- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkStreamingTransformTranslators.java +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkStreamingTransformTranslators.java @@ -18,7 +18,6 @@ package org.apache.beam.runners.flink.translation; -import org.apache.beam.runners.flink.translation.functions.UnionCoder; import org.apache.beam.runners.flink.translation.types.CoderTypeInformation; import org.apache.beam.runners.flink.translation.types.FlinkCoder; import org.apache.beam.runners.flink.translation.wrappers.SourceInputFormat; @@ -46,6 +45,7 @@ import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.join.RawUnionValue; +import org.apache.beam.sdk.transforms.join.UnionCoder; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.GlobalWindow; import org.apache.beam.sdk.transforms.windowing.PaneInfo; @@ -229,29 +229,15 @@ public void translateNode(Read.Bounded transform, FlinkStreamingTranslationCo BoundedSource boundedSource = transform.getSource(); PCollection output = context.getOutput(transform); - Coder defaultOutputCoder = boundedSource.getDefaultOutputCoder(); - CoderTypeInformation typeInfo = new CoderTypeInformation<>(defaultOutputCoder); + TypeInformation> typeInfo = context.getTypeInfo(output); - DataStream source = context.getExecutionEnvironment().createInput( + DataStream> source = context.getExecutionEnvironment().createInput( new SourceInputFormat<>( boundedSource, context.getPipelineOptions()), typeInfo); - DataStream> windowedStream = source.flatMap( - new FlatMapFunction>() { - @Override - public void flatMap(T value, Collector> out) throws Exception { - out.collect( - WindowedValue.of(value, - Instant.now(), - GlobalWindow.INSTANCE, - PaneInfo.NO_FIRING)); - } - }) - .assignTimestampsAndWatermarks(new IngestionTimeExtractor>()); - - context.setOutputDataStream(output, windowedStream); + context.setOutputDataStream(output, source); } } diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkStreamingTranslationContext.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkStreamingTranslationContext.java index 8bc73172405c..0cb80baa7cc8 100644 --- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkStreamingTranslationContext.java +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkStreamingTranslationContext.java @@ -17,21 +17,30 @@ */ package org.apache.beam.runners.flink.translation; +import org.apache.beam.runners.flink.translation.types.CoderTypeInformation; +import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.transforms.AppliedPTransform; import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PInput; import org.apache.beam.sdk.values.POutput; import org.apache.beam.sdk.values.PValue; import com.google.common.base.Preconditions; +import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import java.util.HashMap; import java.util.Map; +/** + * Helper for keeping track of which {@link DataStream DataStreams} map + * to which {@link PTransform PTransforms}. + */ public class FlinkStreamingTranslationContext { private final StreamExecutionEnvironment env; @@ -80,12 +89,24 @@ public void setCurrentTransform(AppliedPTransform currentTransform) { } @SuppressWarnings("unchecked") - public I getInput(PTransform transform) { - return (I) currentTransform.getInput(); + public TypeInformation> getTypeInfo(PCollection collection) { + Coder valueCoder = collection.getCoder(); + WindowedValue.FullWindowedValueCoder windowedValueCoder = + WindowedValue.getFullCoder( + valueCoder, + collection.getWindowingStrategy().getWindowFn().windowCoder()); + + return new CoderTypeInformation<>(windowedValueCoder); + } + + + @SuppressWarnings("unchecked") + public T getInput(PTransform transform) { + return (T) currentTransform.getInput(); } @SuppressWarnings("unchecked") - public O getOutput(PTransform transform) { - return (O) currentTransform.getOutput(); + public T getOutput(PTransform transform) { + return (T) currentTransform.getOutput(); } } diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkAssignContext.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkAssignContext.java new file mode 100644 index 000000000000..7ea8c202f9d5 --- /dev/null +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkAssignContext.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.functions; + +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.WindowFn; +import org.apache.beam.sdk.util.WindowedValue; + +import org.joda.time.Instant; + +import java.util.Collection; + +/** + * {@link org.apache.beam.sdk.transforms.windowing.WindowFn.AssignContext} for + * Flink functions. + */ +class FlinkAssignContext + extends WindowFn.AssignContext { + private final WindowedValue value; + + FlinkAssignContext(WindowFn fn, WindowedValue value) { + fn.super(); + this.value = value; + } + + @Override + public InputT element() { + return value.getValue(); + } + + @Override + public Instant timestamp() { + return value.getTimestamp(); + } + + @Override + public Collection windows() { + return value.getWindows(); + } + +} diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkAssignWindows.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkAssignWindows.java new file mode 100644 index 000000000000..e07e49a2f060 --- /dev/null +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkAssignWindows.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.functions; + +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.WindowFn; +import org.apache.beam.sdk.util.WindowedValue; + +import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.util.Collector; + +import java.util.Collection; + +/** + * Flink {@link FlatMapFunction} for implementing + * {@link org.apache.beam.sdk.transforms.windowing.Window.Bound}. + */ +public class FlinkAssignWindows + implements FlatMapFunction, WindowedValue> { + + private final WindowFn windowFn; + + public FlinkAssignWindows(WindowFn windowFn) { + this.windowFn = windowFn; + } + + @Override + public void flatMap( + WindowedValue input, Collector> collector) throws Exception { + Collection windows = windowFn.assignWindows(new FlinkAssignContext<>(windowFn, input)); + for (W window: windows) { + collector.collect( + WindowedValue.of(input.getValue(), input.getTimestamp(), window, input.getPane())); + } + } +} diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkCoGroupKeyedListAggregator.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkCoGroupKeyedListAggregator.java deleted file mode 100644 index 8e7cdd75ca48..000000000000 --- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkCoGroupKeyedListAggregator.java +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.flink.translation.functions; - -import org.apache.beam.sdk.transforms.join.CoGbkResult; -import org.apache.beam.sdk.transforms.join.CoGbkResultSchema; -import org.apache.beam.sdk.transforms.join.RawUnionValue; -import org.apache.beam.sdk.values.KV; -import org.apache.beam.sdk.values.TupleTag; - -import org.apache.flink.api.common.functions.CoGroupFunction; -import org.apache.flink.util.Collector; - -import java.util.ArrayList; -import java.util.List; - - -public class FlinkCoGroupKeyedListAggregator implements CoGroupFunction, KV, KV>{ - - private CoGbkResultSchema schema; - private TupleTag tupleTag1; - private TupleTag tupleTag2; - - public FlinkCoGroupKeyedListAggregator(CoGbkResultSchema schema, TupleTag tupleTag1, TupleTag tupleTag2) { - this.schema = schema; - this.tupleTag1 = tupleTag1; - this.tupleTag2 = tupleTag2; - } - - @Override - public void coGroup(Iterable> first, Iterable> second, Collector> out) throws Exception { - K k = null; - List result = new ArrayList<>(); - int index1 = schema.getIndex(tupleTag1); - for (KV entry : first) { - k = entry.getKey(); - result.add(new RawUnionValue(index1, entry.getValue())); - } - int index2 = schema.getIndex(tupleTag2); - for (KV entry : second) { - k = entry.getKey(); - result.add(new RawUnionValue(index2, entry.getValue())); - } - out.collect(KV.of(k, new CoGbkResult(schema, result))); - } -} diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkCreateFunction.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkCreateFunction.java deleted file mode 100644 index e5ac7482cfcb..000000000000 --- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkCreateFunction.java +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.flink.translation.functions; - -import org.apache.beam.runners.flink.translation.types.VoidCoderTypeSerializer; -import org.apache.beam.sdk.coders.Coder; - -import org.apache.flink.api.common.functions.FlatMapFunction; -import org.apache.flink.util.Collector; - -import java.io.ByteArrayInputStream; -import java.util.List; - -/** - * This is a hack for transforming a {@link org.apache.beam.sdk.transforms.Create} - * operation. Flink does not allow {@code null} in it's equivalent operation: - * {@link org.apache.flink.api.java.ExecutionEnvironment#fromElements(Object[])}. Therefore - * we use a DataSource with one dummy element and output the elements of the Create operation - * inside this FlatMap. - */ -public class FlinkCreateFunction implements FlatMapFunction { - - private final List elements; - private final Coder coder; - - public FlinkCreateFunction(List elements, Coder coder) { - this.elements = elements; - this.coder = coder; - } - - @Override - @SuppressWarnings("unchecked") - public void flatMap(IN value, Collector out) throws Exception { - - for (byte[] element : elements) { - ByteArrayInputStream bai = new ByteArrayInputStream(element); - OUT outValue = coder.decode(bai, Coder.Context.OUTER); - if (outValue == null) { - // TODO Flink doesn't allow null values in records - out.collect((OUT) VoidCoderTypeSerializer.VoidValue.INSTANCE); - } else { - out.collect(outValue); - } - } - - out.close(); - } -} diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkDoFnFunction.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkDoFnFunction.java index 3566f7e1070e..89243a3ede28 100644 --- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkDoFnFunction.java +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkDoFnFunction.java @@ -18,173 +18,85 @@ package org.apache.beam.runners.flink.translation.functions; import org.apache.beam.runners.flink.translation.utils.SerializedPipelineOptions; -import org.apache.beam.runners.flink.translation.wrappers.SerializableFnAggregatorWrapper; -import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.options.PipelineOptions; -import org.apache.beam.sdk.transforms.Aggregator; -import org.apache.beam.sdk.transforms.Combine; import org.apache.beam.sdk.transforms.DoFn; -import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.transforms.windowing.GlobalWindow; -import org.apache.beam.sdk.transforms.windowing.PaneInfo; -import org.apache.beam.sdk.util.TimerInternals; import org.apache.beam.sdk.util.WindowedValue; -import org.apache.beam.sdk.util.WindowingInternals; -import org.apache.beam.sdk.util.state.StateInternals; +import org.apache.beam.sdk.util.WindowingStrategy; import org.apache.beam.sdk.values.PCollectionView; -import org.apache.beam.sdk.values.TupleTag; - -import com.google.common.collect.ImmutableList; import org.apache.flink.api.common.functions.RichMapPartitionFunction; import org.apache.flink.util.Collector; -import org.joda.time.Instant; -import java.io.IOException; -import java.util.ArrayList; -import java.util.Collection; -import java.util.List; +import java.util.Map; /** * Encapsulates a {@link org.apache.beam.sdk.transforms.DoFn} * inside a Flink {@link org.apache.flink.api.common.functions.RichMapPartitionFunction}. */ -public class FlinkDoFnFunction extends RichMapPartitionFunction { +public class FlinkDoFnFunction + extends RichMapPartitionFunction, WindowedValue> { - private final DoFn doFn; + private final DoFn doFn; private final SerializedPipelineOptions serializedOptions; - public FlinkDoFnFunction(DoFn doFn, PipelineOptions options) { - this.doFn = doFn; - this.serializedOptions = new SerializedPipelineOptions(options); - } - - @Override - public void mapPartition(Iterable values, Collector out) throws Exception { - ProcessContext context = new ProcessContext(doFn, out); - this.doFn.startBundle(context); - for (IN value : values) { - context.inValue = value; - doFn.processElement(context); - } - this.doFn.finishBundle(context); - } - - private class ProcessContext extends DoFn.ProcessContext { - - IN inValue; - Collector outCollector; - - public ProcessContext(DoFn fn, Collector outCollector) { - fn.super(); - super.setupDelegateAggregators(); - this.outCollector = outCollector; - } - - @Override - public IN element() { - return this.inValue; - } - + private final Map, WindowingStrategy> sideInputs; - @Override - public Instant timestamp() { - return Instant.now(); - } + private final boolean requiresWindowAccess; + private final boolean hasSideInputs; - @Override - public BoundedWindow window() { - return GlobalWindow.INSTANCE; - } - - @Override - public PaneInfo pane() { - return PaneInfo.NO_FIRING; - } + private final WindowingStrategy windowingStrategy; - @Override - public WindowingInternals windowingInternals() { - return new WindowingInternals() { - @Override - public StateInternals stateInternals() { - return null; - } - - @Override - public void outputWindowedValue(OUT output, Instant timestamp, Collection windows, PaneInfo pane) { - - } - - @Override - public TimerInternals timerInternals() { - return null; - } + public FlinkDoFnFunction( + DoFn doFn, + WindowingStrategy windowingStrategy, + Map, WindowingStrategy> sideInputs, + PipelineOptions options) { + this.doFn = doFn; + this.sideInputs = sideInputs; + this.serializedOptions = new SerializedPipelineOptions(options); + this.windowingStrategy = windowingStrategy; - @Override - public Collection windows() { - return ImmutableList.of(GlobalWindow.INSTANCE); - } + this.requiresWindowAccess = doFn instanceof DoFn.RequiresWindowAccess; + this.hasSideInputs = !sideInputs.isEmpty(); + } - @Override - public PaneInfo pane() { - return PaneInfo.NO_FIRING; - } + @Override + public void mapPartition( + Iterable> values, + Collector> out) throws Exception { + + FlinkProcessContext context = new FlinkProcessContext<>( + serializedOptions.getPipelineOptions(), + getRuntimeContext(), + doFn, + windowingStrategy, + out, + sideInputs); - @Override - public void writePCollectionViewData(TupleTag tag, Iterable> data, Coder elemCoder) throws IOException { - } + this.doFn.startBundle(context); - @Override - public T sideInput(PCollectionView view, BoundedWindow mainInputWindow) { - throw new RuntimeException("sideInput() not implemented."); + if (!requiresWindowAccess || hasSideInputs) { + // we don't need to explode the windows + for (WindowedValue value : values) { + context = context.forWindowedValue(value); + doFn.processElement(context); + } + } else { + // we need to explode the windows because we have per-window + // side inputs and window access also only works if an element + // is in only one window + for (WindowedValue value : values) { + for (WindowedValue explodedValue: value.explodeWindows()) { + context = context.forWindowedValue(value); + doFn.processElement(context); } - }; - } - - @Override - public PipelineOptions getPipelineOptions() { - return serializedOptions.getPipelineOptions(); - } - - @Override - public T sideInput(PCollectionView view) { - List sideInput = getRuntimeContext().getBroadcastVariable(view.getTagInternal().getId()); - List> windowedValueList = new ArrayList<>(sideInput.size()); - for (T input : sideInput) { - windowedValueList.add(WindowedValue.of(input, Instant.now(), ImmutableList.of(GlobalWindow.INSTANCE), pane())); } - return view.fromIterableInternal(windowedValueList); } - @Override - public void output(OUT output) { - outCollector.collect(output); - } - - @Override - public void outputWithTimestamp(OUT output, Instant timestamp) { - // not FLink's way, just output normally - output(output); - } - - @Override - public void sideOutput(TupleTag tag, T output) { - // ignore the side output, this can happen when a user does not register - // side outputs but then outputs using a freshly created TupleTag. - } - - @Override - public void sideOutputWithTimestamp(TupleTag tag, T output, Instant timestamp) { - sideOutput(tag, output); - } - - @Override - protected Aggregator createAggregatorInternal(String name, Combine.CombineFn combiner) { - SerializableFnAggregatorWrapper wrapper = new SerializableFnAggregatorWrapper<>(combiner); - getRuntimeContext().addAccumulator(name, wrapper); - return wrapper; - } - - + // set the windowed value to null so that the logic + // or outputting in finishBundle kicks in + context = context.forWindowedValue(null); + this.doFn.finishBundle(context); } + } diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkKeyedListAggregationFunction.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkKeyedListAggregationFunction.java deleted file mode 100644 index 7c7084db287c..000000000000 --- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkKeyedListAggregationFunction.java +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.flink.translation.functions; - -import org.apache.beam.sdk.values.KV; - -import org.apache.flink.api.common.functions.GroupReduceFunction; -import org.apache.flink.util.Collector; - -import java.util.Iterator; - -/** - * Flink {@link org.apache.flink.api.common.functions.GroupReduceFunction} for executing a - * {@link org.apache.beam.sdk.transforms.GroupByKey} operation. This reads the input - * {@link org.apache.beam.sdk.values.KV} elements, extracts the key and collects - * the values in a {@code List}. - */ -public class FlinkKeyedListAggregationFunction implements GroupReduceFunction, KV>> { - - @Override - public void reduce(Iterable> values, Collector>> out) throws Exception { - Iterator> it = values.iterator(); - KV first = it.next(); - Iterable passThrough = new PassThroughIterable<>(first, it); - out.collect(KV.of(first.getKey(), passThrough)); - } - - private static class PassThroughIterable implements Iterable, Iterator { - private KV first; - private Iterator> iterator; - - public PassThroughIterable(KV first, Iterator> iterator) { - this.first = first; - this.iterator = iterator; - } - - @Override - public Iterator iterator() { - return this; - } - - @Override - public boolean hasNext() { - return first != null || iterator.hasNext(); - } - - @Override - public V next() { - if (first != null) { - V result = first.getValue(); - first = null; - return result; - } else { - return iterator.next().getValue(); - } - } - - @Override - public void remove() { - throw new UnsupportedOperationException("Cannot remove elements from input."); - } - } -} diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMergingNonShuffleReduceFunction.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMergingNonShuffleReduceFunction.java new file mode 100644 index 000000000000..9074d72e0e15 --- /dev/null +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMergingNonShuffleReduceFunction.java @@ -0,0 +1,238 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.functions; + +import org.apache.beam.runners.flink.translation.utils.SerializedPipelineOptions; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.transforms.CombineFnBase; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.IntervalWindow; +import org.apache.beam.sdk.transforms.windowing.OutputTimeFn; +import org.apache.beam.sdk.transforms.windowing.PaneInfo; +import org.apache.beam.sdk.util.PerKeyCombineFnRunner; +import org.apache.beam.sdk.util.PerKeyCombineFnRunners; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.util.WindowingStrategy; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollectionView; + +import com.google.common.collect.Iterables; +import com.google.common.collect.Lists; + +import org.apache.flink.api.common.functions.RichGroupReduceFunction; +import org.apache.flink.util.Collector; +import org.joda.time.Instant; + +import java.util.Collections; +import java.util.Comparator; +import java.util.Iterator; +import java.util.List; +import java.util.Map; + +/** + * Special version of {@link FlinkReduceFunction} that supports merging windows. This + * assumes that the windows are {@link IntervalWindow IntervalWindows} and exhibits the + * same behaviour as {@code MergeOverlappingIntervalWindows}. + * + *

    This is different from the pair of function for the non-merging windows case + * in that we cannot do combining before the shuffle because elements would not + * yet be in their correct windows for side-input access. + */ +public class FlinkMergingNonShuffleReduceFunction< + K, InputT, AccumT, OutputT, W extends IntervalWindow> + extends RichGroupReduceFunction>, WindowedValue>> { + + private final CombineFnBase.PerKeyCombineFn combineFn; + + private final DoFn, KV> doFn; + + private final WindowingStrategy windowingStrategy; + + private final Map, WindowingStrategy> sideInputs; + + private final SerializedPipelineOptions serializedOptions; + + public FlinkMergingNonShuffleReduceFunction( + CombineFnBase.PerKeyCombineFn keyedCombineFn, + WindowingStrategy windowingStrategy, + Map, WindowingStrategy> sideInputs, + PipelineOptions pipelineOptions) { + + this.combineFn = keyedCombineFn; + + this.windowingStrategy = windowingStrategy; + this.sideInputs = sideInputs; + + this.serializedOptions = new SerializedPipelineOptions(pipelineOptions); + + // dummy DoFn because we need one for ProcessContext + this.doFn = new DoFn, KV>() { + @Override + public void processElement(ProcessContext c) throws Exception { + + } + }; + } + + @Override + public void reduce( + Iterable>> elements, + Collector>> out) throws Exception { + + FlinkProcessContext, KV> processContext = + new FlinkProcessContext<>( + serializedOptions.getPipelineOptions(), + getRuntimeContext(), + doFn, + windowingStrategy, + out, + sideInputs); + + PerKeyCombineFnRunner combineFnRunner = + PerKeyCombineFnRunners.create(combineFn); + + @SuppressWarnings("unchecked") + OutputTimeFn outputTimeFn = + (OutputTimeFn) windowingStrategy.getOutputTimeFn(); + + // get all elements so that we can sort them, has to fit into + // memory + // this seems very unprudent, but correct, for now + List>> sortedInput = Lists.newArrayList(); + for (WindowedValue> inputValue: elements) { + for (WindowedValue> exploded: inputValue.explodeWindows()) { + sortedInput.add(exploded); + } + } + Collections.sort(sortedInput, new Comparator>>() { + @Override + public int compare( + WindowedValue> o1, + WindowedValue> o2) { + return Iterables.getOnlyElement(o1.getWindows()).maxTimestamp() + .compareTo(Iterables.getOnlyElement(o2.getWindows()).maxTimestamp()); + } + }); + + // merge windows, we have to do it in an extra pre-processing step and + // can't do it as we go since the window of early elements would not + // be correct when calling the CombineFn + mergeWindow(sortedInput); + + // iterate over the elements that are sorted by window timestamp + final Iterator>> iterator = sortedInput.iterator(); + + // create accumulator using the first elements key + WindowedValue> currentValue = iterator.next(); + K key = currentValue.getValue().getKey(); + IntervalWindow currentWindow = + (IntervalWindow) Iterables.getOnlyElement(currentValue.getWindows()); + InputT firstValue = currentValue.getValue().getValue(); + processContext = processContext.forWindowedValue(currentValue); + AccumT accumulator = combineFnRunner.createAccumulator(key, processContext); + accumulator = combineFnRunner.addInput(key, accumulator, firstValue, processContext); + + // we use this to keep track of the timestamps assigned by the OutputTimeFn + Instant windowTimestamp = + outputTimeFn.assignOutputTime(currentValue.getTimestamp(), currentWindow); + + while (iterator.hasNext()) { + WindowedValue> nextValue = iterator.next(); + IntervalWindow nextWindow = (IntervalWindow) Iterables.getOnlyElement(nextValue.getWindows()); + + if (currentWindow.equals(nextWindow)) { + // continue accumulating and merge windows + + InputT value = nextValue.getValue().getValue(); + processContext = processContext.forWindowedValue(nextValue); + accumulator = combineFnRunner.addInput(key, accumulator, value, processContext); + + windowTimestamp = outputTimeFn.combine( + windowTimestamp, + outputTimeFn.assignOutputTime(nextValue.getTimestamp(), currentWindow)); + + } else { + // emit the value that we currently have + out.collect( + WindowedValue.of( + KV.of(key, combineFnRunner.extractOutput(key, accumulator, processContext)), + windowTimestamp, + currentWindow, + PaneInfo.NO_FIRING)); + + currentWindow = nextWindow; + InputT value = nextValue.getValue().getValue(); + processContext = processContext.forWindowedValue(nextValue); + accumulator = combineFnRunner.createAccumulator(key, processContext); + accumulator = combineFnRunner.addInput(key, accumulator, value, processContext); + windowTimestamp = outputTimeFn.assignOutputTime(nextValue.getTimestamp(), currentWindow); + } + } + + // emit the final accumulator + out.collect( + WindowedValue.of( + KV.of(key, combineFnRunner.extractOutput(key, accumulator, processContext)), + windowTimestamp, + currentWindow, + PaneInfo.NO_FIRING)); + } + + /** + * Merge windows. This assumes that the list of elements is sorted by window-end timestamp. + * This replaces windows in the input list. + */ + private void mergeWindow(List>> elements) { + int currentStart = 0; + IntervalWindow currentWindow = + (IntervalWindow) Iterables.getOnlyElement(elements.get(0).getWindows()); + + for (int i = 1; i < elements.size(); i++) { + WindowedValue> nextValue = elements.get(i); + IntervalWindow nextWindow = + (IntervalWindow) Iterables.getOnlyElement(nextValue.getWindows()); + if (currentWindow.intersects(nextWindow)) { + // we continue + currentWindow = currentWindow.span(nextWindow); + } else { + // retrofit the merged window to all windows up to "currentStart" + for (int j = i - 1; j >= currentStart; j--) { + WindowedValue> value = elements.get(j); + elements.set( + j, + WindowedValue.of( + value.getValue(), value.getTimestamp(), currentWindow, value.getPane())); + } + currentStart = i; + currentWindow = nextWindow; + } + } + if (currentStart < elements.size() - 1) { + // we have to retrofit the last batch + for (int j = elements.size() - 1; j >= currentStart; j--) { + WindowedValue> value = elements.get(j); + elements.set( + j, + WindowedValue.of( + value.getValue(), value.getTimestamp(), currentWindow, value.getPane())); + } + } + } + +} diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMergingPartialReduceFunction.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMergingPartialReduceFunction.java new file mode 100644 index 000000000000..c12e4204a3f0 --- /dev/null +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMergingPartialReduceFunction.java @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.functions; + +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.transforms.CombineFnBase; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.IntervalWindow; +import org.apache.beam.sdk.transforms.windowing.OutputTimeFn; +import org.apache.beam.sdk.transforms.windowing.PaneInfo; +import org.apache.beam.sdk.util.PerKeyCombineFnRunner; +import org.apache.beam.sdk.util.PerKeyCombineFnRunners; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.util.WindowingStrategy; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollectionView; + +import com.google.common.collect.Iterables; +import com.google.common.collect.Lists; + +import org.apache.flink.util.Collector; +import org.joda.time.Instant; + +import java.util.Collections; +import java.util.Comparator; +import java.util.Iterator; +import java.util.List; +import java.util.Map; + +/** + * Special version of {@link FlinkPartialReduceFunction} that supports merging windows. This + * assumes that the windows are {@link IntervalWindow IntervalWindows} and exhibits the + * same behaviour as {@code MergeOverlappingIntervalWindows}. + */ +public class FlinkMergingPartialReduceFunction + extends FlinkPartialReduceFunction { + + public FlinkMergingPartialReduceFunction( + CombineFnBase.PerKeyCombineFn combineFn, + WindowingStrategy windowingStrategy, + Map, WindowingStrategy> sideInputs, + PipelineOptions pipelineOptions) { + super(combineFn, windowingStrategy, sideInputs, pipelineOptions); + } + + @Override + public void combine( + Iterable>> elements, + Collector>> out) throws Exception { + + FlinkProcessContext, KV> processContext = + new FlinkProcessContext<>( + serializedOptions.getPipelineOptions(), + getRuntimeContext(), + doFn, + windowingStrategy, + out, + sideInputs); + + PerKeyCombineFnRunner combineFnRunner = + PerKeyCombineFnRunners.create(combineFn); + + @SuppressWarnings("unchecked") + OutputTimeFn outputTimeFn = + (OutputTimeFn) windowingStrategy.getOutputTimeFn(); + + // get all elements so that we can sort them, has to fit into + // memory + // this seems very unprudent, but correct, for now + List>> sortedInput = Lists.newArrayList(); + for (WindowedValue> inputValue: elements) { + for (WindowedValue> exploded: inputValue.explodeWindows()) { + sortedInput.add(exploded); + } + } + Collections.sort(sortedInput, new Comparator>>() { + @Override + public int compare( + WindowedValue> o1, + WindowedValue> o2) { + return Iterables.getOnlyElement(o1.getWindows()).maxTimestamp() + .compareTo(Iterables.getOnlyElement(o2.getWindows()).maxTimestamp()); + } + }); + + // merge windows, we have to do it in an extra pre-processing step and + // can't do it as we go since the window of early elements would not + // be correct when calling the CombineFn + mergeWindow(sortedInput); + + // iterate over the elements that are sorted by window timestamp + final Iterator>> iterator = sortedInput.iterator(); + + // create accumulator using the first elements key + WindowedValue> currentValue = iterator.next(); + K key = currentValue.getValue().getKey(); + IntervalWindow currentWindow = + (IntervalWindow) Iterables.getOnlyElement(currentValue.getWindows()); + InputT firstValue = currentValue.getValue().getValue(); + processContext = processContext.forWindowedValue(currentValue); + AccumT accumulator = combineFnRunner.createAccumulator(key, processContext); + accumulator = combineFnRunner.addInput(key, accumulator, firstValue, processContext); + + // we use this to keep track of the timestamps assigned by the OutputTimeFn + Instant windowTimestamp = + outputTimeFn.assignOutputTime(currentValue.getTimestamp(), currentWindow); + + while (iterator.hasNext()) { + WindowedValue> nextValue = iterator.next(); + IntervalWindow nextWindow = (IntervalWindow) Iterables.getOnlyElement(nextValue.getWindows()); + + if (currentWindow.equals(nextWindow)) { + // continue accumulating and merge windows + + InputT value = nextValue.getValue().getValue(); + processContext = processContext.forWindowedValue(nextValue); + accumulator = combineFnRunner.addInput(key, accumulator, value, processContext); + + windowTimestamp = outputTimeFn.combine( + windowTimestamp, + outputTimeFn.assignOutputTime(nextValue.getTimestamp(), currentWindow)); + + } else { + // emit the value that we currently have + out.collect( + WindowedValue.of( + KV.of(key, accumulator), + windowTimestamp, + currentWindow, + PaneInfo.NO_FIRING)); + + currentWindow = nextWindow; + InputT value = nextValue.getValue().getValue(); + processContext = processContext.forWindowedValue(nextValue); + accumulator = combineFnRunner.createAccumulator(key, processContext); + accumulator = combineFnRunner.addInput(key, accumulator, value, processContext); + windowTimestamp = outputTimeFn.assignOutputTime(nextValue.getTimestamp(), currentWindow); + } + } + + // emit the final accumulator + out.collect( + WindowedValue.of( + KV.of(key, accumulator), + windowTimestamp, + currentWindow, + PaneInfo.NO_FIRING)); + } + + /** + * Merge windows. This assumes that the list of elements is sorted by window-end timestamp. + * This replaces windows in the input list. + */ + private void mergeWindow(List>> elements) { + int currentStart = 0; + IntervalWindow currentWindow = + (IntervalWindow) Iterables.getOnlyElement(elements.get(0).getWindows()); + + for (int i = 1; i < elements.size(); i++) { + WindowedValue> nextValue = elements.get(i); + IntervalWindow nextWindow = + (IntervalWindow) Iterables.getOnlyElement(nextValue.getWindows()); + if (currentWindow.intersects(nextWindow)) { + // we continue + currentWindow = currentWindow.span(nextWindow); + } else { + // retrofit the merged window to all windows up to "currentStart" + for (int j = i - 1; j >= currentStart; j--) { + WindowedValue> value = elements.get(j); + elements.set( + j, + WindowedValue.of( + value.getValue(), value.getTimestamp(), currentWindow, value.getPane())); + } + currentStart = i; + currentWindow = nextWindow; + } + } + if (currentStart < elements.size() - 1) { + // we have to retrofit the last batch + for (int j = elements.size() - 1; j >= currentStart; j--) { + WindowedValue> value = elements.get(j); + elements.set( + j, + WindowedValue.of( + value.getValue(), value.getTimestamp(), currentWindow, value.getPane())); + } + } + } +} diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMergingReduceFunction.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMergingReduceFunction.java new file mode 100644 index 000000000000..07d1c9741533 --- /dev/null +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMergingReduceFunction.java @@ -0,0 +1,207 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.functions; + +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.transforms.CombineFnBase; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.IntervalWindow; +import org.apache.beam.sdk.transforms.windowing.OutputTimeFn; +import org.apache.beam.sdk.transforms.windowing.PaneInfo; +import org.apache.beam.sdk.util.PerKeyCombineFnRunner; +import org.apache.beam.sdk.util.PerKeyCombineFnRunners; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.util.WindowingStrategy; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollectionView; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; +import com.google.common.collect.Lists; + +import org.apache.flink.util.Collector; +import org.joda.time.Instant; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.Iterator; +import java.util.List; +import java.util.Map; + +/** + * Special version of {@link FlinkReduceFunction} that supports merging windows. This + * assumes that the windows are {@link IntervalWindow IntervalWindows} and exhibits the + * same behaviour as {@code MergeOverlappingIntervalWindows}. + */ +public class FlinkMergingReduceFunction + extends FlinkReduceFunction { + + public FlinkMergingReduceFunction( + CombineFnBase.PerKeyCombineFn keyedCombineFn, + WindowingStrategy windowingStrategy, + Map, WindowingStrategy> sideInputs, + PipelineOptions pipelineOptions) { + super(keyedCombineFn, windowingStrategy, sideInputs, pipelineOptions); + } + + @Override + public void reduce( + Iterable>> elements, + Collector>> out) throws Exception { + + FlinkProcessContext, KV> processContext = + new FlinkProcessContext<>( + serializedOptions.getPipelineOptions(), + getRuntimeContext(), + doFn, + windowingStrategy, + out, + sideInputs); + + PerKeyCombineFnRunner combineFnRunner = + PerKeyCombineFnRunners.create(combineFn); + + @SuppressWarnings("unchecked") + OutputTimeFn outputTimeFn = + (OutputTimeFn) windowingStrategy.getOutputTimeFn(); + + + // get all elements so that we can sort them, has to fit into + // memory + // this seems very unprudent, but correct, for now + ArrayList>> sortedInput = Lists.newArrayList(); + for (WindowedValue> inputValue: elements) { + for (WindowedValue> exploded: inputValue.explodeWindows()) { + sortedInput.add(exploded); + } + } + Collections.sort(sortedInput, new Comparator>>() { + @Override + public int compare( + WindowedValue> o1, + WindowedValue> o2) { + return Iterables.getOnlyElement(o1.getWindows()).maxTimestamp() + .compareTo(Iterables.getOnlyElement(o2.getWindows()).maxTimestamp()); + } + }); + + // merge windows, we have to do it in an extra pre-processing step and + // can't do it as we go since the window of early elements would not + // be correct when calling the CombineFn + mergeWindow(sortedInput); + + // iterate over the elements that are sorted by window timestamp + final Iterator>> iterator = sortedInput.iterator(); + + // get the first accumulator + WindowedValue> currentValue = iterator.next(); + K key = currentValue.getValue().getKey(); + IntervalWindow currentWindow = + (IntervalWindow) Iterables.getOnlyElement(currentValue.getWindows()); + AccumT accumulator = currentValue.getValue().getValue(); + + // we use this to keep track of the timestamps assigned by the OutputTimeFn, + // in FlinkPartialReduceFunction we already merge the timestamps assigned + // to individual elements, here we just merge them + List windowTimestamps = new ArrayList<>(); + windowTimestamps.add(currentValue.getTimestamp()); + + while (iterator.hasNext()) { + WindowedValue> nextValue = iterator.next(); + IntervalWindow nextWindow = + (IntervalWindow) Iterables.getOnlyElement(nextValue.getWindows()); + + if (nextWindow.equals(currentWindow)) { + // continue accumulating and merge windows + + processContext = processContext.forWindowedValue(nextValue); + + accumulator = combineFnRunner.mergeAccumulators( + key, ImmutableList.of(accumulator, nextValue.getValue().getValue()), processContext); + + windowTimestamps.add(nextValue.getTimestamp()); + } else { + out.collect( + WindowedValue.of( + KV.of(key, combineFnRunner.extractOutput(key, accumulator, processContext)), + outputTimeFn.merge(currentWindow, windowTimestamps), + currentWindow, + PaneInfo.NO_FIRING)); + + windowTimestamps.clear(); + + processContext = processContext.forWindowedValue(nextValue); + + currentWindow = nextWindow; + accumulator = nextValue.getValue().getValue(); + windowTimestamps.add(nextValue.getTimestamp()); + } + } + + // emit the final accumulator + out.collect( + WindowedValue.of( + KV.of(key, combineFnRunner.extractOutput(key, accumulator, processContext)), + outputTimeFn.merge(currentWindow, windowTimestamps), + currentWindow, + PaneInfo.NO_FIRING)); + } + + /** + * Merge windows. This assumes that the list of elements is sorted by window-end timestamp. + * This replaces windows in the input list. + */ + private void mergeWindow(List>> elements) { + int currentStart = 0; + IntervalWindow currentWindow = + (IntervalWindow) Iterables.getOnlyElement(elements.get(0).getWindows()); + + for (int i = 1; i < elements.size(); i++) { + WindowedValue> nextValue = elements.get(i); + IntervalWindow nextWindow = + (IntervalWindow) Iterables.getOnlyElement(nextValue.getWindows()); + if (currentWindow.intersects(nextWindow)) { + // we continue + currentWindow = currentWindow.span(nextWindow); + } else { + // retrofit the merged window to all windows up to "currentStart" + for (int j = i - 1; j >= currentStart; j--) { + WindowedValue> value = elements.get(j); + elements.set( + j, + WindowedValue.of( + value.getValue(), value.getTimestamp(), currentWindow, value.getPane())); + } + currentStart = i; + currentWindow = nextWindow; + } + } + if (currentStart < elements.size() - 1) { + // we have to retrofit the last batch + for (int j = elements.size() - 1; j >= currentStart; j--) { + WindowedValue> value = elements.get(j); + elements.set( + j, + WindowedValue.of( + value.getValue(), value.getTimestamp(), currentWindow, value.getPane())); + } + } + } + +} diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMultiOutputDoFnFunction.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMultiOutputDoFnFunction.java index 476dc5e5f8e5..f92e76fa60cb 100644 --- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMultiOutputDoFnFunction.java +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMultiOutputDoFnFunction.java @@ -18,28 +18,17 @@ package org.apache.beam.runners.flink.translation.functions; import org.apache.beam.runners.flink.translation.utils.SerializedPipelineOptions; -import org.apache.beam.runners.flink.translation.wrappers.SerializableFnAggregatorWrapper; import org.apache.beam.sdk.options.PipelineOptions; -import org.apache.beam.sdk.transforms.Aggregator; -import org.apache.beam.sdk.transforms.Combine; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.join.RawUnionValue; -import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.transforms.windowing.GlobalWindow; -import org.apache.beam.sdk.transforms.windowing.PaneInfo; import org.apache.beam.sdk.util.WindowedValue; -import org.apache.beam.sdk.util.WindowingInternals; +import org.apache.beam.sdk.util.WindowingStrategy; import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.TupleTag; -import com.google.common.collect.ImmutableList; - import org.apache.flink.api.common.functions.RichMapPartitionFunction; import org.apache.flink.util.Collector; -import org.joda.time.Instant; -import java.util.ArrayList; -import java.util.List; import java.util.Map; /** @@ -50,112 +39,72 @@ * and must tag all outputs with the output number. Afterwards a filter will filter out * those elements that are not to be in a specific output. */ -public class FlinkMultiOutputDoFnFunction extends RichMapPartitionFunction { - - private final DoFn doFn; - private final SerializedPipelineOptions serializedPipelineOptions; - private final Map, Integer> outputMap; - - public FlinkMultiOutputDoFnFunction(DoFn doFn, PipelineOptions options, Map, Integer> outputMap) { - this.doFn = doFn; - this.serializedPipelineOptions = new SerializedPipelineOptions(options); - this.outputMap = outputMap; - } - - @Override - public void mapPartition(Iterable values, Collector out) throws Exception { - ProcessContext context = new ProcessContext(doFn, out); - this.doFn.startBundle(context); - for (IN value : values) { - context.inValue = value; - doFn.processElement(context); - } - this.doFn.finishBundle(context); - } +public class FlinkMultiOutputDoFnFunction + extends RichMapPartitionFunction, WindowedValue> { - private class ProcessContext extends DoFn.ProcessContext { + private final DoFn doFn; + private final SerializedPipelineOptions serializedOptions; - IN inValue; - Collector outCollector; + private final Map, Integer> outputMap; - public ProcessContext(DoFn fn, Collector outCollector) { - fn.super(); - this.outCollector = outCollector; - } + private final Map, WindowingStrategy> sideInputs; - @Override - public IN element() { - return this.inValue; - } + private final boolean requiresWindowAccess; + private final boolean hasSideInputs; - @Override - public Instant timestamp() { - return Instant.now(); - } + private final WindowingStrategy windowingStrategy; - @Override - public BoundedWindow window() { - return GlobalWindow.INSTANCE; - } + public FlinkMultiOutputDoFnFunction( + DoFn doFn, + WindowingStrategy windowingStrategy, + Map, WindowingStrategy> sideInputs, + PipelineOptions options, + Map, Integer> outputMap) { + this.doFn = doFn; + this.serializedOptions = new SerializedPipelineOptions(options); + this.outputMap = outputMap; - @Override - public PaneInfo pane() { - return PaneInfo.NO_FIRING; - } + this.requiresWindowAccess = doFn instanceof DoFn.RequiresWindowAccess; + this.hasSideInputs = !sideInputs.isEmpty(); + this.windowingStrategy = windowingStrategy; + this.sideInputs = sideInputs; + } - @Override - public WindowingInternals windowingInternals() { - return null; - } + @Override + public void mapPartition( + Iterable> values, + Collector> out) throws Exception { + + FlinkProcessContext context = new FlinkMultiOutputProcessContext<>( + serializedOptions.getPipelineOptions(), + getRuntimeContext(), + doFn, + windowingStrategy, + out, + outputMap, + sideInputs); - @Override - public PipelineOptions getPipelineOptions() { - return serializedPipelineOptions.getPipelineOptions(); - } + this.doFn.startBundle(context); - @Override - public T sideInput(PCollectionView view) { - List sideInput = getRuntimeContext().getBroadcastVariable(view.getTagInternal() - .getId()); - List> windowedValueList = new ArrayList<>(sideInput.size()); - for (T input : sideInput) { - windowedValueList.add(WindowedValue.of(input, Instant.now(), ImmutableList.of(GlobalWindow.INSTANCE), pane())); + if (!requiresWindowAccess || hasSideInputs) { + // we don't need to explode the windows + for (WindowedValue value : values) { + context = context.forWindowedValue(value); + doFn.processElement(context); } - return view.fromIterableInternal(windowedValueList); - } - - @Override - public void output(OUT value) { - // assume that index 0 is the default output - outCollector.collect(new RawUnionValue(0, value)); - } - - @Override - public void outputWithTimestamp(OUT output, Instant timestamp) { - // not FLink's way, just output normally - output(output); - } - - @Override - @SuppressWarnings("unchecked") - public void sideOutput(TupleTag tag, T value) { - Integer index = outputMap.get(tag); - if (index != null) { - outCollector.collect(new RawUnionValue(index, value)); + } else { + // we need to explode the windows because we have per-window + // side inputs and window access also only works if an element + // is in only one window + for (WindowedValue value : values) { + for (WindowedValue explodedValue: value.explodeWindows()) { + context = context.forWindowedValue(value); + doFn.processElement(context); + } } } - @Override - public void sideOutputWithTimestamp(TupleTag tag, T output, Instant timestamp) { - sideOutput(tag, output); - } - - @Override - protected Aggregator createAggregatorInternal(String name, Combine.CombineFn combiner) { - SerializableFnAggregatorWrapper wrapper = new SerializableFnAggregatorWrapper<>(combiner); - getRuntimeContext().addAccumulator(name, wrapper); - return null; - } + this.doFn.finishBundle(context); } } diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMultiOutputProcessContext.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMultiOutputProcessContext.java new file mode 100644 index 000000000000..71b6d27ddba0 --- /dev/null +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMultiOutputProcessContext.java @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.functions; + +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.join.RawUnionValue; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.PaneInfo; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.util.WindowingStrategy; +import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.TupleTag; + +import org.apache.flink.api.common.functions.RuntimeContext; +import org.apache.flink.util.Collector; +import org.joda.time.Instant; + +import java.util.Collection; +import java.util.Map; + +/** + * {@link DoFn.ProcessContext} for {@link FlinkMultiOutputDoFnFunction} that supports + * side outputs. + */ +class FlinkMultiOutputProcessContext + extends FlinkProcessContext { + + // we need a different Collector from the base class + private final Collector> collector; + + private final Map, Integer> outputMap; + + + FlinkMultiOutputProcessContext( + PipelineOptions pipelineOptions, + RuntimeContext runtimeContext, + DoFn doFn, + WindowingStrategy windowingStrategy, + Collector> collector, + Map, Integer> outputMap, + Map, WindowingStrategy> sideInputs) { + super( + pipelineOptions, + runtimeContext, + doFn, + windowingStrategy, + new Collector>() { + @Override + public void collect(WindowedValue outputTWindowedValue) { + + } + + @Override + public void close() { + + } + }, + sideInputs); + + this.collector = collector; + this.outputMap = outputMap; + } + + @Override + public FlinkProcessContext forWindowedValue( + WindowedValue windowedValue) { + this.windowedValue = windowedValue; + return this; + } + + @Override + public void outputWithTimestamp(OutputT value, Instant timestamp) { + if (windowedValue == null) { + // we are in startBundle() or finishBundle() + + try { + Collection windows = windowingStrategy.getWindowFn().assignWindows( + new FlinkNoElementAssignContext( + windowingStrategy.getWindowFn(), + value, + timestamp)); + + collector.collect( + WindowedValue.of( + new RawUnionValue(0, value), + timestamp != null ? timestamp : new Instant(Long.MIN_VALUE), + windows, + PaneInfo.NO_FIRING)); + } catch (Exception e) { + throw new RuntimeException(e); + } + } else { + collector.collect( + WindowedValue.of( + new RawUnionValue(0, value), + windowedValue.getTimestamp(), + windowedValue.getWindows(), + windowedValue.getPane())); + } + } + + @Override + protected void outputWithTimestampAndWindow( + OutputT value, + Instant timestamp, + Collection windows, + PaneInfo pane) { + collector.collect( + WindowedValue.of( + new RawUnionValue(0, value), timestamp, windows, pane)); + } + + @Override + @SuppressWarnings("unchecked") + public void sideOutput(TupleTag tag, T value) { + if (windowedValue != null) { + sideOutputWithTimestamp(tag, value, windowedValue.getTimestamp()); + } else { + sideOutputWithTimestamp(tag, value, null); + } + } + + @Override + public void sideOutputWithTimestamp(TupleTag tag, T value, Instant timestamp) { + Integer index = outputMap.get(tag); + + if (index == null) { + throw new IllegalArgumentException("Unknown side output tag: " + tag); + } + + if (windowedValue == null) { + // we are in startBundle() or finishBundle() + + try { + Collection windows = windowingStrategy.getWindowFn().assignWindows( + new FlinkNoElementAssignContext( + windowingStrategy.getWindowFn(), + value, + timestamp)); + + collector.collect( + WindowedValue.of( + new RawUnionValue(index, value), + timestamp != null ? timestamp : new Instant(Long.MIN_VALUE), + windows, + PaneInfo.NO_FIRING)); + } catch (Exception e) { + throw new RuntimeException(e); + } + } else { + collector.collect( + WindowedValue.of( + new RawUnionValue(index, value), + windowedValue.getTimestamp(), + windowedValue.getWindows(), + windowedValue.getPane())); + } + + } +} diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMultiOutputPruningFunction.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMultiOutputPruningFunction.java index 58a36b27c5dd..9205a5520f82 100644 --- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMultiOutputPruningFunction.java +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMultiOutputPruningFunction.java @@ -18,27 +18,34 @@ package org.apache.beam.runners.flink.translation.functions; import org.apache.beam.sdk.transforms.join.RawUnionValue; +import org.apache.beam.sdk.util.WindowedValue; import org.apache.flink.api.common.functions.FlatMapFunction; import org.apache.flink.util.Collector; /** - * A FlatMap function that filters out those elements that don't belong in this output. We need - * this to implement MultiOutput ParDo functions. + * A {@link FlatMapFunction} function that filters out those elements that don't belong in this + * output. We need this to implement MultiOutput ParDo functions in combination with + * {@link FlinkMultiOutputDoFnFunction}. */ -public class FlinkMultiOutputPruningFunction implements FlatMapFunction { +public class FlinkMultiOutputPruningFunction + implements FlatMapFunction, WindowedValue> { - private final int outputTag; + private final int ourOutputTag; - public FlinkMultiOutputPruningFunction(int outputTag) { - this.outputTag = outputTag; + public FlinkMultiOutputPruningFunction(int ourOutputTag) { + this.ourOutputTag = ourOutputTag; } @Override @SuppressWarnings("unchecked") - public void flatMap(RawUnionValue rawUnionValue, Collector collector) throws Exception { - if (rawUnionValue.getUnionTag() == outputTag) { - collector.collect((T) rawUnionValue.getValue()); + public void flatMap( + WindowedValue windowedValue, + Collector> collector) throws Exception { + int unionTag = windowedValue.getValue().getUnionTag(); + if (unionTag == ourOutputTag) { + collector.collect( + (WindowedValue) windowedValue.withValue(windowedValue.getValue().getValue())); } } } diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkNoElementAssignContext.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkNoElementAssignContext.java new file mode 100644 index 000000000000..892f7a1f33f0 --- /dev/null +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkNoElementAssignContext.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.functions; + +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.WindowFn; + +import org.joda.time.Instant; + +import java.util.Collection; + +/** + * {@link WindowFn.AssignContext} for calling a {@link WindowFn} for elements emitted from + * {@link org.apache.beam.sdk.transforms.DoFn#startBundle(DoFn.Context)} + * or {@link DoFn#finishBundle(DoFn.Context)}. + * + *

    In those cases the {@code WindowFn} is not allowed to access any element information. + */ +class FlinkNoElementAssignContext + extends WindowFn.AssignContext { + + private final InputT element; + private final Instant timestamp; + + FlinkNoElementAssignContext( + WindowFn fn, + InputT element, + Instant timestamp) { + fn.super(); + + this.element = element; + // the timestamp can be null, in that case output is called + // without a timestamp + this.timestamp = timestamp; + } + + @Override + public InputT element() { + return element; + } + + @Override + public Instant timestamp() { + if (timestamp != null) { + return timestamp; + } else { + throw new UnsupportedOperationException("No timestamp available."); + } + } + + @Override + public Collection windows() { + throw new UnsupportedOperationException("No windows available."); + } +} diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkPartialReduceFunction.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkPartialReduceFunction.java index a2bab2b3060f..c29e1df2ceb0 100644 --- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkPartialReduceFunction.java +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkPartialReduceFunction.java @@ -17,45 +17,170 @@ */ package org.apache.beam.runners.flink.translation.functions; -import org.apache.beam.sdk.transforms.Combine; +import org.apache.beam.runners.flink.translation.utils.SerializedPipelineOptions; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.transforms.CombineFnBase; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.OutputTimeFn; +import org.apache.beam.sdk.transforms.windowing.PaneInfo; +import org.apache.beam.sdk.util.PerKeyCombineFnRunner; +import org.apache.beam.sdk.util.PerKeyCombineFnRunners; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.util.WindowingStrategy; import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollectionView; -import org.apache.flink.api.common.functions.GroupCombineFunction; +import com.google.common.collect.Iterables; +import com.google.common.collect.Lists; + +import org.apache.flink.api.common.functions.RichGroupCombineFunction; import org.apache.flink.util.Collector; +import org.joda.time.Instant; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; import java.util.Iterator; +import java.util.Map; /** - * Flink {@link org.apache.flink.api.common.functions.GroupCombineFunction} for executing a - * {@link org.apache.beam.sdk.transforms.Combine.PerKey} operation. This reads the input - * {@link org.apache.beam.sdk.values.KV} elements VI, extracts the key and emits accumulated - * values which have the intermediate format VA. + * This is is the first step for executing a {@link org.apache.beam.sdk.transforms.Combine.PerKey} + * on Flink. The second part is {@link FlinkReduceFunction}. This function performs a local + * combine step before shuffling while the latter does the final combination after a shuffle. + * + *

    The input to {@link #combine(Iterable, Collector)} are elements of the same key but + * for different windows. We have to ensure that we only combine elements of matching + * windows. */ -public class FlinkPartialReduceFunction implements GroupCombineFunction, KV> { +public class FlinkPartialReduceFunction + extends RichGroupCombineFunction>, WindowedValue>> { + + protected final CombineFnBase.PerKeyCombineFn combineFn; + + protected final DoFn, KV> doFn; + + protected final WindowingStrategy windowingStrategy; + + protected final SerializedPipelineOptions serializedOptions; - private final Combine.KeyedCombineFn keyedCombineFn; + protected final Map, WindowingStrategy> sideInputs; - public FlinkPartialReduceFunction(Combine.KeyedCombineFn - keyedCombineFn) { - this.keyedCombineFn = keyedCombineFn; + public FlinkPartialReduceFunction( + CombineFnBase.PerKeyCombineFn combineFn, + WindowingStrategy windowingStrategy, + Map, WindowingStrategy> sideInputs, + PipelineOptions pipelineOptions) { + + this.combineFn = combineFn; + this.windowingStrategy = windowingStrategy; + this.sideInputs = sideInputs; + this.serializedOptions = new SerializedPipelineOptions(pipelineOptions); + + // dummy DoFn because we need one for ProcessContext + this.doFn = new DoFn, KV>() { + @Override + public void processElement(ProcessContext c) throws Exception { + + } + }; } @Override - public void combine(Iterable> elements, Collector> out) throws Exception { + public void combine( + Iterable>> elements, + Collector>> out) throws Exception { + + FlinkProcessContext, KV> processContext = + new FlinkProcessContext<>( + serializedOptions.getPipelineOptions(), + getRuntimeContext(), + doFn, + windowingStrategy, + out, + sideInputs); + + PerKeyCombineFnRunner combineFnRunner = + PerKeyCombineFnRunners.create(combineFn); + + @SuppressWarnings("unchecked") + OutputTimeFn outputTimeFn = + (OutputTimeFn) windowingStrategy.getOutputTimeFn(); + + // get all elements so that we can sort them, has to fit into + // memory + // this seems very unprudent, but correct, for now + ArrayList>> sortedInput = Lists.newArrayList(); + for (WindowedValue> inputValue: elements) { + for (WindowedValue> exploded: inputValue.explodeWindows()) { + sortedInput.add(exploded); + } + } + Collections.sort(sortedInput, new Comparator>>() { + @Override + public int compare( + WindowedValue> o1, + WindowedValue> o2) { + return Iterables.getOnlyElement(o1.getWindows()).maxTimestamp() + .compareTo(Iterables.getOnlyElement(o2.getWindows()).maxTimestamp()); + } + }); + + // iterate over the elements that are sorted by window timestamp + // + final Iterator>> iterator = sortedInput.iterator(); - final Iterator> iterator = elements.iterator(); // create accumulator using the first elements key - KV first = iterator.next(); - K key = first.getKey(); - VI value = first.getValue(); - VA accumulator = keyedCombineFn.createAccumulator(key); - accumulator = keyedCombineFn.addInput(key, accumulator, value); - - while(iterator.hasNext()) { - value = iterator.next().getValue(); - accumulator = keyedCombineFn.addInput(key, accumulator, value); + WindowedValue> currentValue = iterator.next(); + K key = currentValue.getValue().getKey(); + BoundedWindow currentWindow = Iterables.getFirst(currentValue.getWindows(), null); + InputT firstValue = currentValue.getValue().getValue(); + processContext = processContext.forWindowedValue(currentValue); + AccumT accumulator = combineFnRunner.createAccumulator(key, processContext); + accumulator = combineFnRunner.addInput(key, accumulator, firstValue, processContext); + + // we use this to keep track of the timestamps assigned by the OutputTimeFn + Instant windowTimestamp = + outputTimeFn.assignOutputTime(currentValue.getTimestamp(), currentWindow); + + while (iterator.hasNext()) { + WindowedValue> nextValue = iterator.next(); + BoundedWindow nextWindow = Iterables.getOnlyElement(nextValue.getWindows()); + + if (nextWindow.equals(currentWindow)) { + // continue accumulating + InputT value = nextValue.getValue().getValue(); + processContext = processContext.forWindowedValue(nextValue); + accumulator = combineFnRunner.addInput(key, accumulator, value, processContext); + + windowTimestamp = outputTimeFn.combine( + windowTimestamp, + outputTimeFn.assignOutputTime(nextValue.getTimestamp(), currentWindow)); + + } else { + // emit the value that we currently have + out.collect( + WindowedValue.of( + KV.of(key, accumulator), + windowTimestamp, + currentWindow, + PaneInfo.NO_FIRING)); + + currentWindow = nextWindow; + InputT value = nextValue.getValue().getValue(); + processContext = processContext.forWindowedValue(nextValue); + accumulator = combineFnRunner.createAccumulator(key, processContext); + accumulator = combineFnRunner.addInput(key, accumulator, value, processContext); + windowTimestamp = outputTimeFn.assignOutputTime(nextValue.getTimestamp(), currentWindow); + } } - out.collect(KV.of(key, accumulator)); + // emit the final accumulator + out.collect( + WindowedValue.of( + KV.of(key, accumulator), + windowTimestamp, + currentWindow, + PaneInfo.NO_FIRING)); } } diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkProcessContext.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkProcessContext.java new file mode 100644 index 000000000000..0f1885ca5192 --- /dev/null +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkProcessContext.java @@ -0,0 +1,324 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.functions; + +import org.apache.beam.runners.flink.translation.wrappers.SerializableFnAggregatorWrapper; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.transforms.Aggregator; +import org.apache.beam.sdk.transforms.Combine; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.PaneInfo; +import org.apache.beam.sdk.util.TimerInternals; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.util.WindowingInternals; +import org.apache.beam.sdk.util.WindowingStrategy; +import org.apache.beam.sdk.util.state.StateInternals; +import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.TupleTag; + +import com.google.common.base.Preconditions; +import com.google.common.collect.Iterables; + +import org.apache.flink.api.common.functions.RuntimeContext; +import org.apache.flink.util.Collector; +import org.joda.time.Instant; + +import java.io.IOException; +import java.util.Collection; +import java.util.Collections; +import java.util.Iterator; +import java.util.Map; + +/** + * {@link org.apache.beam.sdk.transforms.DoFn.ProcessContext} for our Flink Wrappers. + */ +class FlinkProcessContext + extends DoFn.ProcessContext { + + private final PipelineOptions pipelineOptions; + private final RuntimeContext runtimeContext; + private Collector> collector; + private final boolean requiresWindowAccess; + + protected WindowedValue windowedValue; + + protected WindowingStrategy windowingStrategy; + + private final Map, WindowingStrategy> sideInputs; + + FlinkProcessContext( + PipelineOptions pipelineOptions, + RuntimeContext runtimeContext, + DoFn doFn, + WindowingStrategy windowingStrategy, + Collector> collector, + Map, WindowingStrategy> sideInputs) { + doFn.super(); + Preconditions.checkNotNull(pipelineOptions); + Preconditions.checkNotNull(runtimeContext); + Preconditions.checkNotNull(doFn); + Preconditions.checkNotNull(collector); + + this.pipelineOptions = pipelineOptions; + this.runtimeContext = runtimeContext; + this.collector = collector; + this.requiresWindowAccess = doFn instanceof DoFn.RequiresWindowAccess; + this.windowingStrategy = windowingStrategy; + this.sideInputs = sideInputs; + + super.setupDelegateAggregators(); + } + + FlinkProcessContext( + PipelineOptions pipelineOptions, + RuntimeContext runtimeContext, + DoFn doFn, + WindowingStrategy windowingStrategy, + Map, WindowingStrategy> sideInputs) { + doFn.super(); + Preconditions.checkNotNull(pipelineOptions); + Preconditions.checkNotNull(runtimeContext); + Preconditions.checkNotNull(doFn); + + this.pipelineOptions = pipelineOptions; + this.runtimeContext = runtimeContext; + this.collector = null; + this.requiresWindowAccess = doFn instanceof DoFn.RequiresWindowAccess; + this.windowingStrategy = windowingStrategy; + this.sideInputs = sideInputs; + + super.setupDelegateAggregators(); + } + + public FlinkProcessContext forOutput( + Collector> collector) { + this.collector = collector; + + // for now, returns ourselves, to be easy on the GC + return this; + } + + + + public FlinkProcessContext forWindowedValue( + WindowedValue windowedValue) { + this.windowedValue = windowedValue; + + // for now, returns ourselves, to be easy on the GC + return this; + } + + @Override + public InputT element() { + return this.windowedValue.getValue(); + } + + + @Override + public Instant timestamp() { + return windowedValue.getTimestamp(); + } + + @Override + public BoundedWindow window() { + if (!requiresWindowAccess) { + throw new UnsupportedOperationException( + "window() is only available in the context of a DoFn marked as RequiresWindow."); + } + return Iterables.getOnlyElement(windowedValue.getWindows()); + } + + @Override + public PaneInfo pane() { + return windowedValue.getPane(); + } + + @Override + public WindowingInternals windowingInternals() { + + return new WindowingInternals() { + + @Override + public StateInternals stateInternals() { + throw new UnsupportedOperationException(); + } + + @Override + public void outputWindowedValue( + OutputT value, + Instant timestamp, + Collection windows, + PaneInfo pane) { + collector.collect(WindowedValue.of(value, timestamp, windows, pane)); + outputWithTimestampAndWindow(value, timestamp, windows, pane); + } + + @Override + public TimerInternals timerInternals() { + throw new UnsupportedOperationException(); + } + + @Override + public Collection windows() { + return windowedValue.getWindows(); + } + + @Override + public PaneInfo pane() { + return windowedValue.getPane(); + } + + @Override + public void writePCollectionViewData(TupleTag tag, + Iterable> data, Coder elemCoder) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public ViewT sideInput( + PCollectionView view, + BoundedWindow mainInputWindow) { + + Preconditions.checkNotNull(view, "View passed to sideInput cannot be null"); + Preconditions.checkNotNull( + sideInputs.get(view), + "Side input for " + view + " not available."); + + // get the side input strategy for mapping the window + WindowingStrategy windowingStrategy = sideInputs.get(view); + + BoundedWindow sideInputWindow = + windowingStrategy.getWindowFn().getSideInputWindow(mainInputWindow); + + Map sideInputs = + runtimeContext.getBroadcastVariableWithInitializer( + view.getTagInternal().getId(), new SideInputInitializer<>(view)); + return sideInputs.get(sideInputWindow); + } + }; + } + + @Override + public PipelineOptions getPipelineOptions() { + return pipelineOptions; + } + + @Override + public ViewT sideInput(PCollectionView view) { + Preconditions.checkNotNull(view, "View passed to sideInput cannot be null"); + Preconditions.checkNotNull(sideInputs.get(view), "Side input for " + view + " not available."); + Iterator windowIter = windowedValue.getWindows().iterator(); + BoundedWindow window; + if (!windowIter.hasNext()) { + throw new IllegalStateException( + "sideInput called when main input element is not in any windows"); + } else { + window = windowIter.next(); + if (windowIter.hasNext()) { + throw new IllegalStateException( + "sideInput called when main input element is in multiple windows"); + } + } + + // get the side input strategy for mapping the window + WindowingStrategy windowingStrategy = sideInputs.get(view); + + BoundedWindow sideInputWindow = + windowingStrategy.getWindowFn().getSideInputWindow(window); + + Map sideInputs = + runtimeContext.getBroadcastVariableWithInitializer( + view.getTagInternal().getId(), new SideInputInitializer<>(view)); + ViewT result = sideInputs.get(sideInputWindow); + if (result == null) { + result = view.fromIterableInternal(Collections.>emptyList()); + } + return result; + } + + @Override + public void output(OutputT value) { + if (windowedValue != null) { + outputWithTimestamp(value, windowedValue.getTimestamp()); + } else { + outputWithTimestamp(value, null); + } + } + + @Override + public void outputWithTimestamp(OutputT value, Instant timestamp) { + if (windowedValue == null) { + // we are in startBundle() or finishBundle() + + try { + Collection windows = windowingStrategy.getWindowFn().assignWindows( + new FlinkNoElementAssignContext( + windowingStrategy.getWindowFn(), + value, + timestamp)); + + collector.collect( + WindowedValue.of( + value, + timestamp != null ? timestamp : new Instant(Long.MIN_VALUE), + windows, + PaneInfo.NO_FIRING)); + } catch (Exception e) { + throw new RuntimeException(e); + } + } else { + collector.collect( + WindowedValue.of( + value, + timestamp, + windowedValue.getWindows(), + windowedValue.getPane())); + } + } + + protected void outputWithTimestampAndWindow( + OutputT value, + Instant timestamp, + Collection windows, + PaneInfo pane) { + collector.collect( + WindowedValue.of( + value, timestamp, windows, pane)); + } + + @Override + public void sideOutput(TupleTag tag, T output) { + throw new UnsupportedOperationException(); + } + + @Override + public void sideOutputWithTimestamp(TupleTag tag, T output, Instant timestamp) { + sideOutput(tag, output); + } + + @Override + protected Aggregator + createAggregatorInternal(String name, Combine.CombineFn combiner) { + SerializableFnAggregatorWrapper wrapper = + new SerializableFnAggregatorWrapper<>(combiner); + runtimeContext.addAccumulator(name, wrapper); + return wrapper; + } +} diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkReduceFunction.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkReduceFunction.java index 43e458fc3720..9cbc6b914765 100644 --- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkReduceFunction.java +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkReduceFunction.java @@ -17,43 +17,179 @@ */ package org.apache.beam.runners.flink.translation.functions; -import org.apache.beam.sdk.transforms.Combine; +import org.apache.beam.runners.flink.translation.utils.SerializedPipelineOptions; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.transforms.CombineFnBase; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.OutputTimeFn; +import org.apache.beam.sdk.transforms.windowing.PaneInfo; +import org.apache.beam.sdk.util.PerKeyCombineFnRunner; +import org.apache.beam.sdk.util.PerKeyCombineFnRunners; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.util.WindowingStrategy; import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollectionView; import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; +import com.google.common.collect.Lists; -import org.apache.flink.api.common.functions.GroupReduceFunction; +import org.apache.flink.api.common.functions.RichGroupReduceFunction; import org.apache.flink.util.Collector; +import org.joda.time.Instant; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; import java.util.Iterator; +import java.util.List; +import java.util.Map; /** - * Flink {@link org.apache.flink.api.common.functions.GroupReduceFunction} for executing a - * {@link org.apache.beam.sdk.transforms.Combine.PerKey} operation. This reads the input - * {@link org.apache.beam.sdk.values.KV} elements, extracts the key and merges the - * accumulators resulting from the PartialReduce which produced the input VA. + * This is the second part for executing a {@link org.apache.beam.sdk.transforms.Combine.PerKey} + * on Flink, the second part is {@link FlinkReduceFunction}. This function performs the final + * combination of the pre-combined values after a shuffle. + * + *

    The input to {@link #reduce(Iterable, Collector)} are elements of the same key but + * for different windows. We have to ensure that we only combine elements of matching + * windows. */ -public class FlinkReduceFunction implements GroupReduceFunction, KV> { +public class FlinkReduceFunction + extends RichGroupReduceFunction>, WindowedValue>> { + + protected final CombineFnBase.PerKeyCombineFn combineFn; + + protected final DoFn, KV> doFn; + + protected final WindowingStrategy windowingStrategy; + + protected final Map, WindowingStrategy> sideInputs; + + protected final SerializedPipelineOptions serializedOptions; - private final Combine.KeyedCombineFn keyedCombineFn; + public FlinkReduceFunction( + CombineFnBase.PerKeyCombineFn keyedCombineFn, + WindowingStrategy windowingStrategy, + Map, WindowingStrategy> sideInputs, + PipelineOptions pipelineOptions) { - public FlinkReduceFunction(Combine.KeyedCombineFn keyedCombineFn) { - this.keyedCombineFn = keyedCombineFn; + this.combineFn = keyedCombineFn; + + this.windowingStrategy = windowingStrategy; + this.sideInputs = sideInputs; + + this.serializedOptions = new SerializedPipelineOptions(pipelineOptions); + + // dummy DoFn because we need one for ProcessContext + this.doFn = new DoFn, KV>() { + @Override + public void processElement(ProcessContext c) throws Exception { + + } + }; } @Override - public void reduce(Iterable> values, Collector> out) throws Exception { - Iterator> it = values.iterator(); + public void reduce( + Iterable>> elements, + Collector>> out) throws Exception { + + FlinkProcessContext, KV> processContext = + new FlinkProcessContext<>( + serializedOptions.getPipelineOptions(), + getRuntimeContext(), + doFn, + windowingStrategy, + out, + sideInputs); + + PerKeyCombineFnRunner combineFnRunner = + PerKeyCombineFnRunners.create(combineFn); - KV current = it.next(); - K k = current.getKey(); - VA accumulator = current.getValue(); + @SuppressWarnings("unchecked") + OutputTimeFn outputTimeFn = + (OutputTimeFn) windowingStrategy.getOutputTimeFn(); - while (it.hasNext()) { - current = it.next(); - keyedCombineFn.mergeAccumulators(k, ImmutableList.of(accumulator, current.getValue()) ); + + // get all elements so that we can sort them, has to fit into + // memory + // this seems very unprudent, but correct, for now + ArrayList>> sortedInput = Lists.newArrayList(); + for (WindowedValue> inputValue: elements) { + for (WindowedValue> exploded: inputValue.explodeWindows()) { + sortedInput.add(exploded); + } + } + Collections.sort(sortedInput, new Comparator>>() { + @Override + public int compare( + WindowedValue> o1, + WindowedValue> o2) { + return Iterables.getOnlyElement(o1.getWindows()).maxTimestamp() + .compareTo(Iterables.getOnlyElement(o2.getWindows()).maxTimestamp()); + } + }); + + // iterate over the elements that are sorted by window timestamp + // + final Iterator>> iterator = sortedInput.iterator(); + + // get the first accumulator + WindowedValue> currentValue = iterator.next(); + K key = currentValue.getValue().getKey(); + BoundedWindow currentWindow = Iterables.getFirst(currentValue.getWindows(), null); + AccumT accumulator = currentValue.getValue().getValue(); + + // we use this to keep track of the timestamps assigned by the OutputTimeFn, + // in FlinkPartialReduceFunction we already merge the timestamps assigned + // to individual elements, here we just merge them + List windowTimestamps = new ArrayList<>(); + windowTimestamps.add(currentValue.getTimestamp()); + + while (iterator.hasNext()) { + WindowedValue> nextValue = iterator.next(); + BoundedWindow nextWindow = Iterables.getOnlyElement(nextValue.getWindows()); + + if (nextWindow.equals(currentWindow)) { + // continue accumulating + processContext = processContext.forWindowedValue(nextValue); + accumulator = combineFnRunner.mergeAccumulators( + key, ImmutableList.of(accumulator, nextValue.getValue().getValue()), processContext); + + windowTimestamps.add(nextValue.getTimestamp()); + } else { + // emit the value that we currently have + processContext = processContext.forWindowedValue(currentValue); + out.collect( + WindowedValue.of( + KV.of(key, combineFnRunner.extractOutput(key, accumulator, processContext)), + outputTimeFn.merge(currentWindow, windowTimestamps), + currentWindow, + PaneInfo.NO_FIRING)); + + windowTimestamps.clear(); + + currentWindow = nextWindow; + accumulator = nextValue.getValue().getValue(); + windowTimestamps.add(nextValue.getTimestamp()); + } + + // we have to keep track so that we can set the context to the right + // windowed value when windows change in the iterable + currentValue = nextValue; } - out.collect(KV.of(k, keyedCombineFn.extractOutput(k, accumulator))); + // if at the end of the iteration we have a change in windows + // the ProcessContext will not have been updated + processContext = processContext.forWindowedValue(currentValue); + + // emit the final accumulator + out.collect( + WindowedValue.of( + KV.of(key, combineFnRunner.extractOutput(key, accumulator, processContext)), + outputTimeFn.merge(currentWindow, windowTimestamps), + currentWindow, + PaneInfo.NO_FIRING)); } } diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/SideInputInitializer.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/SideInputInitializer.java new file mode 100644 index 000000000000..451b31b12c5e --- /dev/null +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/SideInputInitializer.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.functions; + +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.values.PCollectionView; + +import org.apache.flink.api.common.functions.BroadcastVariableInitializer; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * {@link BroadcastVariableInitializer} that initializes the broadcast input as a {@code Map} + * from window to side input. + */ +public class SideInputInitializer + implements BroadcastVariableInitializer, Map> { + + PCollectionView view; + + public SideInputInitializer(PCollectionView view) { + this.view = view; + } + + @Override + public Map initializeBroadcastVariable( + Iterable> inputValues) { + + // first partition into windows + Map>> partitionedElements = new HashMap<>(); + for (WindowedValue value: inputValues) { + for (BoundedWindow window: value.getWindows()) { + List> windowedValues = partitionedElements.get(window); + if (windowedValues == null) { + windowedValues = new ArrayList<>(); + partitionedElements.put(window, windowedValues); + } + windowedValues.add(value); + } + } + + Map resultMap = new HashMap<>(); + + for (Map.Entry>> elements: + partitionedElements.entrySet()) { + + @SuppressWarnings("unchecked") + Iterable> elementsIterable = + (List>) (List) elements.getValue(); + + resultMap.put(elements.getKey(), view.fromIterableInternal(elementsIterable)); + } + + return resultMap; + } +} diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/UnionCoder.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/UnionCoder.java deleted file mode 100644 index cc6fd8b70917..000000000000 --- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/UnionCoder.java +++ /dev/null @@ -1,152 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.flink.translation.functions; - - -import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.coders.StandardCoder; -import org.apache.beam.sdk.transforms.join.RawUnionValue; -import org.apache.beam.sdk.util.PropertyNames; -import org.apache.beam.sdk.util.VarInt; -import org.apache.beam.sdk.util.common.ElementByteSizeObserver; - -import com.fasterxml.jackson.annotation.JsonCreator; -import com.fasterxml.jackson.annotation.JsonProperty; - -import java.io.IOException; -import java.io.InputStream; -import java.io.OutputStream; -import java.util.List; - -/** - * A UnionCoder encodes RawUnionValues. - * - * This file copied from {@link org.apache.beam.sdk.transforms.join.UnionCoder} - */ -@SuppressWarnings("serial") -public class UnionCoder extends StandardCoder { - // TODO: Think about how to integrate this with a schema object (i.e. - // a tuple of tuple tags). - /** - * Builds a union coder with the given list of element coders. This list - * corresponds to a mapping of union tag to Coder. Union tags start at 0. - */ - public static UnionCoder of(List> elementCoders) { - return new UnionCoder(elementCoders); - } - - @JsonCreator - public static UnionCoder jsonOf( - @JsonProperty(PropertyNames.COMPONENT_ENCODINGS) - List> elements) { - return UnionCoder.of(elements); - } - - private int getIndexForEncoding(RawUnionValue union) { - if (union == null) { - throw new IllegalArgumentException("cannot encode a null tagged union"); - } - int index = union.getUnionTag(); - if (index < 0 || index >= elementCoders.size()) { - throw new IllegalArgumentException( - "union value index " + index + " not in range [0.." + - (elementCoders.size() - 1) + "]"); - } - return index; - } - - @SuppressWarnings("unchecked") - @Override - public void encode( - RawUnionValue union, - OutputStream outStream, - Context context) - throws IOException { - int index = getIndexForEncoding(union); - // Write out the union tag. - VarInt.encode(index, outStream); - - // Write out the actual value. - Coder coder = (Coder) elementCoders.get(index); - coder.encode( - union.getValue(), - outStream, - context); - } - - @Override - public RawUnionValue decode(InputStream inStream, Context context) - throws IOException { - int index = VarInt.decodeInt(inStream); - Object value = elementCoders.get(index).decode(inStream, context); - return new RawUnionValue(index, value); - } - - @Override - public List> getCoderArguments() { - return null; - } - - @Override - public List> getComponents() { - return elementCoders; - } - - /** - * Since this coder uses elementCoders.get(index) and coders that are known to run in constant - * time, we defer the return value to that coder. - */ - @Override - public boolean isRegisterByteSizeObserverCheap(RawUnionValue union, Context context) { - int index = getIndexForEncoding(union); - @SuppressWarnings("unchecked") - Coder coder = (Coder) elementCoders.get(index); - return coder.isRegisterByteSizeObserverCheap(union.getValue(), context); - } - - /** - * Notifies ElementByteSizeObserver about the byte size of the encoded value using this coder. - */ - @Override - public void registerByteSizeObserver( - RawUnionValue union, ElementByteSizeObserver observer, Context context) - throws Exception { - int index = getIndexForEncoding(union); - // Write out the union tag. - observer.update(VarInt.getLength(index)); - // Write out the actual value. - @SuppressWarnings("unchecked") - Coder coder = (Coder) elementCoders.get(index); - coder.registerByteSizeObserver(union.getValue(), observer, context); - } - - ///////////////////////////////////////////////////////////////////////////// - - private final List> elementCoders; - - private UnionCoder(List> elementCoders) { - this.elementCoders = elementCoders; - } - - @Override - public void verifyDeterministic() throws NonDeterministicException { - verifyDeterministic( - "UnionCoder is only deterministic if all element coders are", - elementCoders); - } -} diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeInformation.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeInformation.java index 895ecef1b92e..4434cf8726e1 100644 --- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeInformation.java +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeInformation.java @@ -18,7 +18,8 @@ package org.apache.beam.runners.flink.translation.types; import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.coders.VoidCoder; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.util.WindowedValue; import com.google.common.base.Preconditions; @@ -71,9 +72,6 @@ public boolean isKeyType() { @Override @SuppressWarnings("unchecked") public TypeSerializer createSerializer(ExecutionConfig config) { - if (coder instanceof VoidCoder) { - return (TypeSerializer) new VoidCoderTypeSerializer(); - } return new CoderTypeSerializer<>(coder); } @@ -84,8 +82,12 @@ public int getTotalFields() { @Override public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } CoderTypeInformation that = (CoderTypeInformation) o; @@ -113,6 +115,11 @@ public String toString() { @Override public TypeComparator createComparator(boolean sortOrderAscending, ExecutionConfig executionConfig) { - return new CoderComparator<>(coder); + WindowedValue.WindowedValueCoder windowCoder = (WindowedValue.WindowedValueCoder) coder; + if (windowCoder.getValueCoder() instanceof KvCoder) { + return new KvCoderComperator(windowCoder); + } else { + return new CoderComparator<>(coder); + } } } diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeSerializer.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeSerializer.java index c6f3921971a6..097316b242fd 100644 --- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeSerializer.java +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeSerializer.java @@ -33,7 +33,7 @@ /** * Flink {@link org.apache.flink.api.common.typeutils.TypeSerializer} for - * Dataflow {@link org.apache.beam.sdk.coders.Coder}s + * Dataflow {@link org.apache.beam.sdk.coders.Coder Coders}. */ public class CoderTypeSerializer extends TypeSerializer { @@ -128,14 +128,20 @@ public T deserialize(T t, DataInputView dataInputView) throws IOException { } @Override - public void copy(DataInputView dataInputView, DataOutputView dataOutputView) throws IOException { + public void copy( + DataInputView dataInputView, + DataOutputView dataOutputView) throws IOException { serialize(deserialize(dataInputView), dataOutputView); } @Override public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } CoderTypeSerializer that = (CoderTypeSerializer) o; return coder.equals(that.coder); diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/types/KvCoderComperator.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/types/KvCoderComperator.java index 6f0c651406a2..79b127d1062c 100644 --- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/types/KvCoderComperator.java +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/types/KvCoderComperator.java @@ -20,6 +20,8 @@ import org.apache.beam.runners.flink.translation.wrappers.DataInputViewWrapper; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.transforms.windowing.Window; +import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.KV; import org.apache.flink.api.common.typeutils.TypeComparator; @@ -31,14 +33,13 @@ import java.io.ObjectInputStream; /** - * Flink {@link org.apache.flink.api.common.typeutils.TypeComparator} for - * {@link org.apache.beam.sdk.coders.KvCoder}. We have a special comparator + * Flink {@link TypeComparator} for {@link KvCoder}. We have a special comparator * for {@link KV} that always compares on the key only. */ -public class KvCoderComperator extends TypeComparator> { +public class KvCoderComperator extends TypeComparator>> { - private KvCoder coder; - private Coder keyCoder; + private final WindowedValue.WindowedValueCoder> coder; + private final Coder keyCoder; // We use these for internal encoding/decoding for creating copies and comparing // serialized forms using a Coder @@ -52,9 +53,10 @@ public class KvCoderComperator extends TypeComparator> { // For deserializing the key private transient DataInputViewWrapper inputWrapper; - public KvCoderComperator(KvCoder coder) { + public KvCoderComperator(WindowedValue.WindowedValueCoder> coder) { this.coder = coder; - this.keyCoder = coder.getKeyCoder(); + KvCoder kvCoder = (KvCoder) coder.getValueCoder(); + this.keyCoder = kvCoder.getKeyCoder(); buffer1 = new InspectableByteArrayOutputStream(); buffer2 = new InspectableByteArrayOutputStream(); @@ -74,8 +76,8 @@ private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundE } @Override - public int hash(KV record) { - K key = record.getKey(); + public int hash(WindowedValue> record) { + K key = record.getValue().getKey(); if (key != null) { return key.hashCode(); } else { @@ -84,27 +86,27 @@ public int hash(KV record) { } @Override - public void setReference(KV toCompare) { + public void setReference(WindowedValue> toCompare) { referenceBuffer.reset(); try { - keyCoder.encode(toCompare.getKey(), referenceBuffer, Coder.Context.OUTER); + keyCoder.encode(toCompare.getValue().getKey(), referenceBuffer, Coder.Context.OUTER); } catch (IOException e) { throw new RuntimeException("Could not set reference " + toCompare + ": " + e); } } @Override - public boolean equalToReference(KV candidate) { + public boolean equalToReference(WindowedValue> candidate) { try { buffer2.reset(); - keyCoder.encode(candidate.getKey(), buffer2, Coder.Context.OUTER); + keyCoder.encode(candidate.getValue().getKey(), buffer2, Coder.Context.OUTER); byte[] arr = referenceBuffer.getBuffer(); byte[] arrOther = buffer2.getBuffer(); if (referenceBuffer.size() != buffer2.size()) { return false; } int len = buffer2.size(); - for(int i = 0; i < len; i++ ) { + for (int i = 0; i < len; i++) { if (arr[i] != arrOther[i]) { return false; } @@ -116,8 +118,9 @@ public boolean equalToReference(KV candidate) { } @Override - public int compareToReference(TypeComparator> other) { - InspectableByteArrayOutputStream otherReferenceBuffer = ((KvCoderComperator) other).referenceBuffer; + public int compareToReference(TypeComparator>> other) { + InspectableByteArrayOutputStream otherReferenceBuffer = + ((KvCoderComperator) other).referenceBuffer; byte[] arr = referenceBuffer.getBuffer(); byte[] arrOther = otherReferenceBuffer.getBuffer(); @@ -135,19 +138,19 @@ public int compareToReference(TypeComparator> other) { @Override - public int compare(KV first, KV second) { + public int compare(WindowedValue> first, WindowedValue> second) { try { buffer1.reset(); buffer2.reset(); - keyCoder.encode(first.getKey(), buffer1, Coder.Context.OUTER); - keyCoder.encode(second.getKey(), buffer2, Coder.Context.OUTER); + keyCoder.encode(first.getValue().getKey(), buffer1, Coder.Context.OUTER); + keyCoder.encode(second.getValue().getKey(), buffer2, Coder.Context.OUTER); byte[] arr = buffer1.getBuffer(); byte[] arrOther = buffer2.getBuffer(); if (buffer1.size() != buffer2.size()) { return buffer1.size() - buffer2.size(); } int len = buffer1.size(); - for(int i = 0; i < len; i++ ) { + for (int i = 0; i < len; i++) { if (arr[i] != arrOther[i]) { return arr[i] - arrOther[i]; } @@ -159,38 +162,19 @@ public int compare(KV first, KV second) { } @Override - public int compareSerialized(DataInputView firstSource, DataInputView secondSource) throws IOException { - + public int compareSerialized( + DataInputView firstSource, + DataInputView secondSource) throws IOException { inputWrapper.setInputView(firstSource); - K firstKey = keyCoder.decode(inputWrapper, Coder.Context.NESTED); + WindowedValue> first = coder.decode(inputWrapper, Coder.Context.NESTED); inputWrapper.setInputView(secondSource); - K secondKey = keyCoder.decode(inputWrapper, Coder.Context.NESTED); - - try { - buffer1.reset(); - buffer2.reset(); - keyCoder.encode(firstKey, buffer1, Coder.Context.OUTER); - keyCoder.encode(secondKey, buffer2, Coder.Context.OUTER); - byte[] arr = buffer1.getBuffer(); - byte[] arrOther = buffer2.getBuffer(); - if (buffer1.size() != buffer2.size()) { - return buffer1.size() - buffer2.size(); - } - int len = buffer1.size(); - for(int i = 0; i < len; i++ ) { - if (arr[i] != arrOther[i]) { - return arr[i] - arrOther[i]; - } - } - return 0; - } catch (IOException e) { - throw new RuntimeException("Could not compare reference.", e); - } + WindowedValue> second = coder.decode(inputWrapper, Coder.Context.NESTED); + return compare(first, second); } @Override public boolean supportsNormalizedKey() { - return true; + return false; } @Override @@ -209,12 +193,18 @@ public boolean isNormalizedKeyPrefixOnly(int keyBytes) { } @Override - public void putNormalizedKey(KV record, MemorySegment target, int offset, int numBytes) { + public void putNormalizedKey( + WindowedValue> record, + MemorySegment target, + int offset, + int numBytes) { + buffer1.reset(); try { - keyCoder.encode(record.getKey(), buffer1, Coder.Context.NESTED); + keyCoder.encode(record.getValue().getKey(), buffer1, Coder.Context.NESTED); } catch (IOException e) { - throw new RuntimeException("Could not serializer " + record + " using coder " + coder + ": " + e); + throw new RuntimeException( + "Could not serializer " + record + " using coder " + coder + ": " + e); } final byte[] data = buffer1.getBuffer(); final int limit = offset + numBytes; @@ -231,12 +221,16 @@ public void putNormalizedKey(KV record, MemorySegment target, int offset, } @Override - public void writeWithKeyNormalization(KV record, DataOutputView target) throws IOException { + public void writeWithKeyNormalization( + WindowedValue> record, + DataOutputView target) throws IOException { throw new UnsupportedOperationException(); } @Override - public KV readWithKeyDenormalization(KV reuse, DataInputView source) throws IOException { + public WindowedValue> readWithKeyDenormalization( + WindowedValue> reuse, + DataInputView source) throws IOException { throw new UnsupportedOperationException(); } @@ -246,14 +240,14 @@ public boolean invertNormalizedKey() { } @Override - public TypeComparator> duplicate() { + public TypeComparator>> duplicate() { return new KvCoderComperator<>(coder); } @Override public int extractKeys(Object record, Object[] target, int index) { - KV kv = (KV) record; - K k = kv.getKey(); + WindowedValue> kv = (WindowedValue>) record; + K k = kv.getValue().getKey(); target[index] = k; return 1; } diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/types/KvCoderTypeInformation.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/types/KvCoderTypeInformation.java index 74f3821dfb2f..ba53f640bb81 100644 --- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/types/KvCoderTypeInformation.java +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/types/KvCoderTypeInformation.java @@ -18,6 +18,7 @@ package org.apache.beam.runners.flink.translation.types; import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.KV; import com.google.common.base.Preconditions; @@ -31,27 +32,32 @@ import java.util.List; /** - * Flink {@link org.apache.flink.api.common.typeinfo.TypeInformation} for - * Dataflow {@link org.apache.beam.sdk.coders.KvCoder}. + * Flink {@link TypeInformation} for {@link KvCoder}. This creates special comparator + * for {@link KV} that always compares on the key only. */ -public class KvCoderTypeInformation extends CompositeType> { +public class KvCoderTypeInformation extends CompositeType>> { - private KvCoder coder; + private final WindowedValue.WindowedValueCoder> coder; +// private KvCoder coder; // We don't have the Class, so we have to pass null here. What a shame... - private static Object DUMMY = new Object(); + private static Object dummy = new Object(); @SuppressWarnings("unchecked") - public KvCoderTypeInformation(KvCoder coder) { - super(((Class>) DUMMY.getClass())); + public KvCoderTypeInformation(WindowedValue.WindowedValueCoder> coder) { + super((Class) dummy.getClass()); this.coder = coder; Preconditions.checkNotNull(coder); } @Override @SuppressWarnings("unchecked") - public TypeComparator> createComparator(int[] logicalKeyFields, boolean[] orders, int logicalFieldOffset, ExecutionConfig config) { - return new KvCoderComperator((KvCoder) coder); + public TypeComparator>> createComparator( + int[] logicalKeyFields, + boolean[] orders, + int logicalFieldOffset, + ExecutionConfig config) { + return new KvCoderComperator(coder); } @Override @@ -71,7 +77,7 @@ public int getArity() { @Override @SuppressWarnings("unchecked") - public Class> getTypeClass() { + public Class>> getTypeClass() { return privateGetTypeClass(); } @@ -87,7 +93,7 @@ public boolean isKeyType() { @Override @SuppressWarnings("unchecked") - public TypeSerializer> createSerializer(ExecutionConfig config) { + public TypeSerializer>> createSerializer(ExecutionConfig config) { return new CoderTypeSerializer<>(coder); } @@ -98,8 +104,12 @@ public int getTotalFields() { @Override public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } KvCoderTypeInformation that = (KvCoderTypeInformation) o; @@ -122,10 +132,11 @@ public String toString() { @Override @SuppressWarnings("unchecked") public TypeInformation getTypeAt(int pos) { + KvCoder kvCoder = (KvCoder) coder.getValueCoder(); if (pos == 0) { - return (TypeInformation) new CoderTypeInformation<>(coder.getKeyCoder()); + return (TypeInformation) new CoderTypeInformation<>(kvCoder.getKeyCoder()); } else if (pos == 1) { - return (TypeInformation) new CoderTypeInformation<>(coder.getValueCoder()); + return (TypeInformation) new CoderTypeInformation<>(kvCoder.getValueCoder()); } else { throw new RuntimeException("Invalid field position " + pos); } @@ -134,11 +145,12 @@ public TypeInformation getTypeAt(int pos) { @Override @SuppressWarnings("unchecked") public TypeInformation getTypeAt(String fieldExpression) { + KvCoder kvCoder = (KvCoder) coder.getValueCoder(); switch (fieldExpression) { case "key": - return (TypeInformation) new CoderTypeInformation<>(coder.getKeyCoder()); + return (TypeInformation) new CoderTypeInformation<>(kvCoder.getKeyCoder()); case "value": - return (TypeInformation) new CoderTypeInformation<>(coder.getValueCoder()); + return (TypeInformation) new CoderTypeInformation<>(kvCoder.getValueCoder()); default: throw new UnsupportedOperationException("Only KvCoder has fields."); } @@ -162,17 +174,24 @@ public int getFieldIndex(String fieldName) { } @Override - public void getFlatFields(String fieldExpression, int offset, List result) { - CoderTypeInformation keyTypeInfo = new CoderTypeInformation<>(coder.getKeyCoder()); + public void getFlatFields( + String fieldExpression, + int offset, + List result) { + KvCoder kvCoder = (KvCoder) coder.getValueCoder(); + + CoderTypeInformation keyTypeInfo = + new CoderTypeInformation<>(kvCoder.getKeyCoder()); result.add(new FlatFieldDescriptor(0, keyTypeInfo)); } @Override - protected TypeComparatorBuilder> createTypeComparatorBuilder() { + protected TypeComparatorBuilder>> createTypeComparatorBuilder() { return new KvCoderTypeComparatorBuilder(); } - private class KvCoderTypeComparatorBuilder implements TypeComparatorBuilder> { + private class KvCoderTypeComparatorBuilder + implements TypeComparatorBuilder>> { @Override public void initializeTypeComparatorBuilder(int size) {} @@ -181,7 +200,7 @@ public void initializeTypeComparatorBuilder(int size) {} public void addComparatorField(int fieldId, TypeComparator comparator) {} @Override - public TypeComparator> createTypeComparator(ExecutionConfig config) { + public TypeComparator>> createTypeComparator(ExecutionConfig config) { return new KvCoderComperator<>(coder); } } diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/types/VoidCoderTypeSerializer.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/types/VoidCoderTypeSerializer.java deleted file mode 100644 index 7b48208845fd..000000000000 --- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/types/VoidCoderTypeSerializer.java +++ /dev/null @@ -1,112 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.flink.translation.types; - -import org.apache.flink.api.common.typeutils.TypeSerializer; -import org.apache.flink.core.memory.DataInputView; -import org.apache.flink.core.memory.DataOutputView; - -import java.io.IOException; - -/** - * Special Flink {@link org.apache.flink.api.common.typeutils.TypeSerializer} for - * {@link org.apache.beam.sdk.coders.VoidCoder}. We need this because Flink does not - * allow returning {@code null} from an input reader. We return a {@link VoidValue} instead - * that behaves like a {@code null}, hopefully. - */ -public class VoidCoderTypeSerializer extends TypeSerializer { - - @Override - public boolean isImmutableType() { - return false; - } - - @Override - public VoidCoderTypeSerializer duplicate() { - return this; - } - - @Override - public VoidValue createInstance() { - return VoidValue.INSTANCE; - } - - @Override - public VoidValue copy(VoidValue from) { - return from; - } - - @Override - public VoidValue copy(VoidValue from, VoidValue reuse) { - return from; - } - - @Override - public int getLength() { - return 0; - } - - @Override - public void serialize(VoidValue record, DataOutputView target) throws IOException { - target.writeByte(1); - } - - @Override - public VoidValue deserialize(DataInputView source) throws IOException { - source.readByte(); - return VoidValue.INSTANCE; - } - - @Override - public VoidValue deserialize(VoidValue reuse, DataInputView source) throws IOException { - return deserialize(source); - } - - @Override - public void copy(DataInputView source, DataOutputView target) throws IOException { - source.readByte(); - target.writeByte(1); - } - - @Override - public boolean equals(Object obj) { - if (obj instanceof VoidCoderTypeSerializer) { - VoidCoderTypeSerializer other = (VoidCoderTypeSerializer) obj; - return other.canEqual(this); - } else { - return false; - } - } - - @Override - public boolean canEqual(Object obj) { - return obj instanceof VoidCoderTypeSerializer; - } - - @Override - public int hashCode() { - return 0; - } - - public static class VoidValue { - private VoidValue() {} - - public static VoidValue INSTANCE = new VoidValue(); - } - -} diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/CombineFnAggregatorWrapper.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/CombineFnAggregatorWrapper.java deleted file mode 100644 index e5567d3ea3b2..000000000000 --- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/CombineFnAggregatorWrapper.java +++ /dev/null @@ -1,94 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.flink.translation.wrappers; - -import org.apache.beam.sdk.transforms.Aggregator; -import org.apache.beam.sdk.transforms.Combine; - -import com.google.common.collect.Lists; - -import org.apache.flink.api.common.accumulators.Accumulator; - -import java.io.Serializable; - -/** - * Wrapper that wraps a {@link org.apache.beam.sdk.transforms.Combine.CombineFn} - * in a Flink {@link org.apache.flink.api.common.accumulators.Accumulator} for using - * the combine function as an aggregator in a {@link org.apache.beam.sdk.transforms.ParDo} - * operation. - */ -public class CombineFnAggregatorWrapper implements Aggregator, Accumulator { - - private AA aa; - private Combine.CombineFn combiner; - - public CombineFnAggregatorWrapper() { - } - - public CombineFnAggregatorWrapper(Combine.CombineFn combiner) { - this.combiner = combiner; - this.aa = combiner.createAccumulator(); - } - - @Override - public void add(AI value) { - combiner.addInput(aa, value); - } - - @Override - public Serializable getLocalValue() { - return (Serializable) combiner.extractOutput(aa); - } - - @Override - public void resetLocal() { - aa = combiner.createAccumulator(); - } - - @Override - @SuppressWarnings("unchecked") - public void merge(Accumulator other) { - aa = combiner.mergeAccumulators(Lists.newArrayList(aa, ((CombineFnAggregatorWrapper)other).aa)); - } - - @Override - public Accumulator clone() { - // copy it by merging - AA aaCopy = combiner.mergeAccumulators(Lists.newArrayList(aa)); - CombineFnAggregatorWrapper result = new - CombineFnAggregatorWrapper<>(combiner); - result.aa = aaCopy; - return result; - } - - @Override - public void addValue(AI value) { - add(value); - } - - @Override - public String getName() { - return "CombineFn: " + combiner.toString(); - } - - @Override - public Combine.CombineFn getCombineFn() { - return combiner; - } - -} diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/SerializableFnAggregatorWrapper.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/SerializableFnAggregatorWrapper.java index eb32fa2fd74a..82d3fb8ffae3 100644 --- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/SerializableFnAggregatorWrapper.java +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/SerializableFnAggregatorWrapper.java @@ -33,20 +33,21 @@ * the function as an aggregator in a {@link org.apache.beam.sdk.transforms.ParDo} * operation. */ -public class SerializableFnAggregatorWrapper implements Aggregator, Accumulator { +public class SerializableFnAggregatorWrapper + implements Aggregator, Accumulator { - private AO aa; - private Combine.CombineFn combiner; + private OutputT aa; + private Combine.CombineFn combiner; - public SerializableFnAggregatorWrapper(Combine.CombineFn combiner) { + public SerializableFnAggregatorWrapper(Combine.CombineFn combiner) { this.combiner = combiner; resetLocal(); } - + @Override @SuppressWarnings("unchecked") - public void add(AI value) { - this.aa = combiner.apply(ImmutableList.of((AI) aa, value)); + public void add(InputT value) { + this.aa = combiner.apply(ImmutableList.of((InputT) aa, value)); } @Override @@ -56,17 +57,17 @@ public Serializable getLocalValue() { @Override public void resetLocal() { - this.aa = combiner.apply(ImmutableList.of()); + this.aa = combiner.apply(ImmutableList.of()); } @Override @SuppressWarnings("unchecked") - public void merge(Accumulator other) { - this.aa = combiner.apply(ImmutableList.of((AI) aa, (AI) other.getLocalValue())); + public void merge(Accumulator other) { + this.aa = combiner.apply(ImmutableList.of((InputT) aa, (InputT) other.getLocalValue())); } @Override - public void addValue(AI value) { + public void addValue(InputT value) { add(value); } @@ -76,15 +77,15 @@ public String getName() { } @Override - public Combine.CombineFn getCombineFn() { + public Combine.CombineFn getCombineFn() { return combiner; } @Override - public Accumulator clone() { + public Accumulator clone() { // copy it by merging - AO resultCopy = combiner.apply(Lists.newArrayList((AI) aa)); - SerializableFnAggregatorWrapper result = new + OutputT resultCopy = combiner.apply(Lists.newArrayList((InputT) aa)); + SerializableFnAggregatorWrapper result = new SerializableFnAggregatorWrapper<>(combiner); result.aa = resultCopy; diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/SinkOutputFormat.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/SinkOutputFormat.java index 53e544d9e8fc..c0a71329fe3f 100644 --- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/SinkOutputFormat.java +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/SinkOutputFormat.java @@ -22,6 +22,7 @@ import org.apache.beam.sdk.io.Sink; import org.apache.beam.sdk.io.Write; import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.util.WindowedValue; import org.apache.flink.api.common.io.OutputFormat; import org.apache.flink.configuration.Configuration; @@ -31,10 +32,11 @@ import java.lang.reflect.Field; /** - * Wrapper class to use generic Write.Bound transforms as sinks. + * Wrapper for executing a {@link Sink} on Flink as an {@link OutputFormat}. + * * @param The type of the incoming records. */ -public class SinkOutputFormat implements OutputFormat { +public class SinkOutputFormat implements OutputFormat> { private final Sink sink; @@ -75,9 +77,9 @@ public void open(int taskNumber, int numTasks) throws IOException { } @Override - public void writeRecord(T record) throws IOException { + public void writeRecord(WindowedValue record) throws IOException { try { - writer.write(record); + writer.write(record.getValue()); } catch (Exception e) { throw new IOException("Couldn't write record.", e); } diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/SourceInputFormat.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/SourceInputFormat.java index debd1a14d525..1d06b1ac2fc9 100644 --- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/SourceInputFormat.java +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/SourceInputFormat.java @@ -21,12 +21,16 @@ import org.apache.beam.sdk.io.BoundedSource; import org.apache.beam.sdk.io.Source; import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; +import org.apache.beam.sdk.transforms.windowing.PaneInfo; +import org.apache.beam.sdk.util.WindowedValue; import org.apache.flink.api.common.io.DefaultInputSplitAssigner; import org.apache.flink.api.common.io.InputFormat; import org.apache.flink.api.common.io.statistics.BaseStatistics; import org.apache.flink.configuration.Configuration; import org.apache.flink.core.io.InputSplitAssigner; +import org.joda.time.Instant; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -35,10 +39,10 @@ /** - * A Flink {@link org.apache.flink.api.common.io.InputFormat} that wraps a - * Dataflow {@link org.apache.beam.sdk.io.Source}. + * Wrapper for executing a {@link Source} as a Flink {@link InputFormat}. */ -public class SourceInputFormat implements InputFormat> { +public class SourceInputFormat + implements InputFormat, SourceInputSplit> { private static final Logger LOG = LoggerFactory.getLogger(SourceInputFormat.class); private final BoundedSource initialSource; @@ -122,12 +126,16 @@ public boolean reachedEnd() throws IOException { } @Override - public T nextRecord(T t) throws IOException { + public WindowedValue nextRecord(WindowedValue t) throws IOException { if (inputAvailable) { final T current = reader.getCurrent(); + final Instant timestamp = reader.getCurrentTimestamp(); // advance reader to have a record ready next time inputAvailable = reader.advance(); - return current; + return WindowedValue.of( + current, + timestamp, + GlobalWindow.INSTANCE, PaneInfo.NO_FIRING); } return null; diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/FlinkGroupByKeyWrapper.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/FlinkGroupByKeyWrapper.java index 3bf566bce762..6b69d547cf12 100644 --- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/FlinkGroupByKeyWrapper.java +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/FlinkGroupByKeyWrapper.java @@ -18,7 +18,6 @@ package org.apache.beam.runners.flink.translation.wrappers.streaming; import org.apache.beam.runners.flink.translation.types.CoderTypeInformation; -import org.apache.beam.runners.flink.translation.types.VoidCoderTypeSerializer; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.coders.VoidCoder; @@ -54,7 +53,7 @@ public static KeyedStream>, K> groupStreamByKey(Da @Override public K getKey(WindowedValue> value) throws Exception { - return isKeyVoid ? (K) VoidCoderTypeSerializer.VoidValue.INSTANCE : + return isKeyVoid ? (K) VoidValue.INSTANCE : value.getValue().getKey(); } @@ -64,4 +63,11 @@ public TypeInformation getProducedType() { } }); } + + // special type to return as key for null key + public static class VoidValue { + private VoidValue() {} + + public static VoidValue INSTANCE = new VoidValue(); + } } diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/FlinkStreamingCreateFunction.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/FlinkStreamingCreateFunction.java index d6aff7d7a4ee..8cd8351021b4 100644 --- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/FlinkStreamingCreateFunction.java +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/FlinkStreamingCreateFunction.java @@ -17,7 +17,6 @@ */ package org.apache.beam.runners.flink.translation.wrappers.streaming.io; -import org.apache.beam.runners.flink.translation.types.VoidCoderTypeSerializer; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.transforms.windowing.GlobalWindow; import org.apache.beam.sdk.transforms.windowing.PaneInfo; @@ -47,17 +46,11 @@ public FlinkStreamingCreateFunction(List elements, Coder coder) { @Override public void flatMap(IN value, Collector> out) throws Exception { - @SuppressWarnings("unchecked") - OUT voidValue = (OUT) VoidCoderTypeSerializer.VoidValue.INSTANCE; for (byte[] element : elements) { ByteArrayInputStream bai = new ByteArrayInputStream(element); OUT outValue = coder.decode(bai, Coder.Context.OUTER); - if (outValue == null) { - out.collect(WindowedValue.of(voidValue, Instant.now(), GlobalWindow.INSTANCE, PaneInfo.NO_FIRING)); - } else { - out.collect(WindowedValue.of(outValue, Instant.now(), GlobalWindow.INSTANCE, PaneInfo.NO_FIRING)); - } + out.collect(WindowedValue.of(outValue, Instant.now(), GlobalWindow.INSTANCE, PaneInfo.NO_FIRING)); } out.close(); diff --git a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/AvroITCase.java b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/AvroITCase.java deleted file mode 100644 index 113fee0881de..000000000000 --- a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/AvroITCase.java +++ /dev/null @@ -1,129 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.beam.runners.flink; - -import org.apache.beam.sdk.Pipeline; -import org.apache.beam.sdk.coders.AvroCoder; -import org.apache.beam.sdk.io.AvroIO; -import org.apache.beam.sdk.io.TextIO; -import org.apache.beam.sdk.transforms.Create; -import org.apache.beam.sdk.transforms.DoFn; -import org.apache.beam.sdk.transforms.ParDo; - -import com.google.common.base.Joiner; - -import org.apache.flink.test.util.JavaProgramTestBase; - - -public class AvroITCase extends JavaProgramTestBase { - - protected String resultPath; - protected String tmpPath; - - public AvroITCase(){ - } - - static final String[] EXPECTED_RESULT = new String[] { - "Joe red 3", - "Mary blue 4", - "Mark green 1", - "Julia purple 5" - }; - - @Override - protected void preSubmit() throws Exception { - resultPath = getTempDirPath("result"); - tmpPath = getTempDirPath("tmp"); - - } - - @Override - protected void postSubmit() throws Exception { - compareResultsByLinesInMemory(Joiner.on('\n').join(EXPECTED_RESULT), resultPath); - } - - @Override - protected void testProgram() throws Exception { - runProgram(tmpPath, resultPath); - } - - private static void runProgram(String tmpPath, String resultPath) { - Pipeline p = FlinkTestPipeline.createForBatch(); - - p - .apply(Create.of( - new User("Joe", 3, "red"), - new User("Mary", 4, "blue"), - new User("Mark", 1, "green"), - new User("Julia", 5, "purple")) - .withCoder(AvroCoder.of(User.class))) - - .apply(AvroIO.Write.to(tmpPath) - .withSchema(User.class)); - - p.run(); - - p = FlinkTestPipeline.createForBatch(); - - p - .apply(AvroIO.Read.from(tmpPath).withSchema(User.class).withoutValidation()) - - .apply(ParDo.of(new DoFn() { - @Override - public void processElement(ProcessContext c) throws Exception { - User u = c.element(); - String result = u.getName() + " " + u.getFavoriteColor() + " " + u.getFavoriteNumber(); - c.output(result); - } - })) - - .apply(TextIO.Write.to(resultPath)); - - p.run(); - } - - private static class User { - - private String name; - private int favoriteNumber; - private String favoriteColor; - - public User() {} - - public User(String name, int favoriteNumber, String favoriteColor) { - this.name = name; - this.favoriteNumber = favoriteNumber; - this.favoriteColor = favoriteColor; - } - - public String getName() { - return name; - } - - public String getFavoriteColor() { - return favoriteColor; - } - - public int getFavoriteNumber() { - return favoriteNumber; - } - } - -} - diff --git a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/FlattenizeITCase.java b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/FlattenizeITCase.java deleted file mode 100644 index ac0a3d7d4d67..000000000000 --- a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/FlattenizeITCase.java +++ /dev/null @@ -1,76 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.flink; - -import org.apache.beam.sdk.Pipeline; -import org.apache.beam.sdk.io.TextIO; -import org.apache.beam.sdk.transforms.Create; -import org.apache.beam.sdk.transforms.Flatten; -import org.apache.beam.sdk.values.PCollection; -import org.apache.beam.sdk.values.PCollectionList; - -import com.google.common.base.Joiner; - -import org.apache.flink.test.util.JavaProgramTestBase; - -public class FlattenizeITCase extends JavaProgramTestBase { - - private String resultPath; - private String resultPath2; - - private static final String[] words = {"hello", "this", "is", "a", "DataSet!"}; - private static final String[] words2 = {"hello", "this", "is", "another", "DataSet!"}; - private static final String[] words3 = {"hello", "this", "is", "yet", "another", "DataSet!"}; - - @Override - protected void preSubmit() throws Exception { - resultPath = getTempDirPath("result"); - resultPath2 = getTempDirPath("result2"); - } - - @Override - protected void postSubmit() throws Exception { - String join = Joiner.on('\n').join(words); - String join2 = Joiner.on('\n').join(words2); - String join3 = Joiner.on('\n').join(words3); - compareResultsByLinesInMemory(join + "\n" + join2, resultPath); - compareResultsByLinesInMemory(join + "\n" + join2 + "\n" + join3, resultPath2); - } - - - @Override - protected void testProgram() throws Exception { - Pipeline p = FlinkTestPipeline.createForBatch(); - - PCollection p1 = p.apply(Create.of(words)); - PCollection p2 = p.apply(Create.of(words2)); - - PCollectionList list = PCollectionList.of(p1).and(p2); - - list.apply(Flatten.pCollections()).apply(TextIO.Write.to(resultPath)); - - PCollection p3 = p.apply(Create.of(words3)); - - PCollectionList list2 = list.and(p3); - - list2.apply(Flatten.pCollections()).apply(TextIO.Write.to(resultPath2)); - - p.run(); - } - -} diff --git a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/JoinExamplesITCase.java b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/JoinExamplesITCase.java deleted file mode 100644 index 47685b6be6f3..000000000000 --- a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/JoinExamplesITCase.java +++ /dev/null @@ -1,102 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.flink; - -import org.apache.beam.runners.flink.util.JoinExamples; -import org.apache.beam.sdk.Pipeline; -import org.apache.beam.sdk.io.TextIO; -import org.apache.beam.sdk.transforms.Create; -import org.apache.beam.sdk.values.PCollection; - -import com.google.api.services.bigquery.model.TableRow; -import com.google.common.base.Joiner; - -import org.apache.flink.test.util.JavaProgramTestBase; - -import java.util.Arrays; -import java.util.List; - - -/** - * Unfortunately we need to copy the code from the Dataflow SDK because it is not public there. - */ -public class JoinExamplesITCase extends JavaProgramTestBase { - - protected String resultPath; - - public JoinExamplesITCase(){ - } - - private static final TableRow row1 = new TableRow() - .set("ActionGeo_CountryCode", "VM").set("SQLDATE", "20141212") - .set("Actor1Name", "BANGKOK").set("SOURCEURL", "http://cnn.com"); - private static final TableRow row2 = new TableRow() - .set("ActionGeo_CountryCode", "VM").set("SQLDATE", "20141212") - .set("Actor1Name", "LAOS").set("SOURCEURL", "http://www.chicagotribune.com"); - private static final TableRow row3 = new TableRow() - .set("ActionGeo_CountryCode", "BE").set("SQLDATE", "20141213") - .set("Actor1Name", "AFGHANISTAN").set("SOURCEURL", "http://cnn.com"); - static final TableRow[] EVENTS = new TableRow[] { - row1, row2, row3 - }; - static final List EVENT_ARRAY = Arrays.asList(EVENTS); - - private static final TableRow cc1 = new TableRow() - .set("FIPSCC", "VM").set("HumanName", "Vietnam"); - private static final TableRow cc2 = new TableRow() - .set("FIPSCC", "BE").set("HumanName", "Belgium"); - static final TableRow[] CCS = new TableRow[] { - cc1, cc2 - }; - static final List CC_ARRAY = Arrays.asList(CCS); - - static final String[] JOINED_EVENTS = new String[] { - "Country code: VM, Country name: Vietnam, Event info: Date: 20141212, Actor1: LAOS, " - + "url: http://www.chicagotribune.com", - "Country code: VM, Country name: Vietnam, Event info: Date: 20141212, Actor1: BANGKOK, " - + "url: http://cnn.com", - "Country code: BE, Country name: Belgium, Event info: Date: 20141213, Actor1: AFGHANISTAN, " - + "url: http://cnn.com" - }; - - @Override - protected void preSubmit() throws Exception { - resultPath = getTempDirPath("result"); - } - - @Override - protected void postSubmit() throws Exception { - compareResultsByLinesInMemory(Joiner.on('\n').join(JOINED_EVENTS), resultPath); - } - - @Override - protected void testProgram() throws Exception { - - Pipeline p = FlinkTestPipeline.createForBatch(); - - PCollection input1 = p.apply(Create.of(EVENT_ARRAY)); - PCollection input2 = p.apply(Create.of(CC_ARRAY)); - - PCollection output = JoinExamples.joinEvents(input1, input2); - - output.apply(TextIO.Write.to(resultPath)); - - p.run(); - } -} - diff --git a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/MaybeEmptyTestITCase.java b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/MaybeEmptyTestITCase.java deleted file mode 100644 index 4d66fa421c5e..000000000000 --- a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/MaybeEmptyTestITCase.java +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.flink; - -import org.apache.beam.sdk.Pipeline; -import org.apache.beam.sdk.coders.VoidCoder; -import org.apache.beam.sdk.io.TextIO; -import org.apache.beam.sdk.transforms.Create; -import org.apache.beam.sdk.transforms.DoFn; -import org.apache.beam.sdk.transforms.ParDo; - -import org.apache.flink.test.util.JavaProgramTestBase; - -import java.io.Serializable; - -public class MaybeEmptyTestITCase extends JavaProgramTestBase implements Serializable { - - protected String resultPath; - - protected final String expected = "test"; - - public MaybeEmptyTestITCase() { - } - - @Override - protected void preSubmit() throws Exception { - resultPath = getTempDirPath("result"); - } - - @Override - protected void postSubmit() throws Exception { - compareResultsByLinesInMemory(expected, resultPath); - } - - @Override - protected void testProgram() throws Exception { - - Pipeline p = FlinkTestPipeline.createForBatch(); - - p.apply(Create.of((Void) null)).setCoder(VoidCoder.of()) - .apply(ParDo.of( - new DoFn() { - @Override - public void processElement(DoFn.ProcessContext c) { - c.output(expected); - } - })).apply(TextIO.Write.to(resultPath)); - p.run(); - } - -} diff --git a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/ParDoMultiOutputITCase.java b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/ParDoMultiOutputITCase.java deleted file mode 100644 index a2ef4e29f403..000000000000 --- a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/ParDoMultiOutputITCase.java +++ /dev/null @@ -1,102 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.flink; - -import org.apache.beam.sdk.Pipeline; -import org.apache.beam.sdk.io.TextIO; -import org.apache.beam.sdk.transforms.Create; -import org.apache.beam.sdk.transforms.DoFn; -import org.apache.beam.sdk.transforms.ParDo; -import org.apache.beam.sdk.values.PCollection; -import org.apache.beam.sdk.values.PCollectionTuple; -import org.apache.beam.sdk.values.TupleTag; -import org.apache.beam.sdk.values.TupleTagList; - -import com.google.common.base.Joiner; - -import org.apache.flink.test.util.JavaProgramTestBase; - -import java.io.Serializable; - -public class ParDoMultiOutputITCase extends JavaProgramTestBase implements Serializable { - - private String resultPath; - - private static String[] expectedWords = {"MAAA", "MAAFOOO"}; - - @Override - protected void preSubmit() throws Exception { - resultPath = getTempDirPath("result"); - } - - @Override - protected void postSubmit() throws Exception { - compareResultsByLinesInMemory(Joiner.on("\n").join(expectedWords), resultPath); - } - - @Override - protected void testProgram() throws Exception { - Pipeline p = FlinkTestPipeline.createForBatch(); - - PCollection words = p.apply(Create.of("Hello", "Whatupmyman", "hey", "SPECIALthere", "MAAA", "MAAFOOO")); - - // Select words whose length is below a cut off, - // plus the lengths of words that are above the cut off. - // Also select words starting with "MARKER". - final int wordLengthCutOff = 3; - // Create tags to use for the main and side outputs. - final TupleTag wordsBelowCutOffTag = new TupleTag(){}; - final TupleTag wordLengthsAboveCutOffTag = new TupleTag(){}; - final TupleTag markedWordsTag = new TupleTag(){}; - - PCollectionTuple results = - words.apply(ParDo - .withOutputTags(wordsBelowCutOffTag, TupleTagList.of(wordLengthsAboveCutOffTag) - .and(markedWordsTag)) - .of(new DoFn() { - final TupleTag specialWordsTag = new TupleTag() { - }; - - public void processElement(ProcessContext c) { - String word = c.element(); - if (word.length() <= wordLengthCutOff) { - c.output(word); - } else { - c.sideOutput(wordLengthsAboveCutOffTag, word.length()); - } - if (word.startsWith("MAA")) { - c.sideOutput(markedWordsTag, word); - } - - if (word.startsWith("SPECIAL")) { - c.sideOutput(specialWordsTag, word); - } - } - })); - - // Extract the PCollection results, by tag. - PCollection wordsBelowCutOff = results.get(wordsBelowCutOffTag); - PCollection wordLengthsAboveCutOff = results.get - (wordLengthsAboveCutOffTag); - PCollection markedWords = results.get(markedWordsTag); - - markedWords.apply(TextIO.Write.to(resultPath)); - - p.run(); - } -} diff --git a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/ReadSourceITCase.java b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/ReadSourceITCase.java index 66c959eea90c..bb79b270945c 100644 --- a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/ReadSourceITCase.java +++ b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/ReadSourceITCase.java @@ -28,6 +28,9 @@ import org.apache.flink.test.util.JavaProgramTestBase; +import java.io.File; +import java.net.URI; + /** * Reads from a bounded source in batch execution. */ @@ -44,6 +47,13 @@ public ReadSourceITCase(){ @Override protected void preSubmit() throws Exception { resultPath = getTempDirPath("result"); + + // need to create the dir, otherwise Beam sinks don't + // work for these tests + + if (!new File(new URI(resultPath)).mkdirs()) { + throw new RuntimeException("Could not create output dir."); + } } @Override @@ -56,7 +66,7 @@ protected void testProgram() throws Exception { runProgram(resultPath); } - private static void runProgram(String resultPath) { + private static void runProgram(String resultPath) throws Exception { Pipeline p = FlinkTestPipeline.createForBatch(); @@ -69,7 +79,7 @@ public void processElement(ProcessContext c) throws Exception { } })); - result.apply(TextIO.Write.to(resultPath)); + result.apply(TextIO.Write.to(new URI(resultPath).getPath() + "/part")); p.run(); } diff --git a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/RemoveDuplicatesEmptyITCase.java b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/RemoveDuplicatesEmptyITCase.java deleted file mode 100644 index 471d3262a36c..000000000000 --- a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/RemoveDuplicatesEmptyITCase.java +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.flink; - -import org.apache.beam.sdk.Pipeline; -import org.apache.beam.sdk.coders.StringUtf8Coder; -import org.apache.beam.sdk.io.TextIO; -import org.apache.beam.sdk.transforms.Create; -import org.apache.beam.sdk.transforms.RemoveDuplicates; -import org.apache.beam.sdk.values.PCollection; - -import com.google.common.base.Joiner; - -import org.apache.flink.test.util.JavaProgramTestBase; - -import java.util.Collections; -import java.util.List; - - -public class RemoveDuplicatesEmptyITCase extends JavaProgramTestBase { - - protected String resultPath; - - public RemoveDuplicatesEmptyITCase(){ - } - - static final String[] EXPECTED_RESULT = new String[] {}; - - @Override - protected void preSubmit() throws Exception { - resultPath = getTempDirPath("result"); - } - - @Override - protected void postSubmit() throws Exception { - compareResultsByLinesInMemory(Joiner.on('\n').join(EXPECTED_RESULT), resultPath); - } - - @Override - protected void testProgram() throws Exception { - - List strings = Collections.emptyList(); - - Pipeline p = FlinkTestPipeline.createForBatch(); - - PCollection input = - p.apply(Create.of(strings)) - .setCoder(StringUtf8Coder.of()); - - PCollection output = - input.apply(RemoveDuplicates.create()); - - output.apply(TextIO.Write.to(resultPath)); - p.run(); - } -} - diff --git a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/RemoveDuplicatesITCase.java b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/RemoveDuplicatesITCase.java deleted file mode 100644 index 0544f20eb310..000000000000 --- a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/RemoveDuplicatesITCase.java +++ /dev/null @@ -1,73 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.flink; - -import org.apache.beam.sdk.Pipeline; -import org.apache.beam.sdk.coders.StringUtf8Coder; -import org.apache.beam.sdk.io.TextIO; -import org.apache.beam.sdk.transforms.Create; -import org.apache.beam.sdk.transforms.RemoveDuplicates; -import org.apache.beam.sdk.values.PCollection; - -import com.google.common.base.Joiner; - -import org.apache.flink.test.util.JavaProgramTestBase; - -import java.util.Arrays; -import java.util.List; - - -public class RemoveDuplicatesITCase extends JavaProgramTestBase { - - protected String resultPath; - - public RemoveDuplicatesITCase(){ - } - - static final String[] EXPECTED_RESULT = new String[] { - "k1", "k5", "k2", "k3"}; - - @Override - protected void preSubmit() throws Exception { - resultPath = getTempDirPath("result"); - } - - @Override - protected void postSubmit() throws Exception { - compareResultsByLinesInMemory(Joiner.on('\n').join(EXPECTED_RESULT), resultPath); - } - - @Override - protected void testProgram() throws Exception { - - List strings = Arrays.asList("k1", "k5", "k5", "k2", "k1", "k2", "k3"); - - Pipeline p = FlinkTestPipeline.createForBatch(); - - PCollection input = - p.apply(Create.of(strings)) - .setCoder(StringUtf8Coder.of()); - - PCollection output = - input.apply(RemoveDuplicates.create()); - - output.apply(TextIO.Write.to(resultPath)); - p.run(); - } -} - diff --git a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/SideInputITCase.java b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/SideInputITCase.java deleted file mode 100644 index 2c7c65e8af3d..000000000000 --- a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/SideInputITCase.java +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.flink; - -import org.apache.beam.sdk.Pipeline; -import org.apache.beam.sdk.io.TextIO; -import org.apache.beam.sdk.transforms.Create; -import org.apache.beam.sdk.transforms.DoFn; -import org.apache.beam.sdk.transforms.ParDo; -import org.apache.beam.sdk.transforms.View; -import org.apache.beam.sdk.values.PCollectionView; - -import org.apache.flink.test.util.JavaProgramTestBase; - -import java.io.Serializable; - -public class SideInputITCase extends JavaProgramTestBase implements Serializable { - - private static final String expected = "Hello!"; - - protected String resultPath; - - @Override - protected void testProgram() throws Exception { - - - Pipeline p = FlinkTestPipeline.createForBatch(); - - - final PCollectionView sidesInput = p - .apply(Create.of(expected)) - .apply(View.asSingleton()); - - p.apply(Create.of("bli")) - .apply(ParDo.of(new DoFn() { - @Override - public void processElement(ProcessContext c) throws Exception { - String s = c.sideInput(sidesInput); - c.output(s); - } - }).withSideInputs(sidesInput)).apply(TextIO.Write.to(resultPath)); - - p.run(); - } - - @Override - protected void preSubmit() throws Exception { - resultPath = getTempDirPath("result"); - } - - @Override - protected void postSubmit() throws Exception { - compareResultsByLinesInMemory(expected, resultPath); - } -} diff --git a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/TfIdfITCase.java b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/TfIdfITCase.java deleted file mode 100644 index 547f3c3a4660..000000000000 --- a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/TfIdfITCase.java +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.flink; - -import org.apache.beam.examples.complete.TfIdf; -import org.apache.beam.sdk.Pipeline; -import org.apache.beam.sdk.coders.StringDelegateCoder; -import org.apache.beam.sdk.io.TextIO; -import org.apache.beam.sdk.transforms.Create; -import org.apache.beam.sdk.transforms.Keys; -import org.apache.beam.sdk.transforms.RemoveDuplicates; -import org.apache.beam.sdk.values.KV; -import org.apache.beam.sdk.values.PCollection; - -import com.google.common.base.Joiner; - -import org.apache.flink.test.util.JavaProgramTestBase; - -import java.net.URI; - - -public class TfIdfITCase extends JavaProgramTestBase { - - protected String resultPath; - - public TfIdfITCase(){ - } - - static final String[] EXPECTED_RESULT = new String[] { - "a", "m", "n", "b", "c", "d"}; - - @Override - protected void preSubmit() throws Exception { - resultPath = getTempDirPath("result"); - } - - @Override - protected void postSubmit() throws Exception { - compareResultsByLinesInMemory(Joiner.on('\n').join(EXPECTED_RESULT), resultPath); - } - - @Override - protected void testProgram() throws Exception { - - Pipeline pipeline = FlinkTestPipeline.createForBatch(); - - pipeline.getCoderRegistry().registerCoder(URI.class, StringDelegateCoder.of(URI.class)); - - PCollection>> wordToUriAndTfIdf = pipeline - .apply(Create.of( - KV.of(new URI("x"), "a b c d"), - KV.of(new URI("y"), "a b c"), - KV.of(new URI("z"), "a m n"))) - .apply(new TfIdf.ComputeTfIdf()); - - PCollection words = wordToUriAndTfIdf - .apply(Keys.create()) - .apply(RemoveDuplicates.create()); - - words.apply(TextIO.Write.to(resultPath)); - - pipeline.run(); - } -} - diff --git a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/WordCountITCase.java b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/WordCountITCase.java deleted file mode 100644 index 3254e7885db8..000000000000 --- a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/WordCountITCase.java +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.flink; - -import org.apache.beam.examples.WordCount; -import org.apache.beam.sdk.Pipeline; -import org.apache.beam.sdk.coders.StringUtf8Coder; -import org.apache.beam.sdk.io.TextIO; -import org.apache.beam.sdk.transforms.Create; -import org.apache.beam.sdk.transforms.MapElements; -import org.apache.beam.sdk.values.PCollection; - -import com.google.common.base.Joiner; - -import org.apache.flink.test.util.JavaProgramTestBase; - -import java.util.Arrays; -import java.util.List; - - -public class WordCountITCase extends JavaProgramTestBase { - - protected String resultPath; - - public WordCountITCase(){ - } - - static final String[] WORDS_ARRAY = new String[] { - "hi there", "hi", "hi sue bob", - "hi sue", "", "bob hi"}; - - static final List WORDS = Arrays.asList(WORDS_ARRAY); - - static final String[] COUNTS_ARRAY = new String[] { - "hi: 5", "there: 1", "sue: 2", "bob: 2"}; - - @Override - protected void preSubmit() throws Exception { - resultPath = getTempDirPath("result"); - } - - @Override - protected void postSubmit() throws Exception { - compareResultsByLinesInMemory(Joiner.on('\n').join(COUNTS_ARRAY), resultPath); - } - - @Override - protected void testProgram() throws Exception { - - Pipeline p = FlinkTestPipeline.createForBatch(); - - PCollection input = p.apply(Create.of(WORDS)).setCoder(StringUtf8Coder.of()); - - input - .apply(new WordCount.CountWords()) - .apply(MapElements.via(new WordCount.FormatAsTextFn())) - .apply(TextIO.Write.to(resultPath)); - - p.run(); - } -} - diff --git a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/WordCountJoin2ITCase.java b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/WordCountJoin2ITCase.java deleted file mode 100644 index 6570e7df5508..000000000000 --- a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/WordCountJoin2ITCase.java +++ /dev/null @@ -1,140 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.flink; - -import org.apache.beam.sdk.Pipeline; -import org.apache.beam.sdk.io.TextIO; -import org.apache.beam.sdk.transforms.Count; -import org.apache.beam.sdk.transforms.Create; -import org.apache.beam.sdk.transforms.DoFn; -import org.apache.beam.sdk.transforms.ParDo; -import org.apache.beam.sdk.transforms.join.CoGbkResult; -import org.apache.beam.sdk.transforms.join.CoGroupByKey; -import org.apache.beam.sdk.transforms.join.KeyedPCollectionTuple; -import org.apache.beam.sdk.values.KV; -import org.apache.beam.sdk.values.PCollection; -import org.apache.beam.sdk.values.TupleTag; - -import com.google.common.base.Joiner; - -import org.apache.flink.test.util.JavaProgramTestBase; - - -public class WordCountJoin2ITCase extends JavaProgramTestBase { - - static final String[] WORDS_1 = new String[] { - "hi there", "hi", "hi sue bob", - "hi sue", "", "bob hi"}; - - static final String[] WORDS_2 = new String[] { - "hi tim", "beauty", "hooray sue bob", - "hi there", "", "please say hi"}; - - static final String[] RESULTS = new String[] { - "beauty -> Tag1: Tag2: 1", - "bob -> Tag1: 2 Tag2: 1", - "hi -> Tag1: 5 Tag2: 3", - "hooray -> Tag1: Tag2: 1", - "please -> Tag1: Tag2: 1", - "say -> Tag1: Tag2: 1", - "sue -> Tag1: 2 Tag2: 1", - "there -> Tag1: 1 Tag2: 1", - "tim -> Tag1: Tag2: 1" - }; - - static final TupleTag tag1 = new TupleTag<>("Tag1"); - static final TupleTag tag2 = new TupleTag<>("Tag2"); - - protected String resultPath; - - @Override - protected void preSubmit() throws Exception { - resultPath = getTempDirPath("result"); - } - - @Override - protected void postSubmit() throws Exception { - compareResultsByLinesInMemory(Joiner.on('\n').join(RESULTS), resultPath); - } - - @Override - protected void testProgram() throws Exception { - Pipeline p = FlinkTestPipeline.createForBatch(); - - /* Create two PCollections and join them */ - PCollection> occurences1 = p.apply(Create.of(WORDS_1)) - .apply(ParDo.of(new ExtractWordsFn())) - .apply(Count.perElement()); - - PCollection> occurences2 = p.apply(Create.of(WORDS_2)) - .apply(ParDo.of(new ExtractWordsFn())) - .apply(Count.perElement()); - - /* CoGroup the two collections */ - PCollection> mergedOccurences = KeyedPCollectionTuple - .of(tag1, occurences1) - .and(tag2, occurences2) - .apply(CoGroupByKey.create()); - - /* Format output */ - mergedOccurences.apply(ParDo.of(new FormatCountsFn())) - .apply(TextIO.Write.named("test").to(resultPath)); - - p.run(); - } - - - static class ExtractWordsFn extends DoFn { - - @Override - public void startBundle(Context c) { - } - - @Override - public void processElement(ProcessContext c) { - // Split the line into words. - String[] words = c.element().split("[^a-zA-Z']+"); - - // Output each word encountered into the output PCollection. - for (String word : words) { - if (!word.isEmpty()) { - c.output(word); - } - } - } - } - - static class FormatCountsFn extends DoFn, String> { - @Override - public void processElement(ProcessContext c) { - CoGbkResult value = c.element().getValue(); - String key = c.element().getKey(); - String countTag1 = tag1.getId() + ": "; - String countTag2 = tag2.getId() + ": "; - for (Long count : value.getAll(tag1)) { - countTag1 += count + " "; - } - for (Long count : value.getAll(tag2)) { - countTag2 += count; - } - c.output(key + " -> " + countTag1 + countTag2); - } - } - - -} diff --git a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/WordCountJoin3ITCase.java b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/WordCountJoin3ITCase.java deleted file mode 100644 index 60dc74af90b6..000000000000 --- a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/WordCountJoin3ITCase.java +++ /dev/null @@ -1,158 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.flink; - -import org.apache.beam.sdk.Pipeline; -import org.apache.beam.sdk.io.TextIO; -import org.apache.beam.sdk.transforms.Count; -import org.apache.beam.sdk.transforms.Create; -import org.apache.beam.sdk.transforms.DoFn; -import org.apache.beam.sdk.transforms.ParDo; -import org.apache.beam.sdk.transforms.join.CoGbkResult; -import org.apache.beam.sdk.transforms.join.CoGroupByKey; -import org.apache.beam.sdk.transforms.join.KeyedPCollectionTuple; -import org.apache.beam.sdk.values.KV; -import org.apache.beam.sdk.values.PCollection; -import org.apache.beam.sdk.values.TupleTag; - -import com.google.common.base.Joiner; - -import org.apache.flink.test.util.JavaProgramTestBase; - - -public class WordCountJoin3ITCase extends JavaProgramTestBase { - - static final String[] WORDS_1 = new String[] { - "hi there", "hi", "hi sue bob", - "hi sue", "", "bob hi"}; - - static final String[] WORDS_2 = new String[] { - "hi tim", "beauty", "hooray sue bob", - "hi there", "", "please say hi"}; - - static final String[] WORDS_3 = new String[] { - "hi stephan", "beauty", "hooray big fabian", - "hi yo", "", "please say hi"}; - - static final String[] RESULTS = new String[] { - "beauty -> Tag1: Tag2: 1 Tag3: 1", - "bob -> Tag1: 2 Tag2: 1 Tag3: ", - "hi -> Tag1: 5 Tag2: 3 Tag3: 3", - "hooray -> Tag1: Tag2: 1 Tag3: 1", - "please -> Tag1: Tag2: 1 Tag3: 1", - "say -> Tag1: Tag2: 1 Tag3: 1", - "sue -> Tag1: 2 Tag2: 1 Tag3: ", - "there -> Tag1: 1 Tag2: 1 Tag3: ", - "tim -> Tag1: Tag2: 1 Tag3: ", - "stephan -> Tag1: Tag2: Tag3: 1", - "yo -> Tag1: Tag2: Tag3: 1", - "fabian -> Tag1: Tag2: Tag3: 1", - "big -> Tag1: Tag2: Tag3: 1" - }; - - static final TupleTag tag1 = new TupleTag<>("Tag1"); - static final TupleTag tag2 = new TupleTag<>("Tag2"); - static final TupleTag tag3 = new TupleTag<>("Tag3"); - - protected String resultPath; - - @Override - protected void preSubmit() throws Exception { - resultPath = getTempDirPath("result"); - } - - @Override - protected void postSubmit() throws Exception { - compareResultsByLinesInMemory(Joiner.on('\n').join(RESULTS), resultPath); - } - - @Override - protected void testProgram() throws Exception { - - Pipeline p = FlinkTestPipeline.createForBatch(); - - /* Create two PCollections and join them */ - PCollection> occurences1 = p.apply(Create.of(WORDS_1)) - .apply(ParDo.of(new ExtractWordsFn())) - .apply(Count.perElement()); - - PCollection> occurences2 = p.apply(Create.of(WORDS_2)) - .apply(ParDo.of(new ExtractWordsFn())) - .apply(Count.perElement()); - - PCollection> occurences3 = p.apply(Create.of(WORDS_3)) - .apply(ParDo.of(new ExtractWordsFn())) - .apply(Count.perElement()); - - /* CoGroup the two collections */ - PCollection> mergedOccurences = KeyedPCollectionTuple - .of(tag1, occurences1) - .and(tag2, occurences2) - .and(tag3, occurences3) - .apply(CoGroupByKey.create()); - - /* Format output */ - mergedOccurences.apply(ParDo.of(new FormatCountsFn())) - .apply(TextIO.Write.named("test").to(resultPath)); - - p.run(); - } - - - static class ExtractWordsFn extends DoFn { - - @Override - public void startBundle(Context c) { - } - - @Override - public void processElement(ProcessContext c) { - // Split the line into words. - String[] words = c.element().split("[^a-zA-Z']+"); - - // Output each word encountered into the output PCollection. - for (String word : words) { - if (!word.isEmpty()) { - c.output(word); - } - } - } - } - - static class FormatCountsFn extends DoFn, String> { - @Override - public void processElement(ProcessContext c) { - CoGbkResult value = c.element().getValue(); - String key = c.element().getKey(); - String countTag1 = tag1.getId() + ": "; - String countTag2 = tag2.getId() + ": "; - String countTag3 = tag3.getId() + ": "; - for (Long count : value.getAll(tag1)) { - countTag1 += count + " "; - } - for (Long count : value.getAll(tag2)) { - countTag2 += count + " "; - } - for (Long count : value.getAll(tag3)) { - countTag3 += count; - } - c.output(key + " -> " + countTag1 + countTag2 + countTag3); - } - } - -} diff --git a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/streaming/GroupAlsoByWindowTest.java b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/streaming/GroupAlsoByWindowTest.java index c76af657b9ab..3e5a17dbdfea 100644 --- a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/streaming/GroupAlsoByWindowTest.java +++ b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/streaming/GroupAlsoByWindowTest.java @@ -44,6 +44,7 @@ import org.apache.flink.streaming.api.watermark.Watermark; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness; +import org.apache.flink.streaming.util.StreamingMultipleProgramsTestBase; import org.apache.flink.streaming.util.TestHarnessUtil; import org.joda.time.Duration; import org.joda.time.Instant; @@ -53,7 +54,7 @@ import java.util.Comparator; import java.util.concurrent.ConcurrentLinkedQueue; -public class GroupAlsoByWindowTest { +public class GroupAlsoByWindowTest extends StreamingMultipleProgramsTestBase { private final Combine.CombineFn combiner = new Sum.SumIntegerFn(); diff --git a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/util/JoinExamples.java b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/util/JoinExamples.java deleted file mode 100644 index e6b7f64f69a1..000000000000 --- a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/util/JoinExamples.java +++ /dev/null @@ -1,161 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.flink.util; - -import org.apache.beam.sdk.Pipeline; -import org.apache.beam.sdk.io.BigQueryIO; -import org.apache.beam.sdk.io.TextIO; -import org.apache.beam.sdk.options.Description; -import org.apache.beam.sdk.options.PipelineOptions; -import org.apache.beam.sdk.options.PipelineOptionsFactory; -import org.apache.beam.sdk.options.Validation; -import org.apache.beam.sdk.transforms.DoFn; -import org.apache.beam.sdk.transforms.ParDo; -import org.apache.beam.sdk.transforms.join.CoGbkResult; -import org.apache.beam.sdk.transforms.join.CoGroupByKey; -import org.apache.beam.sdk.transforms.join.KeyedPCollectionTuple; -import org.apache.beam.sdk.values.KV; -import org.apache.beam.sdk.values.PCollection; -import org.apache.beam.sdk.values.TupleTag; - -import com.google.api.services.bigquery.model.TableRow; - -/** - * Copied from {@link org.apache.beam.examples.JoinExamples} because the code - * is private there. - */ -public class JoinExamples { - - // A 1000-row sample of the GDELT data here: gdelt-bq:full.events. - private static final String GDELT_EVENTS_TABLE = - "clouddataflow-readonly:samples.gdelt_sample"; - // A table that maps country codes to country names. - private static final String COUNTRY_CODES = - "gdelt-bq:full.crosswalk_geocountrycodetohuman"; - - /** - * Join two collections, using country code as the key. - */ - public static PCollection joinEvents(PCollection eventsTable, - PCollection countryCodes) throws Exception { - - final TupleTag eventInfoTag = new TupleTag<>(); - final TupleTag countryInfoTag = new TupleTag<>(); - - // transform both input collections to tuple collections, where the keys are country - // codes in both cases. - PCollection> eventInfo = eventsTable.apply( - ParDo.of(new ExtractEventDataFn())); - PCollection> countryInfo = countryCodes.apply( - ParDo.of(new ExtractCountryInfoFn())); - - // country code 'key' -> CGBKR (, ) - PCollection> kvpCollection = KeyedPCollectionTuple - .of(eventInfoTag, eventInfo) - .and(countryInfoTag, countryInfo) - .apply(CoGroupByKey.create()); - - // Process the CoGbkResult elements generated by the CoGroupByKey transform. - // country code 'key' -> string of , - PCollection> finalResultCollection = - kvpCollection.apply(ParDo.of(new DoFn, KV>() { - @Override - public void processElement(ProcessContext c) { - KV e = c.element(); - CoGbkResult val = e.getValue(); - String countryCode = e.getKey(); - String countryName; - countryName = e.getValue().getOnly(countryInfoTag, "Kostas"); - for (String eventInfo : c.element().getValue().getAll(eventInfoTag)) { - // Generate a string that combines information from both collection values - c.output(KV.of(countryCode, "Country name: " + countryName - + ", Event info: " + eventInfo)); - } - } - })); - - // write to GCS - return finalResultCollection - .apply(ParDo.of(new DoFn, String>() { - @Override - public void processElement(ProcessContext c) { - String outputstring = "Country code: " + c.element().getKey() - + ", " + c.element().getValue(); - c.output(outputstring); - } - })); - } - - /** - * Examines each row (event) in the input table. Output a KV with the key the country - * code of the event, and the value a string encoding event information. - */ - static class ExtractEventDataFn extends DoFn> { - @Override - public void processElement(ProcessContext c) { - TableRow row = c.element(); - String countryCode = (String) row.get("ActionGeo_CountryCode"); - String sqlDate = (String) row.get("SQLDATE"); - String actor1Name = (String) row.get("Actor1Name"); - String sourceUrl = (String) row.get("SOURCEURL"); - String eventInfo = "Date: " + sqlDate + ", Actor1: " + actor1Name + ", url: " + sourceUrl; - c.output(KV.of(countryCode, eventInfo)); - } - } - - - /** - * Examines each row (country info) in the input table. Output a KV with the key the country - * code, and the value the country name. - */ - static class ExtractCountryInfoFn extends DoFn> { - @Override - public void processElement(ProcessContext c) { - TableRow row = c.element(); - String countryCode = (String) row.get("FIPSCC"); - String countryName = (String) row.get("HumanName"); - c.output(KV.of(countryCode, countryName)); - } - } - - - /** - * Options supported by {@link JoinExamples}. - *

    - * Inherits standard configuration options. - */ - private interface Options extends PipelineOptions { - @Description("Path of the file to write to") - @Validation.Required - String getOutput(); - void setOutput(String value); - } - - public static void main(String[] args) throws Exception { - Options options = PipelineOptionsFactory.fromArgs(args).withValidation().as(Options.class); - Pipeline p = Pipeline.create(options); - // the following two 'applys' create multiple inputs to our pipeline, one for each - // of our two input sources. - PCollection eventsTable = p.apply(BigQueryIO.Read.from(GDELT_EVENTS_TABLE)); - PCollection countryCodes = p.apply(BigQueryIO.Read.from(COUNTRY_CODES)); - PCollection formattedResults = joinEvents(eventsTable, countryCodes); - formattedResults.apply(TextIO.Write.to(options.getOutput())); - p.run(); - } - -} diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/join/UnionCoder.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/join/UnionCoder.java index 2ca7014691af..29240e7bb863 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/join/UnionCoder.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/join/UnionCoder.java @@ -35,7 +35,7 @@ /** * A UnionCoder encodes RawUnionValues. */ -class UnionCoder extends StandardCoder { +public class UnionCoder extends StandardCoder { // TODO: Think about how to integrate this with a schema object (i.e. // a tuple of tuple tags). /** From 23ba976403b308103ea8b7dd0505e5847dd44952 Mon Sep 17 00:00:00 2001 From: Aljoscha Krettek Date: Thu, 19 May 2016 14:05:05 +0200 Subject: [PATCH 17/21] Add surefire plugin to java 8 example tests Without this the test will just pick up whathever was last written to beamTestPipelineOptions. --- examples/java8/pom.xml | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/examples/java8/pom.xml b/examples/java8/pom.xml index b4a9ec6d7fa4..e211739a9412 100644 --- a/examples/java8/pom.xml +++ b/examples/java8/pom.xml @@ -49,6 +49,18 @@ + + org.apache.maven.plugins + maven-surefire-plugin + + + + + + + + + org.apache.maven.plugins maven-dependency-plugin From 145049f45b97951ae6cd1ef5e9ca64ea1ce22b4d Mon Sep 17 00:00:00 2001 From: Thomas Groh Date: Fri, 20 May 2016 13:43:10 -0700 Subject: [PATCH 18/21] Remove CustomSourcesTest The paths tested by this test are exercised in existing tests marked RunnableOnService, which have real sources. --- .../dataflow/internal/CustomSourcesTest.java | 276 ------------------ 1 file changed, 276 deletions(-) delete mode 100644 runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/internal/CustomSourcesTest.java diff --git a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/internal/CustomSourcesTest.java b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/internal/CustomSourcesTest.java deleted file mode 100644 index ed86be22d8a9..000000000000 --- a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/internal/CustomSourcesTest.java +++ /dev/null @@ -1,276 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.dataflow.internal; -import static org.apache.beam.sdk.testing.SourceTestUtils.readFromSource; - -import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.contains; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertTrue; - -import org.apache.beam.runners.dataflow.options.DataflowPipelineOptions; -import org.apache.beam.sdk.Pipeline; -import org.apache.beam.sdk.coders.BigEndianIntegerCoder; -import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.io.BoundedSource; -import org.apache.beam.sdk.io.Read; -import org.apache.beam.sdk.options.PipelineOptions; -import org.apache.beam.sdk.options.PipelineOptionsFactory; -import org.apache.beam.sdk.testing.ExpectedLogs; -import org.apache.beam.sdk.testing.PAssert; -import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.transforms.Sample; -import org.apache.beam.sdk.transforms.Sum; -import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.transforms.windowing.FixedWindows; -import org.apache.beam.sdk.transforms.windowing.Window; -import org.apache.beam.sdk.values.PCollection; - -import com.google.common.base.Preconditions; - -import org.joda.time.Duration; -import org.joda.time.Instant; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.ExpectedException; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; - -/** - * Tests for {@link CustomSources}. - */ -@RunWith(JUnit4.class) -public class CustomSourcesTest { - @Rule public ExpectedException expectedException = ExpectedException.none(); - @Rule public ExpectedLogs logged = ExpectedLogs.none(CustomSources.class); - - static class TestIO { - public static Read fromRange(int from, int to) { - return new Read(from, to, false); - } - - static class Read extends BoundedSource { - final int from; - final int to; - final boolean produceTimestamps; - - Read(int from, int to, boolean produceTimestamps) { - this.from = from; - this.to = to; - this.produceTimestamps = produceTimestamps; - } - - public Read withTimestampsMillis() { - return new Read(from, to, true); - } - - @Override - public List splitIntoBundles(long desiredBundleSizeBytes, PipelineOptions options) - throws Exception { - List res = new ArrayList<>(); - DataflowPipelineOptions dataflowOptions = options.as(DataflowPipelineOptions.class); - float step = 1.0f * (to - from) / dataflowOptions.getNumWorkers(); - for (int i = 0; i < dataflowOptions.getNumWorkers(); ++i) { - res.add(new Read( - Math.round(from + i * step), Math.round(from + (i + 1) * step), - produceTimestamps)); - } - return res; - } - - @Override - public long getEstimatedSizeBytes(PipelineOptions options) throws Exception { - return 8 * (to - from); - } - - @Override - public boolean producesSortedKeys(PipelineOptions options) throws Exception { - return true; - } - - @Override - public BoundedReader createReader(PipelineOptions options) throws IOException { - return new RangeReader(this); - } - - @Override - public void validate() {} - - @Override - public String toString() { - return "[" + from + ", " + to + ")"; - } - - @Override - public Coder getDefaultOutputCoder() { - return BigEndianIntegerCoder.of(); - } - - private static class RangeReader extends BoundedReader { - // To verify that BasicSerializableSourceFormat calls our methods according to protocol. - enum State { - UNSTARTED, - STARTED, - FINISHED - } - private Read source; - private int current = -1; - private State state = State.UNSTARTED; - - public RangeReader(Read source) { - this.source = source; - } - - @Override - public boolean start() throws IOException { - Preconditions.checkState(state == State.UNSTARTED); - state = State.STARTED; - current = source.from; - return (current < source.to); - } - - @Override - public boolean advance() throws IOException { - Preconditions.checkState(state == State.STARTED); - if (current == source.to - 1) { - state = State.FINISHED; - return false; - } - current++; - return true; - } - - @Override - public Integer getCurrent() { - Preconditions.checkState(state == State.STARTED); - return current; - } - - @Override - public Instant getCurrentTimestamp() { - return source.produceTimestamps - ? new Instant(current /* as millis */) : BoundedWindow.TIMESTAMP_MIN_VALUE; - } - - @Override - public void close() throws IOException { - Preconditions.checkState(state == State.STARTED || state == State.FINISHED); - state = State.FINISHED; - } - - @Override - public Read getCurrentSource() { - return source; - } - - @Override - public Read splitAtFraction(double fraction) { - int proposedIndex = (int) (source.from + fraction * (source.to - source.from)); - if (proposedIndex <= current) { - return null; - } - Read primary = new Read(source.from, proposedIndex, source.produceTimestamps); - Read residual = new Read(proposedIndex, source.to, source.produceTimestamps); - this.source = primary; - return residual; - } - - @Override - public Double getFractionConsumed() { - return (current == -1) - ? 0.0 - : (1.0 * (1 + current - source.from) / (source.to - source.from)); - } - } - } - } - - @Test - public void testDirectPipelineWithoutTimestamps() throws Exception { - Pipeline p = TestPipeline.create(); - PCollection sum = p - .apply(Read.from(TestIO.fromRange(10, 20))) - .apply(Sum.integersGlobally()) - .apply(Sample.any(1)); - - PAssert.thatSingleton(sum).isEqualTo(145); - p.run(); - } - - @Test - public void testDirectPipelineWithTimestamps() throws Exception { - Pipeline p = TestPipeline.create(); - PCollection sums = - p.apply(Read.from(TestIO.fromRange(10, 20).withTimestampsMillis())) - .apply(Window.into(FixedWindows.of(Duration.millis(3)))) - .apply(Sum.integersGlobally().withoutDefaults()); - // Should group into [10 11] [12 13 14] [15 16 17] [18 19]. - PAssert.that(sums).containsInAnyOrder(21, 37, 39, 48); - p.run(); - } - - @Test - public void testRangeProgressAndSplitAtFraction() throws Exception { - // Show basic usage of getFractionConsumed and splitAtFraction. - // This test only tests TestIO itself, not BasicSerializableSourceFormat. - - DataflowPipelineOptions options = - PipelineOptionsFactory.create().as(DataflowPipelineOptions.class); - TestIO.Read source = TestIO.fromRange(10, 20); - try (BoundedSource.BoundedReader reader = source.createReader(options)) { - assertEquals(0, reader.getFractionConsumed().intValue()); - assertTrue(reader.start()); - assertEquals(0.1, reader.getFractionConsumed(), 1e-6); - assertTrue(reader.advance()); - assertEquals(0.2, reader.getFractionConsumed(), 1e-6); - // Already past 0.0 and 0.1. - assertNull(reader.splitAtFraction(0.0)); - assertNull(reader.splitAtFraction(0.1)); - - { - TestIO.Read residual = (TestIO.Read) reader.splitAtFraction(0.5); - assertNotNull(residual); - TestIO.Read primary = (TestIO.Read) reader.getCurrentSource(); - assertThat(readFromSource(primary, options), contains(10, 11, 12, 13, 14)); - assertThat(readFromSource(residual, options), contains(15, 16, 17, 18, 19)); - } - - // Range is now [10, 15) and we are at 12. - { - TestIO.Read residual = (TestIO.Read) reader.splitAtFraction(0.8); // give up 14. - assertNotNull(residual); - TestIO.Read primary = (TestIO.Read) reader.getCurrentSource(); - assertThat(readFromSource(primary, options), contains(10, 11, 12, 13)); - assertThat(readFromSource(residual, options), contains(14)); - } - - assertTrue(reader.advance()); - assertEquals(12, reader.getCurrent().intValue()); - assertTrue(reader.advance()); - assertEquals(13, reader.getCurrent().intValue()); - assertFalse(reader.advance()); - } - } -} From 89d20a2d66319269082cdead70eb3cf10309b9e8 Mon Sep 17 00:00:00 2001 From: Pei He Date: Wed, 18 May 2016 17:46:24 -0700 Subject: [PATCH 19/21] [BEAM-48] Add BigQueryTornadoes integration test --- .../examples/cookbook/BigQueryTornadoes.java | 2 +- .../cookbook/BigQueryTornadoesIT.java | 52 +++++++++++ .../org/apache/beam/sdk/io/BigQueryIO.java | 87 +++++++++++-------- .../apache/beam/sdk/io/BigQueryIOTest.java | 17 ++-- 4 files changed, 112 insertions(+), 46 deletions(-) create mode 100644 examples/java/src/test/java/org/apache/beam/examples/cookbook/BigQueryTornadoesIT.java diff --git a/examples/java/src/main/java/org/apache/beam/examples/cookbook/BigQueryTornadoes.java b/examples/java/src/main/java/org/apache/beam/examples/cookbook/BigQueryTornadoes.java index 80a2f25569ab..4c69efb81ca5 100644 --- a/examples/java/src/main/java/org/apache/beam/examples/cookbook/BigQueryTornadoes.java +++ b/examples/java/src/main/java/org/apache/beam/examples/cookbook/BigQueryTornadoes.java @@ -143,7 +143,7 @@ public PCollection apply(PCollection rows) { * *

    Inherits standard configuration options. */ - private static interface Options extends PipelineOptions { + static interface Options extends PipelineOptions { @Description("Table to read from, specified as " + ":.") @Default.String(WEATHER_SAMPLES_TABLE) diff --git a/examples/java/src/test/java/org/apache/beam/examples/cookbook/BigQueryTornadoesIT.java b/examples/java/src/test/java/org/apache/beam/examples/cookbook/BigQueryTornadoesIT.java new file mode 100644 index 000000000000..fbd775cf50c8 --- /dev/null +++ b/examples/java/src/test/java/org/apache/beam/examples/cookbook/BigQueryTornadoesIT.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.examples.cookbook; + +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.testing.TestPipelineOptions; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * End-to-end tests of BigQueryTornadoes. + */ +@RunWith(JUnit4.class) +public class BigQueryTornadoesIT { + + /** + * Options for the BigQueryTornadoes Integration Test. + */ + public interface BigQueryTornadoesITOptions + extends TestPipelineOptions, BigQueryTornadoes.Options { + } + + @Test + public void testE2EBigQueryTornadoes() throws Exception { + PipelineOptionsFactory.register(BigQueryTornadoesITOptions.class); + BigQueryTornadoesITOptions options = + TestPipeline.testingPipelineOptions().as(BigQueryTornadoesITOptions.class); + options.setOutput(String.format("%s.%s", + "BigQueryTornadoesIT", "monthly_tornadoes_" + System.currentTimeMillis())); + + BigQueryTornadoes.main(TestPipeline.convertToArgs(options)); + } +} diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/BigQueryIO.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/BigQueryIO.java index e4a306adc1cf..030dde031d9d 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/BigQueryIO.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/BigQueryIO.java @@ -559,34 +559,23 @@ public PCollection apply(PInput input) { String.format("Failed to resolve extract destination directory in %s", tempLocation)); } + final String executingProject = bqOptions.getProject(); if (!Strings.isNullOrEmpty(query)) { - String projectId = bqOptions.getProject(); String queryTempDatasetId = "temp_dataset_" + uuid; String queryTempTableId = "temp_table_" + uuid; TableReference queryTempTableRef = new TableReference() - .setProjectId(projectId) + .setProjectId(executingProject) .setDatasetId(queryTempDatasetId) .setTableId(queryTempTableId); - String jsonQueryTempTable; - try { - jsonQueryTempTable = JSON_FACTORY.toString(queryTempTableRef); - } catch (IOException e) { - throw new RuntimeException("Cannot initialize table to JSON strings.", e); - } source = BigQueryQuerySource.create( - jobIdToken, query, jsonQueryTempTable, flattenResults, + jobIdToken, query, queryTempTableRef, flattenResults, extractDestinationDir, bqServices); } else { - String jsonTable; - try { - jsonTable = JSON_FACTORY.toString(getTableWithDefaultProject(bqOptions)); - } catch (IOException e) { - throw new RuntimeException("Cannot initialize table to JSON strings.", e); - } + TableReference inputTable = getTableWithDefaultProject(bqOptions); source = BigQueryTableSource.create( - jobIdToken, jsonTable, extractDestinationDir, bqServices); + jobIdToken, inputTable, extractDestinationDir, bqServices, executingProject); } PassThroughThenCleanup.CleanupOperation cleanupOperation = new PassThroughThenCleanup.CleanupOperation() { @@ -595,7 +584,7 @@ void cleanup(PipelineOptions options) throws Exception { BigQueryOptions bqOptions = options.as(BigQueryOptions.class); JobReference jobRef = new JobReference() - .setProjectId(bqOptions.getProject()) + .setProjectId(executingProject) .setJobId(getExtractJobId(jobIdToken)); Job extractJob = bqServices.getJobService(bqOptions).pollJob( jobRef, CLEANUP_JOB_POLL_MAX_RETRIES); @@ -759,10 +748,12 @@ static class BigQueryTableSource extends BigQuerySourceBase { static BigQueryTableSource create( String jobIdToken, - String jsonTable, + TableReference table, String extractDestinationDir, - BigQueryServices bqServices) { - return new BigQueryTableSource(jobIdToken, jsonTable, extractDestinationDir, bqServices); + BigQueryServices bqServices, + String executingProject) { + return new BigQueryTableSource( + jobIdToken, table, extractDestinationDir, bqServices, executingProject); } private final String jsonTable; @@ -770,11 +761,17 @@ static BigQueryTableSource create( private BigQueryTableSource( String jobIdToken, - String jsonTable, + TableReference table, String extractDestinationDir, - BigQueryServices bqServices) { - super(jobIdToken, extractDestinationDir, bqServices); - this.jsonTable = checkNotNull(jsonTable, "jsonTable"); + BigQueryServices bqServices, + String executingProject) { + super(jobIdToken, extractDestinationDir, bqServices, executingProject); + checkNotNull(table, "table"); + try { + this.jsonTable = JSON_FACTORY.toString(table); + } catch (IOException e) { + throw new RuntimeException("Cannot initialize table to JSON strings.", e); + } this.tableSizeBytes = new AtomicReference<>(); } @@ -824,12 +821,17 @@ static class BigQueryQuerySource extends BigQuerySourceBase { static BigQueryQuerySource create( String jobIdToken, String query, - String jsonQueryTempTable, + TableReference queryTempTableRef, Boolean flattenResults, String extractDestinationDir, BigQueryServices bqServices) { return new BigQueryQuerySource( - jobIdToken, query, jsonQueryTempTable, flattenResults, extractDestinationDir, bqServices); + jobIdToken, + query, + queryTempTableRef, + flattenResults, + extractDestinationDir, + bqServices); } private final String query; @@ -840,13 +842,18 @@ static BigQueryQuerySource create( private BigQueryQuerySource( String jobIdToken, String query, - String jsonQueryTempTable, + TableReference queryTempTableRef, Boolean flattenResults, String extractDestinationDir, BigQueryServices bqServices) { - super(jobIdToken, extractDestinationDir, bqServices); + super(jobIdToken, extractDestinationDir, bqServices, + checkNotNull(queryTempTableRef, "queryTempTableRef").getProjectId()); this.query = checkNotNull(query, "query"); - this.jsonQueryTempTable = checkNotNull(jsonQueryTempTable, "jsonQueryTempTable"); + try { + this.jsonQueryTempTable = JSON_FACTORY.toString(queryTempTableRef); + } catch (IOException e) { + throw new RuntimeException("Cannot initialize table to JSON strings.", e); + } this.flattenResults = checkNotNull(flattenResults, "flattenResults"); this.dryRunJobStats = new AtomicReference<>(); } @@ -861,7 +868,7 @@ public long getEstimatedSizeBytes(PipelineOptions options) throws Exception { public BoundedReader createReader(PipelineOptions options) throws IOException { BigQueryOptions bqOptions = options.as(BigQueryOptions.class); return new BigQueryReader(this, bqServices.getReaderFromQuery( - bqOptions, query, bqOptions.getProject(), flattenResults)); + bqOptions, query, executingProject, flattenResults)); } @Override @@ -887,7 +894,12 @@ protected TableReference getTableToExtract(BigQueryOptions bqOptions) // 3. Execute the query. String queryJobId = jobIdToken + "-query"; executeQuery( - queryJobId, query, tableToExtract, flattenResults, bqServices.getJobService(bqOptions)); + executingProject, + queryJobId, + query, + tableToExtract, + flattenResults, + bqServices.getJobService(bqOptions)); return tableToExtract; } @@ -912,22 +924,22 @@ public void populateDisplayData(DisplayData.Builder builder) { private synchronized JobStatistics dryRunQueryIfNeeded(BigQueryOptions bqOptions) throws InterruptedException, IOException { if (dryRunJobStats.get() == null) { - String projectId = bqOptions.getProject(); JobStatistics jobStats = - bqServices.getJobService(bqOptions).dryRunQuery(projectId, query); + bqServices.getJobService(bqOptions).dryRunQuery(executingProject, query); dryRunJobStats.compareAndSet(null, jobStats); } return dryRunJobStats.get(); } private static void executeQuery( + String executingProject, String jobId, String query, TableReference destinationTable, boolean flattenResults, JobService jobService) throws IOException, InterruptedException { JobReference jobRef = new JobReference() - .setProjectId(destinationTable.getProjectId()) + .setProjectId(executingProject) .setJobId(jobId); JobConfigurationQuery queryConfig = new JobConfigurationQuery(); queryConfig @@ -978,14 +990,17 @@ private abstract static class BigQuerySourceBase extends BoundedSource protected final String jobIdToken; protected final String extractDestinationDir; protected final BigQueryServices bqServices; + protected final String executingProject; private BigQuerySourceBase( String jobIdToken, String extractDestinationDir, - BigQueryServices bqServices) { + BigQueryServices bqServices, + String executingProject) { this.jobIdToken = checkNotNull(jobIdToken, "jobIdToken"); this.extractDestinationDir = checkNotNull(extractDestinationDir, "extractDestinationDir"); this.bqServices = checkNotNull(bqServices, "bqServices"); + this.executingProject = checkNotNull(executingProject, "executingProject"); } @Override @@ -1029,7 +1044,7 @@ private List executeExtract( String jobId, TableReference table, JobService jobService) throws InterruptedException, IOException { JobReference jobRef = new JobReference() - .setProjectId(table.getProjectId()) + .setProjectId(executingProject) .setJobId(jobId); String destinationUri = getExtractDestinationUri(extractDestinationDir); diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/BigQueryIOTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/BigQueryIOTest.java index 6849018d03f6..2d1b5505a8bd 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/BigQueryIOTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/BigQueryIOTest.java @@ -870,10 +870,10 @@ public void testBigQueryTableSourceThroughJsonAPI() throws Exception { toJsonString(new TableRow().set("name", "c").set("number", "3"))); String jobIdToken = "testJobIdToken"; - String jsonTable = toJsonString(BigQueryIO.parseTableSpec("project.data_set.table_name")); + TableReference table = BigQueryIO.parseTableSpec("project.data_set.table_name"); String extractDestinationDir = "mock://tempLocation"; - BoundedSource bqSource = - BigQueryTableSource.create(jobIdToken, jsonTable, extractDestinationDir, fakeBqServices); + BoundedSource bqSource = BigQueryTableSource.create( + jobIdToken, table, extractDestinationDir, fakeBqServices, "project"); List expected = ImmutableList.of( new TableRow().set("name", "a").set("number", "1"), @@ -907,10 +907,10 @@ public void testBigQueryTableSourceInitSplit() throws Exception { toJsonString(new TableRow().set("name", "c").set("number", "3"))); String jobIdToken = "testJobIdToken"; - String jsonTable = toJsonString(BigQueryIO.parseTableSpec("project.data_set.table_name")); + TableReference table = BigQueryIO.parseTableSpec("project:data_set.table_name"); String extractDestinationDir = "mock://tempLocation"; - BoundedSource bqSource = - BigQueryTableSource.create(jobIdToken, jsonTable, extractDestinationDir, fakeBqServices); + BoundedSource bqSource = BigQueryTableSource.create( + jobIdToken, table, extractDestinationDir, fakeBqServices, "project"); List expected = ImmutableList.of( new TableRow().set("name", "a").set("number", "1"), @@ -973,10 +973,9 @@ public void testBigQueryQuerySourceInitSplit() throws Exception { String jobIdToken = "testJobIdToken"; String extractDestinationDir = "mock://tempLocation"; - TableReference destinationTable = BigQueryIO.parseTableSpec("project.data_set.table_name"); - String jsonDestinationTable = toJsonString(destinationTable); + TableReference destinationTable = BigQueryIO.parseTableSpec("project:data_set.table_name"); BoundedSource bqSource = BigQueryQuerySource.create( - jobIdToken, "query", jsonDestinationTable, true /* flattenResults */, + jobIdToken, "query", destinationTable, true /* flattenResults */, extractDestinationDir, fakeBqServices); List expected = ImmutableList.of( From 32a6cde4e43726849713a7183c66aa28f43b0868 Mon Sep 17 00:00:00 2001 From: Dan Halperin Date: Tue, 3 May 2016 17:53:48 -0700 Subject: [PATCH 20/21] BoundedReader: add getSplitPoints{Consumed,Remaining} And implement and test it for common sources OffsetBasedReader: test limited parallelism signals AvroSource: rewrite to support remaining parallelism *) Make the start of a block match Avro's definition: the first byte after the previous sync marker. This enables detecting the last block in the file. *) This change enables us to unify currentOffset and currentBlockOffset, as all records are emitted at the start of the block that contains them. *) Simplify block header reading to have fewer object allocations and buffers using a direct reader and a (allocated once only) CountingInputStream to measure the size of that header. *) Add tests for consumed and remaining parallelism *) Let BlockBasedSource detect the end of the file in remaining parallelism. *) Verify in more places that the correct number of bytes is read from the input Avro file. CompressedSource: add tests of parallelism and progress *) empty file *) non-empty compressed file *) non-empty not-compressed file TextIO: implement and test parallelism *) empty file *) non-empty file CountingSource: test limited parallelism CompressedSource: implement currentOffset based on bytes decompressed *) This is not a very good offset because it is an upper bound, but it is likely better than not reporting any progress at all. --- .../org/apache/beam/sdk/io/AvroSource.java | 166 ++++++++++++------ .../apache/beam/sdk/io/BlockBasedSource.java | 26 +-- .../org/apache/beam/sdk/io/BoundedSource.java | 145 ++++++++++++++- .../apache/beam/sdk/io/CompressedSource.java | 132 ++++++++++++-- .../apache/beam/sdk/io/CountingSource.java | 5 + .../org/apache/beam/sdk/io/DatastoreIO.java | 13 ++ .../apache/beam/sdk/io/FileBasedSource.java | 2 +- .../apache/beam/sdk/io/OffsetBasedSource.java | 49 +++++- .../java/org/apache/beam/sdk/io/TextIO.java | 20 ++- .../beam/sdk/io/range/OffsetRangeTracker.java | 109 ++++++++++-- .../apache/beam/sdk/io/AvroSourceTest.java | 86 ++++++++- .../beam/sdk/io/CompressedSourceTest.java | 107 ++++++++++- .../beam/sdk/io/CountingSourceTest.java | 30 ++++ .../beam/sdk/io/FileBasedSourceTest.java | 2 +- .../beam/sdk/io/OffsetBasedSourceTest.java | 71 +++++++- .../org/apache/beam/sdk/io/TextIOTest.java | 114 +++++++++++- .../sdk/io/range/OffsetRangeTrackerTest.java | 1 - .../beam/sdk/io/hdfs/HDFSFileSource.java | 12 ++ 18 files changed, 969 insertions(+), 121 deletions(-) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/AvroSource.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/AvroSource.java index ef8e4273c18c..255199f8462f 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/AvroSource.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/AvroSource.java @@ -17,6 +17,8 @@ */ package org.apache.beam.sdk.io; +import static com.google.common.base.Preconditions.checkState; + import org.apache.beam.sdk.annotations.Experimental; import org.apache.beam.sdk.coders.AvroCoder; import org.apache.beam.sdk.options.PipelineOptions; @@ -40,18 +42,24 @@ import org.apache.commons.compress.compressors.bzip2.BZip2CompressorInputStream; import org.apache.commons.compress.compressors.snappy.SnappyCompressorInputStream; import org.apache.commons.compress.compressors.xz.XZCompressorInputStream; +import org.apache.commons.compress.utils.CountingInputStream; import java.io.ByteArrayInputStream; +import java.io.EOFException; import java.io.IOException; import java.io.InputStream; import java.io.PushbackInputStream; import java.nio.ByteBuffer; import java.nio.channels.Channels; import java.nio.channels.ReadableByteChannel; +import java.nio.channels.SeekableByteChannel; +import java.util.Arrays; import java.util.Collection; import java.util.zip.Inflater; import java.util.zip.InflaterInputStream; +import javax.annotation.concurrent.GuardedBy; + // CHECKSTYLE.OFF: JavadocStyle /** * A {@link FileBasedSource} for reading Avro files. @@ -439,10 +447,6 @@ public double getFractionOfBlockConsumed() { * the total number of records in the block and the block's size in bytes, followed by the * block's (optionally-encoded) records. Each block is terminated by a 16-bit sync marker. * - *

    Here, we consider the sync marker that precedes a block to be its offset, as this allows - * a reader that begins reading at that offset to detect the sync marker and the beginning of - * the block. - * * @param The type of records contained in the block. */ @Experimental(Experimental.Kind.SOURCE_SINK) @@ -450,24 +454,25 @@ public static class AvroReader extends BlockBasedReader { // The current block. private AvroBlock currentBlock; - // Offset of the block. + // A lock used to synchronize block offsets for getRemainingParallelism + private final Object progressLock = new Object(); + + // Offset of the current block. + @GuardedBy("progressLock") private long currentBlockOffset = 0; // Size of the current block. + @GuardedBy("progressLock") private long currentBlockSizeBytes = 0; - // Current offset within the stream. - private long currentOffset = 0; - // Stream used to read from the underlying file. - // A pushback stream is used to restore bytes buffered during seeking/decoding. + // A pushback stream is used to restore bytes buffered during seeking. private PushbackInputStream stream; + // Counts the number of bytes read. Used only to tell how many bytes are taken up in + // a block's variable-length header. + private CountingInputStream countStream; - // Small buffer for reading encoded values from the stream. - // The maximum size of an encoded long is 10 bytes, and this buffer will be used to read two. - private final byte[] readBuffer = new byte[20]; - - // Decoder to decode binary-encoded values from the buffer. + // Caches the Avro DirectBinaryDecoder used to decode binary-encoded values from the buffer. private BinaryDecoder decoder; /** @@ -482,51 +487,67 @@ public synchronized AvroSource getCurrentSource() { return (AvroSource) super.getCurrentSource(); } + // Precondition: the stream is positioned after the sync marker in the current (about to be + // previous) block. currentBlockSize equals the size of the current block, or zero if this + // reader was just started. + // + // Postcondition: same as above, but for the new current (formerly next) block. @Override public boolean readNextBlock() throws IOException { - // The next block in the file is after the first sync marker that can be read starting from - // the current offset. First, we seek past the next sync marker, if it exists. After a sync - // marker is the start of a block. A block begins with the number of records contained in - // the block, encoded as a long, followed by the size of the block in bytes, encoded as a - // long. The currentOffset after this method should be last byte after this block, and the - // currentBlockOffset should be the start of the sync marker before this block. - - // Seek to the next sync marker, if one exists. - currentOffset += advancePastNextSyncMarker(stream, getCurrentSource().getSyncMarker()); - - // The offset of the current block includes its preceding sync marker. - currentBlockOffset = currentOffset - getCurrentSource().getSyncMarker().length; - - // Read a small buffer to parse the block header. - // We cannot use a BinaryDecoder to do this directly from the stream because a BinaryDecoder - // internally buffers data and we only want to read as many bytes from the stream as the size - // of the header. Though BinaryDecoder#InputStream returns an input stream that is aware of - // its internal buffering, we would have to re-wrap this input stream to seek for the next - // block in the file. - int read = stream.read(readBuffer); - // We reached the last sync marker in the file. - if (read <= 0) { + long startOfNextBlock = currentBlockOffset + currentBlockSizeBytes; + + // Before reading the variable-sized block header, record the current number of bytes read. + long preHeaderCount = countStream.getBytesRead(); + decoder = DecoderFactory.get().directBinaryDecoder(countStream, decoder); + long numRecords; + try { + numRecords = decoder.readLong(); + } catch (EOFException e) { + // Expected for the last block, at which the start position is the EOF. The way to detect + // stream ending is to try reading from it. return false; } - decoder = DecoderFactory.get().binaryDecoder(readBuffer, decoder); - long numRecords = decoder.readLong(); long blockSize = decoder.readLong(); - // The decoder buffers data internally, but since we know the size of the stream the - // decoder has constructed from the readBuffer, the number of bytes available in the - // input stream is equal to the number of unconsumed bytes. - int headerSize = readBuffer.length - decoder.inputStream().available(); - stream.unread(readBuffer, headerSize, read - headerSize); + // Mark header size as the change in the number of bytes read. + long headerSize = countStream.getBytesRead() - preHeaderCount; // Create the current block by reading blockSize bytes. Block sizes permitted by the Avro // specification are [32, 2^30], so this narrowing is ok. byte[] data = new byte[(int) blockSize]; - stream.read(data); + int read = stream.read(data); + checkState(blockSize == read, "Only %s/%s bytes in the block were read", read, blockSize); currentBlock = new AvroBlock<>(data, numRecords, getCurrentSource()); - currentBlockSizeBytes = blockSize; - // Update current offset with the number of bytes we read to get the next block. - currentOffset += headerSize + blockSize; + // Read the end of this block, which MUST be a sync marker for correctness. + byte[] syncMarker = getCurrentSource().getSyncMarker(); + byte[] readSyncMarker = new byte[syncMarker.length]; + long syncMarkerOffset = startOfNextBlock + headerSize + blockSize; + long bytesRead = stream.read(readSyncMarker); + checkState( + bytesRead == syncMarker.length, + "When trying to read a sync marker at position %s, only able to read %s/%s bytes", + syncMarkerOffset, + bytesRead, + syncMarker.length); + if (!Arrays.equals(syncMarker, readSyncMarker)) { + throw new IllegalStateException( + String.format( + "Expected the bytes [%d,%d) in file %s to be a sync marker, but found %s", + syncMarkerOffset, + syncMarkerOffset + syncMarker.length, + getCurrentSource().getFileOrPatternSpec(), + Arrays.toString(readSyncMarker) + )); + } + + // Atomically update both the position and offset of the new block. + synchronized (progressLock) { + currentBlockOffset = startOfNextBlock; + // Total block size includes the header, block content, and trailing sync marker. + currentBlockSizeBytes = headerSize + blockSize + syncMarker.length; + } + return true; } @@ -537,32 +558,65 @@ public AvroBlock getCurrentBlock() { @Override public long getCurrentBlockOffset() { - return currentBlockOffset; + synchronized (progressLock) { + return currentBlockOffset; + } } @Override public long getCurrentBlockSize() { - return currentBlockSizeBytes; + synchronized (progressLock) { + return currentBlockSizeBytes; + } + } + + @Override + public long getSplitPointsRemaining() { + if (isDone()) { + return 0; + } + synchronized (progressLock) { + if (currentBlockOffset + currentBlockSizeBytes >= getCurrentSource().getEndOffset()) { + // This block is known to be the last block in the range. + return 1; + } + } + return super.getSplitPointsRemaining(); } /** * Creates a {@link PushbackInputStream} that has a large enough pushback buffer to be able - * to push back the syncBuffer and the readBuffer. + * to push back the syncBuffer. */ private PushbackInputStream createStream(ReadableByteChannel channel) { return new PushbackInputStream( Channels.newInputStream(channel), - getCurrentSource().getSyncMarker().length + readBuffer.length); + getCurrentSource().getSyncMarker().length); } - /** - * Starts reading from the provided channel. Assumes that the channel is already seeked to - * the source's start offset. - */ + // Postcondition: the stream is positioned at the beginning of the first block after the start + // of the current source, and currentBlockOffset is that position. Additionally, + // currentBlockSizeBytes will be set to 0 indicating that the previous block was empty. @Override protected void startReading(ReadableByteChannel channel) throws IOException { + long startOffset = getCurrentSource().getStartOffset(); + byte[] syncMarker = getCurrentSource().getSyncMarker(); + long syncMarkerLength = syncMarker.length; + + if (startOffset != 0) { + // Rewind order to find the sync marker ending the previous block. + long position = Math.max(0, startOffset - syncMarkerLength); + ((SeekableByteChannel) channel).position(position); + startOffset = position; + } + + // Satisfy the post condition. stream = createStream(channel); - currentOffset = getCurrentSource().getStartOffset(); + countStream = new CountingInputStream(stream); + synchronized (progressLock) { + currentBlockOffset = startOffset + advancePastNextSyncMarker(stream, syncMarker); + currentBlockSizeBytes = 0; + } } /** diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/BlockBasedSource.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/BlockBasedSource.java index 31ef0556a809..997c77a1273f 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/BlockBasedSource.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/BlockBasedSource.java @@ -206,28 +206,32 @@ protected final boolean readNextRecord() throws IOException { } @Override + @Nullable public Double getFractionConsumed() { - if (getCurrentSource().getEndOffset() == Long.MAX_VALUE) { - return null; - } - Block currentBlock = getCurrentBlock(); - if (currentBlock == null) { - // There is no current block (i.e., the read has not yet begun). + if (!isStarted()) { return 0.0; } + if (isDone()) { + return 1.0; + } + FileBasedSource source = getCurrentSource(); + if (source.getEndOffset() == Long.MAX_VALUE) { + // Unknown end offset, so we cannot tell. + return null; + } + long currentBlockOffset = getCurrentBlockOffset(); - long startOffset = getCurrentSource().getStartOffset(); - long endOffset = getCurrentSource().getEndOffset(); + long startOffset = source.getStartOffset(); + long endOffset = source.getEndOffset(); double fractionAtBlockStart = ((double) (currentBlockOffset - startOffset)) / (endOffset - startOffset); double fractionAtBlockEnd = ((double) (currentBlockOffset + getCurrentBlockSize() - startOffset) / (endOffset - startOffset)); + double blockFraction = getCurrentBlock().getFractionOfBlockConsumed(); return Math.min( 1.0, - fractionAtBlockStart - + currentBlock.getFractionOfBlockConsumed() - * (fractionAtBlockEnd - fractionAtBlockStart)); + fractionAtBlockStart + blockFraction * (fractionAtBlockEnd - fractionAtBlockStart)); } @Override diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/BoundedSource.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/BoundedSource.java index 8f7d3fdcbe93..394afa4bb3c8 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/BoundedSource.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/BoundedSource.java @@ -18,6 +18,8 @@ package org.apache.beam.sdk.io; import org.apache.beam.sdk.annotations.Experimental; +import org.apache.beam.sdk.io.range.OffsetRangeTracker; +import org.apache.beam.sdk.io.range.RangeTracker; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; @@ -27,6 +29,8 @@ import java.util.List; import java.util.NoSuchElementException; +import javax.annotation.Nullable; + /** * A {@link Source} that reads a finite amount of input and, because of that, supports * some additional operations. @@ -37,9 +41,16 @@ *

  • Size estimation: {@link #getEstimatedSizeBytes}; *
  • Telling whether or not this source produces key/value pairs in sorted order: * {@link #producesSortedKeys}; - *
  • The reader ({@link BoundedReader}) supports progress estimation - * ({@link BoundedReader#getFractionConsumed}) and dynamic splitting - * ({@link BoundedReader#splitAtFraction}). + *
  • The accompanying {@link BoundedReader reader} has additional functionality to enable runners + * to dynamically adapt based on runtime conditions. + *
      + *
    • Progress estimation ({@link BoundedReader#getFractionConsumed}) + *
    • Tracking of parallelism, to determine whether the current source can be split + * ({@link BoundedReader#getSplitPointsConsumed()} and + * {@link BoundedReader#getSplitPointsRemaining()}). + *
    • Dynamic splitting of the current source ({@link BoundedReader#splitAtFraction}). + *
    + *
  • * * *

    To use this class for supporting your custom input type, derive your class @@ -82,14 +93,14 @@ public abstract List> splitIntoBundles( * *

    Thread safety

    * All methods will be run from the same thread except {@link #splitAtFraction}, - * {@link #getFractionConsumed} and {@link #getCurrentSource}, which can be called concurrently + * {@link #getFractionConsumed}, {@link #getCurrentSource}, {@link #getSplitPointsConsumed()}, + * and {@link #getSplitPointsRemaining()}, all of which can be called concurrently * from a different thread. There will not be multiple concurrent calls to - * {@link #splitAtFraction} but there can be for {@link #getFractionConsumed} if - * {@link #splitAtFraction} is implemented. + * {@link #splitAtFraction}. * - *

    If the source does not implement {@link #splitAtFraction}, you do not need to worry about - * thread safety. If implemented, it must be safe to call {@link #splitAtFraction} and - * {@link #getFractionConsumed} concurrently with other methods. + *

    It must be safe to call {@link #splitAtFraction}, {@link #getFractionConsumed}, + * {@link #getCurrentSource}, {@link #getSplitPointsConsumed()}, and + * {@link #getSplitPointsRemaining()} concurrently with other methods. * *

    Additionally, a successful {@link #splitAtFraction} call must, by definition, cause * {@link #getCurrentSource} to start returning a different value. @@ -129,10 +140,125 @@ public abstract static class BoundedReader extends Source.Reader { * methods (including itself), and it is therefore critical for it to be implemented * in a thread-safe way. */ + @Nullable public Double getFractionConsumed() { return null; } + /** + * A constant to use as the return value for {@link #getSplitPointsConsumed()} or + * {@link #getSplitPointsRemaining()} when the exact value is unknown. + */ + public static final long SPLIT_POINTS_UNKNOWN = -1; + + /** + * Returns the total amount of parallelism in the consumed (returned and processed) range of + * this reader's current {@link BoundedSource} (as would be returned by + * {@link #getCurrentSource}). This corresponds to all split point records (see + * {@link RangeTracker}) returned by this reader, excluding the last split point + * returned if the reader is not finished. + * + *

    Consider the following examples: (1) An input that can be read in parallel down to the + * individual records, such as {@link CountingSource#upTo}, is called "perfectly splittable". + * (2) a "block-compressed" file format such as {@link AvroIO}, in which a block of records has + * to be read as a whole, but different blocks can be read in parallel. (3) An "unsplittable" + * input such as a cursor in a database. + * + *

      + *
    • Any {@link BoundedReader reader} that is unstarted (aka, has never had a call to + * {@link #start}) has a consumed parallelism of 0. This condition holds independent of whether + * the input is splittable. + *
    • Any {@link BoundedReader reader} that has only returned its first element (aka, + * has never had a call to {@link #advance}) has a consumed parallelism of 0: the first element + * is the current element and is still being processed. This condition holds independent of + * whether the input is splittable. + *
    • For an empty reader (in which the call to {@link #start} returned false), the + * consumed parallelism is 0. This condition holds independent of whether the input is + * splittable. + *
    • For a non-empty, finished reader (in which the call to {@link #start} returned true and + * a call to {@link #advance} has returned false), the value returned must be at least 1 + * and should equal the total parallelism in the source. + *
    • For example (1): After returning record #30 (starting at 1) out of 50 in a perfectly + * splittable 50-record input, this value should be 29. When finished, the consumed parallelism + * should be 50. + *
    • For example (2): In a block-compressed value consisting of 5 blocks, the value should + * stay at 0 until the first record of the second block is returned; stay at 1 until the first + * record of the third block is returned, etc. Only once the end-of-file is reached then the + * fifth block has been consumed and the value should stay at 5. + *
    • For example (3): For any non-empty unsplittable input, the consumed parallelism is 0 + * until the reader is finished (because the last call to {@link #advance} returned false, at + * which point it becomes 1. + *
    + * + *

    A reader that is implemented using a {@link RangeTracker} is encouraged to use the + * range tracker's ability to count split points to implement this method. See + * {@link OffsetBasedSource.OffsetBasedReader} and {@link OffsetRangeTracker} for an example. + * + *

    Defaults to {@link #SPLIT_POINTS_UNKNOWN}. Any value less than 0 will be interpreted + * as unknown. + * + *

    Thread safety

    + * See the javadoc on {@link BoundedReader} for information about thread safety. + * + * @see #getSplitPointsRemaining() + */ + public long getSplitPointsConsumed() { + return SPLIT_POINTS_UNKNOWN; + } + + /** + * Returns the total amount of parallelism in the unprocessed part of this reader's current + * {@link BoundedSource} (as would be returned by {@link #getCurrentSource}). This corresponds + * to all unprocessed split point records (see {@link RangeTracker}), including the last + * split point returned, in the remainder part of the source. + * + *

    This function should be implemented only in addition to + * {@link #getSplitPointsConsumed()} and only if an exact value can be + * returned. + * + *

    Consider the following examples: (1) An input that can be read in parallel down to the + * individual records, such as {@link CountingSource#upTo}, is called "perfectly splittable". + * (2) a "block-compressed" file format such as {@link AvroIO}, in which a block of records has + * to be read as a whole, but different blocks can be read in parallel. (3) An "unsplittable" + * input such as a cursor in a database. + * + *

    Assume for examples (1) and (2) that the number of records or blocks remaining is known: + * + *

      + *
    • Any {@link BoundedReader reader} for which the last call to {@link #start} or + * {@link #advance} has returned true should should not return 0, because this reader itself + * represents parallelism at least 1. This condition holds independent of whether the input is + * splittable. + *
    • A finished reader (for which {@link #start} or {@link #advance}) has returned false + * should return a value of 0. This condition holds independent of whether the input is + * splittable. + *
    • For example 1: After returning record #30 (starting at 1) out of 50 in a perfectly + * splittable 50-record input, this value should be 21 (20 remaining + 1 current) if the total + * number of records is known. + *
    • For example 2: After returning a record in block 3 in a block-compressed file + * consisting of 5 blocks, this value should be 3 (since blocks 4 and 5 can be processed in + * parallel by new readers produced via dynamic work rebalancing, while the current reader + * continues processing block 3) if the total number of blocks is known. + *
    • For example (3): a reader for any non-empty unsplittable input, should return 1 until + * it is finished, at which point it should return 0. + *
    • For any reader: After returning the last split point in a file (e.g., the last record + * in example (1), the first record in the last block for example (2), or the first record in + * the file for example (3), this value should be 1: apart from the current task, no additional + * remainder can be split off. + *
    + * + *

    Defaults to {@link #SPLIT_POINTS_UNKNOWN}. Any value less than 0 will be interpreted as + * unknown. + * + *

    Thread safety

    + * See the javadoc on {@link BoundedReader} for information about thread safety. + * + * @see #getSplitPointsConsumed() + */ + public long getSplitPointsRemaining() { + return SPLIT_POINTS_UNKNOWN; + } + /** * Returns a {@code Source} describing the same input that this {@code Reader} currently reads * (including items already read). @@ -263,6 +389,7 @@ public Double getFractionConsumed() { * *

    By default, returns null to indicate that splitting is not possible. */ + @Nullable public BoundedSource splitAtFraction(double fraction) { return null; } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/CompressedSource.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/CompressedSource.java index 5cb0684bb3ab..8bccf5f59d3f 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/CompressedSource.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/CompressedSource.java @@ -32,11 +32,14 @@ import java.io.IOException; import java.io.PushbackInputStream; import java.io.Serializable; +import java.nio.ByteBuffer; import java.nio.channels.Channels; import java.nio.channels.ReadableByteChannel; import java.util.NoSuchElementException; import java.util.zip.GZIPInputStream; +import javax.annotation.concurrent.GuardedBy; + /** * A Source that reads from compressed files. A {@code CompressedSources} wraps a delegate * {@link FileBasedSource} that is able to read the decompressed file format. @@ -361,7 +364,12 @@ public static class CompressedReader extends FileBasedReader { private final FileBasedReader readerDelegate; private final CompressedSource source; + private final boolean splittable; + private final Object progressLock = new Object(); + @GuardedBy("progressLock") private int numRecordsRead; + @GuardedBy("progressLock") + private CountingChannel channel; /** * Create a {@code CompressedReader} from a {@code CompressedSource} and delegate reader. @@ -369,6 +377,13 @@ public static class CompressedReader extends FileBasedReader { public CompressedReader(CompressedSource source, FileBasedReader readerDelegate) { super(source); this.source = source; + boolean splittable; + try { + splittable = source.isSplittable(); + } catch (Exception e) { + throw new RuntimeException("Unable to tell whether source " + source + " is splittable", e); + } + this.splittable = splittable; this.readerDelegate = readerDelegate; } @@ -380,18 +395,78 @@ public T getCurrent() throws NoSuchElementException { return readerDelegate.getCurrent(); } + @Override + public final long getSplitPointsConsumed() { + if (splittable) { + return readerDelegate.getSplitPointsConsumed(); + } else { + synchronized (progressLock) { + return (isDone() && numRecordsRead > 0) ? 1 : 0; + } + } + } + + @Override + public final long getSplitPointsRemaining() { + if (splittable) { + return readerDelegate.getSplitPointsRemaining(); + } else { + return isDone() ? 0 : 1; + } + } + /** * Returns true only for the first record; compressed sources cannot be split. */ @Override protected final boolean isAtSplitPoint() { - // We have to return true for the first record, but not for the state before reading it, - // and not for the state after reading any other record. Hence == rather than >= or <=. - // This is required because FileBasedReader is intended for readers that can read a range - // of offsets in a file and where the range can be split in parts. CompressedReader, - // however, is a degenerate case because it cannot be split, but it has to satisfy the - // semantics of offsets and split points anyway. - return numRecordsRead == 1; + if (splittable) { + return readerDelegate.isAtSplitPoint(); + } else { + // We have to return true for the first record, but not for the state before reading it, + // and not for the state after reading any other record. Hence == rather than >= or <=. + // This is required because FileBasedReader is intended for readers that can read a range + // of offsets in a file and where the range can be split in parts. CompressedReader, + // however, is a degenerate case because it cannot be split, but it has to satisfy the + // semantics of offsets and split points anyway. + synchronized (progressLock) { + return numRecordsRead == 1; + } + } + } + + private static class CountingChannel implements ReadableByteChannel { + long count; + private final ReadableByteChannel inner; + + public CountingChannel(ReadableByteChannel inner, long count) { + this.inner = inner; + this.count = count; + } + + public long getCount() { + return count; + } + + @Override + public int read(ByteBuffer dst) throws IOException { + int bytes = inner.read(dst); + if (bytes > 0) { + // Avoid the -1 from EOF. + count += bytes; + } + return bytes; + } + + @Override + public boolean isOpen() { + return inner.isOpen(); + } + + @Override + public void close() throws IOException { + inner.close(); + } } /** @@ -400,6 +475,16 @@ protected final boolean isAtSplitPoint() { */ @Override protected final void startReading(ReadableByteChannel channel) throws IOException { + if (splittable) { + // No-op. We will always delegate to the inner reader, so this.channel and this.progressLock + // will never be used. + } else { + synchronized (progressLock) { + this.channel = new CountingChannel(channel, getCurrentSource().getStartOffset()); + channel = this.channel; + } + } + if (source.getChannelFactory() instanceof FileNameBasedDecompressingChannelFactory) { FileNameBasedDecompressingChannelFactory channelFactory = (FileNameBasedDecompressingChannelFactory) source.getChannelFactory(); @@ -420,16 +505,37 @@ protected final boolean readNextRecord() throws IOException { if (!readerDelegate.readNextRecord()) { return false; } - ++numRecordsRead; + synchronized (progressLock) { + ++numRecordsRead; + } return true; } - /** - * Returns the delegate reader's current offset in the decompressed input. - */ + // Splittable: simply delegates to the inner reader. + // + // Unsplittable: returns the offset in the input stream that has been read by the input. + // these positions are likely to be coarse-grained (in the event of buffering) and + // over-estimates (because they reflect the number of bytes read to produce an element, not its + // start) but both of these provide better data than e.g., reporting the start of the file. @Override - protected final long getCurrentOffset() { - return readerDelegate.getCurrentOffset(); + protected final long getCurrentOffset() throws NoSuchElementException { + if (splittable) { + return readerDelegate.getCurrentOffset(); + } else { + synchronized (progressLock) { + if (numRecordsRead <= 1) { + // Since the first record is at a split point, it should start at the beginning of the + // file. This avoids the bad case where the decompressor read the entire file, which + // would cause the file to be treated as empty when returning channel.getCount() as it + // is outside the valid range. + return 0; + } + if (channel == null) { + throw new NoSuchElementException(); + } + return channel.getCount(); + } + } } } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/CountingSource.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/CountingSource.java index b28e8662d31a..403d22eba319 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/CountingSource.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/CountingSource.java @@ -209,6 +209,11 @@ protected long getCurrentOffset() throws NoSuchElementException { return current; } + @Override + public synchronized long getSplitPointsRemaining() { + return Math.max(0, getCurrentSource().getEndOffset() - current); + } + @Override public synchronized BoundedCountingSource getCurrentSource() { return (BoundedCountingSource) super.getCurrentSource(); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/DatastoreIO.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/DatastoreIO.java index cc8e9230d097..137c6cd3bc28 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/DatastoreIO.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/DatastoreIO.java @@ -865,6 +865,8 @@ public static class DatastoreReader extends BoundedSource.BoundedReader */ private int userLimit; + private volatile boolean done = false; + private Entity currentEntity; /** @@ -884,6 +886,16 @@ public Entity getCurrent() { return currentEntity; } + @Override + public final long getSplitPointsConsumed() { + return done ? 1 : 0; + } + + @Override + public final long getSplitPointsRemaining() { + return done ? 0 : 1; + } + @Override public boolean start() throws IOException { return advance(); @@ -901,6 +913,7 @@ public boolean advance() throws IOException { if (entities == null || !entities.hasNext()) { currentEntity = null; + done = true; return false; } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileBasedSource.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileBasedSource.java index 96aeda50f9df..f000f6a71eda 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileBasedSource.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileBasedSource.java @@ -489,7 +489,7 @@ public FileBasedReader(FileBasedSource source) { } @Override - public FileBasedSource getCurrentSource() { + public synchronized FileBasedSource getCurrentSource() { return (FileBasedSource) super.getCurrentSource(); } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/OffsetBasedSource.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/OffsetBasedSource.java index 9ee89a2bf799..2f62acd7a6fd 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/OffsetBasedSource.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/OffsetBasedSource.java @@ -180,7 +180,7 @@ public long getBytesPerOffset() { * *

    As an example in which {@link OffsetBasedSource} is used to implement a file source, suppose * that this source was constructed with an {@code endOffset} of {@link Long#MAX_VALUE} to - * indicate that a file should be read to the end. Then {@link #getMaxEndOffset} should determine + * indicate that a file should be read to the end. Then this function should determine * the actual, exact size of the file in bytes and return it. */ public abstract long getMaxEndOffset(PipelineOptions options) throws Exception; @@ -230,9 +230,22 @@ public void populateDisplayData(DisplayData.Builder builder) { */ public abstract static class OffsetBasedReader extends BoundedReader { private static final Logger LOG = LoggerFactory.getLogger(OffsetBasedReader.class); - private OffsetBasedSource source; + /** + * Returns true if the last call to {@link #start} or {@link #advance} returned false. + */ + public final boolean isDone() { + return rangeTracker.isDone(); + } + + /** + * Returns true if there has been a call to {@link #start}. + */ + public final boolean isStarted() { + return rangeTracker.isStarted(); + } + /** The {@link OffsetRangeTracker} managing the range and current position of the source. */ private final OffsetRangeTracker rangeTracker; @@ -266,12 +279,14 @@ protected boolean isAtSplitPoint() throws NoSuchElementException { @Override public final boolean start() throws IOException { - return startImpl() && rangeTracker.tryReturnRecordAt(isAtSplitPoint(), getCurrentOffset()); + return startImpl() && rangeTracker.tryReturnRecordAt(isAtSplitPoint(), getCurrentOffset()) + || rangeTracker.markDone(); } @Override public final boolean advance() throws IOException { - return advanceImpl() && rangeTracker.tryReturnRecordAt(isAtSplitPoint(), getCurrentOffset()); + return advanceImpl() && rangeTracker.tryReturnRecordAt(isAtSplitPoint(), getCurrentOffset()) + || rangeTracker.markDone(); } /** @@ -314,6 +329,32 @@ public Double getFractionConsumed() { return rangeTracker.getFractionConsumed(); } + @Override + public long getSplitPointsConsumed() { + return rangeTracker.getSplitPointsProcessed(); + } + + @Override + public long getSplitPointsRemaining() { + if (isDone()) { + return 0; + } else if (!isStarted()) { + // Note that even if the current source does not allow splitting, we don't know that + // it's non-empty so we return UNKNOWN instead of 1. + return BoundedReader.SPLIT_POINTS_UNKNOWN; + } else if (!getCurrentSource().allowsDynamicSplitting()) { + // Started (so non-empty) and unsplittable, so only the current task. + return 1; + } else if (getCurrentOffset() >= rangeTracker.getStopPosition() - 1) { + // If this is true, the next element is outside the range. Note that even getCurrentOffset() + // might be larger than the stop position when the current record is not a split point. + return 1; + } else { + // Use the default. + return super.getSplitPointsRemaining(); + } + } + @Override public final synchronized OffsetBasedSource splitAtFraction(double fraction) { if (!getCurrentSource().allowsDynamicSplitting()) { diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TextIO.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TextIO.java index 79eeb081671b..13cb45e2099a 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TextIO.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TextIO.java @@ -817,9 +817,10 @@ static class TextBasedReader extends FileBasedReader { private ByteString buffer; private int startOfSeparatorInBuffer; private int endOfSeparatorInBuffer; - private long startOfNextRecord; - private boolean eof; - private boolean elementIsPresent; + private long startOfRecord; + private volatile long startOfNextRecord; + private volatile boolean eof; + private volatile boolean elementIsPresent; private T currentValue; private ReadableByteChannel inChannel; @@ -834,7 +835,15 @@ protected long getCurrentOffset() throws NoSuchElementException { if (!elementIsPresent) { throw new NoSuchElementException(); } - return startOfNextRecord; + return startOfRecord; + } + + @Override + public long getSplitPointsRemaining() { + if (isStarted() && startOfNextRecord >= getCurrentSource().getEndOffset()) { + return isDone() ? 0 : 1; + } + return super.getSplitPointsRemaining(); } @Override @@ -912,7 +921,7 @@ private void findSeparatorBounds() throws IOException { @Override protected boolean readNextRecord() throws IOException { - startOfNextRecord += endOfSeparatorInBuffer; + startOfRecord = startOfNextRecord; findSeparatorBounds(); // If we have reached EOF file and consumed all of the buffer then we know @@ -923,6 +932,7 @@ protected boolean readNextRecord() throws IOException { } decodeCurrentElement(); + startOfNextRecord = startOfRecord + endOfSeparatorInBuffer; return true; } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/range/OffsetRangeTracker.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/range/OffsetRangeTracker.java index ea1cf14e75d8..76790af08e5d 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/range/OffsetRangeTracker.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/range/OffsetRangeTracker.java @@ -17,6 +17,10 @@ */ package org.apache.beam.sdk.io.range; +import static com.google.common.base.Preconditions.checkState; + +import org.apache.beam.sdk.io.BoundedSource.BoundedReader; + import com.google.common.annotations.VisibleForTesting; import org.slf4j.Logger; @@ -32,6 +36,8 @@ public class OffsetRangeTracker implements RangeTracker { private long stopOffset; private long lastRecordStart = -1L; private long offsetOfLastSplitPoint = -1L; + private long splitPointsSeen = 0L; + private boolean done = false; /** * Offset corresponding to infinity. This can only be used as the upper-bound of a range, and @@ -49,6 +55,15 @@ public OffsetRangeTracker(long startOffset, long stopOffset) { this.stopOffset = stopOffset; } + public synchronized boolean isStarted() { + // done => started: handles the case when the reader was empty. + return (offsetOfLastSplitPoint != -1) || done; + } + + public synchronized boolean isDone() { + return done; + } + @Override public synchronized Long getStartPosition() { return startOffset; @@ -65,10 +80,18 @@ public boolean tryReturnRecordAt(boolean isAtSplitPoint, Long recordStart) { } public synchronized boolean tryReturnRecordAt(boolean isAtSplitPoint, long recordStart) { - if (lastRecordStart == -1 && !isAtSplitPoint) { + if (!isStarted() && !isAtSplitPoint) { throw new IllegalStateException( String.format("The first record [starting at %d] must be at a split point", recordStart)); } + if (recordStart < startOffset) { + throw new IllegalStateException( + String.format( + "Trying to return record [starting at %d] which is before the start offset [%d]", + recordStart, + startOffset)); + + } if (recordStart < lastRecordStart) { throw new IllegalStateException( String.format( @@ -77,8 +100,11 @@ public synchronized boolean tryReturnRecordAt(boolean isAtSplitPoint, long recor recordStart, lastRecordStart)); } + + lastRecordStart = recordStart; + if (isAtSplitPoint) { - if (offsetOfLastSplitPoint != -1L && recordStart == offsetOfLastSplitPoint) { + if (recordStart == offsetOfLastSplitPoint) { throw new IllegalStateException( String.format( "Record at a split point has same offset as the previous split point: " @@ -86,12 +112,13 @@ public synchronized boolean tryReturnRecordAt(boolean isAtSplitPoint, long recor offsetOfLastSplitPoint, recordStart)); } if (recordStart >= stopOffset) { + done = true; return false; } offsetOfLastSplitPoint = recordStart; + ++splitPointsSeen; } - lastRecordStart = recordStart; return true; } @@ -105,7 +132,7 @@ public synchronized boolean trySplitAtPosition(long splitOffset) { LOG.debug("Refusing to split {} at {}: stop position unspecified", this, splitOffset); return false; } - if (lastRecordStart == -1) { + if (!isStarted()) { LOG.debug("Refusing to split {} at {}: unstarted", this, splitOffset); return false; } @@ -143,17 +170,72 @@ public synchronized long getPositionForFractionConsumed(double fraction) { @Override public synchronized double getFractionConsumed() { - if (stopOffset == OFFSET_INFINITY) { + if (!isStarted()) { return 0.0; - } - if (lastRecordStart == -1) { + } else if (isDone()) { + return 1.0; + } else if (stopOffset == OFFSET_INFINITY) { return 0.0; + } else if (lastRecordStart >= stopOffset) { + return 1.0; + } else { + // E.g., when reading [3, 6) and lastRecordStart is 4, that means we consumed 3,4 of 3,4,5 + // which is (4 - 3 + 1) / (6 - 3) = 67%. + // Also, clamp to at most 1.0 because the last consumed position can extend past the + // stop position. + return Math.min(1.0, 1.0 * (lastRecordStart - startOffset + 1) / (stopOffset - startOffset)); } - // E.g., when reading [3, 6) and lastRecordStart is 4, that means we consumed 3,4 of 3,4,5 - // which is (4 - 3 + 1) / (6 - 3) = 67%. - // Also, clamp to at most 1.0 because the last consumed position can extend past the - // stop position. - return Math.min(1.0, 1.0 * (lastRecordStart - startOffset + 1) / (stopOffset - startOffset)); + } + + /** + * Returns the total number of split points that have been processed. + * + *

    A split point at a particular offset has been seen if there has been a corresponding call + * to {@link #tryReturnRecordAt(boolean, long)} with {@code isAtSplitPoint} true. It has been + * processed if there has been a subsequent call to + * {@link #tryReturnRecordAt(boolean, long)} with {@code isAtSplitPoint} true and at a larger + * offset. + * + *

    Note that for correctness when implementing {@link BoundedReader#getSplitPointsConsumed()}, + * if a reader finishes before {@link #tryReturnRecordAt(boolean, long)} returns false, + * the reader should add an additional call to {@link #markDone()}. This will indicate that + * processing for the last seen split point has been finished. + * + * @see org.apache.beam.sdk.io.OffsetBasedSource for a {@link BoundedReader} + * implemented using {@link OffsetRangeTracker}. + */ + public synchronized long getSplitPointsProcessed() { + if (!isStarted()) { + return 0; + } else if (isDone()) { + return splitPointsSeen; + } else { + // There is a current split point, and it has not finished processing. + checkState( + splitPointsSeen > 0, + "A started rangeTracker should have seen > 0 split points (is %s)", + splitPointsSeen); + return splitPointsSeen - 1; + } + } + + + /** + * Marks this range tracker as being done. Specifically, this will mark the current split point, + * if one exists, as being finished. + * + *

    Always returns false, so that it can be used in an implementation of + * {@link BoundedReader#start()} or {@link BoundedReader#advance()} as follows: + * + *

     {@code
    +   * public boolean start() {
    +   *   return startImpl() && rangeTracker.tryReturnRecordAt(isAtSplitPoint, position)
    +   *       || rangeTracker.markDone();
    +   * }} 
    + */ + public synchronized boolean markDone() { + done = true; + return false; } @Override @@ -177,7 +259,10 @@ public synchronized String toString() { @VisibleForTesting OffsetRangeTracker copy() { OffsetRangeTracker res = new OffsetRangeTracker(startOffset, stopOffset); + res.offsetOfLastSplitPoint = this.offsetOfLastSplitPoint; res.lastRecordStart = this.lastRecordStart; + res.done = this.done; + res.splitPointsSeen = this.splitPointsSeen; return res; } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/AvroSourceTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/AvroSourceTest.java index 20c21bca11d0..13f8e7f596b0 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/AvroSourceTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/AvroSourceTest.java @@ -18,9 +18,9 @@ package org.apache.beam.sdk.io; import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem; - import static org.hamcrest.Matchers.containsInAnyOrder; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; @@ -28,6 +28,8 @@ import org.apache.beam.sdk.coders.DefaultCoder; import org.apache.beam.sdk.io.AvroSource.AvroReader; import org.apache.beam.sdk.io.AvroSource.AvroReader.Seeker; +import org.apache.beam.sdk.io.BlockBasedSource.BlockBasedReader; +import org.apache.beam.sdk.io.BoundedSource.BoundedReader; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.testing.SourceTestUtils; @@ -44,6 +46,7 @@ import org.apache.avro.reflect.AvroDefault; import org.apache.avro.reflect.Nullable; import org.apache.avro.reflect.ReflectData; +import org.hamcrest.Matchers; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; @@ -57,6 +60,7 @@ import java.io.IOException; import java.io.PushbackInputStream; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.NoSuchElementException; import java.util.Objects; @@ -197,6 +201,86 @@ public void testGetProgressFromUnstartedReader() throws Exception { } } + @Test + public void testProgress() throws Exception { + // 5 records, 2 per block. + List records = createFixedRecords(5); + String filename = generateTestFile("tmp.avro", records, SyncBehavior.SYNC_REGULAR, 2, + AvroCoder.of(FixedRecord.class), DataFileConstants.NULL_CODEC); + + AvroSource source = AvroSource.from(filename).withSchema(FixedRecord.class); + try (BoundedSource.BoundedReader readerOrig = source.createReader(null)) { + assertThat(readerOrig, Matchers.instanceOf(BlockBasedReader.class)); + BlockBasedReader reader = (BlockBasedReader) readerOrig; + + // Before starting + assertEquals(0.0, reader.getFractionConsumed(), 1e-6); + assertEquals(0, reader.getSplitPointsConsumed()); + assertEquals(BoundedReader.SPLIT_POINTS_UNKNOWN, reader.getSplitPointsRemaining()); + + // First 2 records are in the same block. + assertTrue(reader.start()); + assertTrue(reader.isAtSplitPoint()); + assertEquals(0, reader.getSplitPointsConsumed()); + assertEquals(BoundedReader.SPLIT_POINTS_UNKNOWN, reader.getSplitPointsRemaining()); + // continued + assertTrue(reader.advance()); + assertFalse(reader.isAtSplitPoint()); + assertEquals(0, reader.getSplitPointsConsumed()); + assertEquals(BoundedReader.SPLIT_POINTS_UNKNOWN, reader.getSplitPointsRemaining()); + + // Second block -> parallelism consumed becomes 1. + assertTrue(reader.advance()); + assertTrue(reader.isAtSplitPoint()); + assertEquals(1, reader.getSplitPointsConsumed()); + assertEquals(BoundedReader.SPLIT_POINTS_UNKNOWN, reader.getSplitPointsRemaining()); + // continued + assertTrue(reader.advance()); + assertFalse(reader.isAtSplitPoint()); + assertEquals(1, reader.getSplitPointsConsumed()); + assertEquals(BoundedReader.SPLIT_POINTS_UNKNOWN, reader.getSplitPointsRemaining()); + + // Third and final block -> parallelism consumed becomes 2, remaining becomes 1. + assertTrue(reader.advance()); + assertTrue(reader.isAtSplitPoint()); + assertEquals(2, reader.getSplitPointsConsumed()); + assertEquals(1, reader.getSplitPointsRemaining()); + + // Done + assertFalse(reader.advance()); + assertEquals(3, reader.getSplitPointsConsumed()); + assertEquals(0, reader.getSplitPointsRemaining()); + assertEquals(1.0, reader.getFractionConsumed(), 1e-6); + } + } + + @Test + public void testProgressEmptySource() throws Exception { + // 0 records, 20 per block. + List records = Collections.emptyList(); + String filename = generateTestFile("tmp.avro", records, SyncBehavior.SYNC_REGULAR, 2, + AvroCoder.of(FixedRecord.class), DataFileConstants.NULL_CODEC); + + AvroSource source = AvroSource.from(filename).withSchema(FixedRecord.class); + try (BoundedSource.BoundedReader readerOrig = source.createReader(null)) { + assertThat(readerOrig, Matchers.instanceOf(BlockBasedReader.class)); + BlockBasedReader reader = (BlockBasedReader) readerOrig; + + // before starting + assertEquals(0.0, reader.getFractionConsumed(), 1e-6); + assertEquals(0, reader.getSplitPointsConsumed()); + assertEquals(BoundedReader.SPLIT_POINTS_UNKNOWN, reader.getSplitPointsRemaining()); + + // confirm empty + assertFalse(reader.start()); + + // after reading empty source + assertEquals(0, reader.getSplitPointsConsumed()); + assertEquals(0, reader.getSplitPointsRemaining()); + assertEquals(1.0, reader.getFractionConsumed(), 1e-6); + } + } + @Test public void testGetCurrentFromUnstartedReader() throws Exception { List records = createFixedRecords(DEFAULT_RECORD_COUNT); diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/CompressedSourceTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/CompressedSourceTest.java index 542e7341d36e..7161c1d56867 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/CompressedSourceTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/CompressedSourceTest.java @@ -19,9 +19,10 @@ import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem; import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.includesDisplayDataFrom; - import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.not; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; @@ -29,8 +30,11 @@ import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.SerializableCoder; +import org.apache.beam.sdk.io.BoundedSource.BoundedReader; +import org.apache.beam.sdk.io.CompressedSource.CompressedReader; import org.apache.beam.sdk.io.CompressedSource.CompressionMode; import org.apache.beam.sdk.io.CompressedSource.DecompressingChannelFactory; +import org.apache.beam.sdk.io.FileBasedSource.FileBasedReader; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.testing.PAssert; @@ -462,11 +466,12 @@ public Coder getDefaultOutputCoder() { private static class ByteReader extends FileBasedReader { ByteBuffer buff = ByteBuffer.allocate(1); Byte current; - long offset = -1; + long offset; ReadableByteChannel channel; public ByteReader(ByteSource source) { super(source); + offset = source.getStartOffset() - 1; } @Override @@ -501,4 +506,102 @@ protected long getCurrentOffset() { } } } + + @Test + public void testEmptyGzipProgress() throws IOException { + File tmpFile = tmpFolder.newFile("empty.gz"); + String filename = tmpFile.toPath().toString(); + writeFile(tmpFile, new byte[0], CompressionMode.GZIP); + + PipelineOptions options = PipelineOptionsFactory.create(); + CompressedSource source = CompressedSource.from(new ByteSource(filename, 1)); + try (BoundedReader readerOrig = source.createReader(options)) { + assertThat(readerOrig, instanceOf(CompressedReader.class)); + CompressedReader reader = (CompressedReader) readerOrig; + // before starting + assertEquals(0.0, reader.getFractionConsumed(), 1e-6); + assertEquals(0, reader.getSplitPointsConsumed()); + assertEquals(1, reader.getSplitPointsRemaining()); + + // confirm empty + assertFalse(reader.start()); + + // after reading empty source + assertEquals(1.0, reader.getFractionConsumed(), 1e-6); + assertEquals(0, reader.getSplitPointsConsumed()); + assertEquals(0, reader.getSplitPointsRemaining()); + } + } + + @Test + public void testGzipProgress() throws IOException { + int numRecords = 3; + File tmpFile = tmpFolder.newFile("nonempty.gz"); + String filename = tmpFile.toPath().toString(); + writeFile(tmpFile, new byte[numRecords], CompressionMode.GZIP); + + PipelineOptions options = PipelineOptionsFactory.create(); + CompressedSource source = CompressedSource.from(new ByteSource(filename, 1)); + try (BoundedReader readerOrig = source.createReader(options)) { + assertThat(readerOrig, instanceOf(CompressedReader.class)); + CompressedReader reader = (CompressedReader) readerOrig; + // before starting + assertEquals(0.0, reader.getFractionConsumed(), 1e-6); + assertEquals(0, reader.getSplitPointsConsumed()); + assertEquals(1, reader.getSplitPointsRemaining()); + + // confirm has three records + for (int i = 0; i < numRecords; ++i) { + if (i == 0) { + assertTrue(reader.start()); + } else { + assertTrue(reader.advance()); + } + assertEquals(0, reader.getSplitPointsConsumed()); + assertEquals(1, reader.getSplitPointsRemaining()); + } + assertFalse(reader.advance()); + + // after reading empty source + assertEquals(1.0, reader.getFractionConsumed(), 1e-6); + assertEquals(1, reader.getSplitPointsConsumed()); + assertEquals(0, reader.getSplitPointsRemaining()); + } + } + + @Test + public void testSplittableProgress() throws IOException { + File tmpFile = tmpFolder.newFile("nonempty.txt"); + String filename = tmpFile.toPath().toString(); + Files.write(new byte[2], tmpFile); + + PipelineOptions options = PipelineOptionsFactory.create(); + CompressedSource source = CompressedSource.from(new ByteSource(filename, 1)); + try (BoundedReader readerOrig = source.createReader(options)) { + assertThat(readerOrig, not(instanceOf(CompressedReader.class))); + assertThat(readerOrig, instanceOf(FileBasedReader.class)); + FileBasedReader reader = (FileBasedReader) readerOrig; + + // Check preconditions before starting + assertEquals(0.0, reader.getFractionConsumed(), 1e-6); + assertEquals(0, reader.getSplitPointsConsumed()); + assertEquals(BoundedReader.SPLIT_POINTS_UNKNOWN, reader.getSplitPointsRemaining()); + + // First record: none consumed, unknown remaining. + assertTrue(reader.start()); + assertEquals(0, reader.getSplitPointsConsumed()); + assertEquals(BoundedReader.SPLIT_POINTS_UNKNOWN, reader.getSplitPointsRemaining()); + + // Second record: 1 consumed, know that we're on the last record. + assertTrue(reader.advance()); + assertEquals(1, reader.getSplitPointsConsumed()); + assertEquals(1, reader.getSplitPointsRemaining()); + + // Confirm empty and check post-conditions + assertFalse(reader.advance()); + assertEquals(1.0, reader.getFractionConsumed(), 1e-6); + assertEquals(2, reader.getSplitPointsConsumed()); + assertEquals(0, reader.getSplitPointsRemaining()); + } + } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/CountingSourceTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/CountingSourceTest.java index a261fb274459..bf68d41be691 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/CountingSourceTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/CountingSourceTest.java @@ -24,9 +24,11 @@ import static org.junit.Assert.assertTrue; import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.io.BoundedSource.BoundedReader; import org.apache.beam.sdk.io.CountingSource.CounterMark; import org.apache.beam.sdk.io.CountingSource.UnboundedCountingSource; import org.apache.beam.sdk.io.UnboundedSource.UnboundedReader; +import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.RunnableOnService; import org.apache.beam.sdk.testing.TestPipeline; @@ -49,6 +51,7 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import java.io.IOException; import java.util.List; /** @@ -115,6 +118,33 @@ public void testBoundedSourceSplits() throws Exception { p.run(); } + @Test + public void testProgress() throws IOException { + final int numRecords = 5; + @SuppressWarnings("deprecation") // testing CountingSource + BoundedSource source = CountingSource.upTo(numRecords); + try (BoundedReader reader = source.createReader(PipelineOptionsFactory.create())) { + // Check preconditions before starting. Note that CountingReader can always give an accurate + // remaining parallelism. + assertEquals(0.0, reader.getFractionConsumed(), 1e-6); + assertEquals(0, reader.getSplitPointsConsumed()); + assertEquals(numRecords, reader.getSplitPointsRemaining()); + + assertTrue(reader.start()); + int i = 0; + do { + assertEquals(i, reader.getSplitPointsConsumed()); + assertEquals(numRecords - i, reader.getSplitPointsRemaining()); + ++i; + } while (reader.advance()); + + assertEquals(numRecords, i); // exactly numRecords calls to advance() + assertEquals(1.0, reader.getFractionConsumed(), 1e-6); + assertEquals(numRecords, reader.getSplitPointsConsumed()); + assertEquals(0, reader.getSplitPointsRemaining()); + } + } + @Test @Category(RunnableOnService.class) public void testUnboundedSource() { diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/FileBasedSourceTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/FileBasedSourceTest.java index bedbc9977844..1f16d39a290b 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/FileBasedSourceTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/FileBasedSourceTest.java @@ -446,7 +446,7 @@ public void testFractionConsumedWhenReadingFilepattern() throws IOException { assertTrue(fractionConsumed > lastFractionConsumed); lastFractionConsumed = fractionConsumed; } - assertTrue(reader.getFractionConsumed() < 1.0); + assertEquals(1.0, reader.getFractionConsumed(), 1e-6); } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/OffsetBasedSourceTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/OffsetBasedSourceTest.java index e9b61aaa1a9c..66abd334a5cd 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/OffsetBasedSourceTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/OffsetBasedSourceTest.java @@ -19,7 +19,6 @@ import static org.apache.beam.sdk.testing.SourceTestUtils.assertSplitAtFractionExhaustive; import static org.apache.beam.sdk.testing.SourceTestUtils.readFromSource; - import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; @@ -28,6 +27,8 @@ import org.apache.beam.sdk.coders.BigEndianIntegerCoder; import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.io.BoundedSource.BoundedReader; +import org.apache.beam.sdk.io.OffsetBasedSource.OffsetBasedReader; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.PipelineOptionsFactory; @@ -86,13 +87,12 @@ public long getMaxEndOffset(PipelineOptions options) { } @Override - public BoundedReader createReader(PipelineOptions options) throws IOException { + public OffsetBasedReader createReader(PipelineOptions options) throws IOException { return new CoarseRangeReader(this); } } - private static class CoarseRangeReader - extends OffsetBasedSource.OffsetBasedReader { + private static class CoarseRangeReader extends OffsetBasedReader { private long current = -1; private long granularity; @@ -238,6 +238,69 @@ public void testReadingGranularityAndFractionConsumed() throws IOException { } } + @Test + public void testProgress() throws IOException { + PipelineOptions options = PipelineOptionsFactory.create(); + CoarseRangeSource source = new CoarseRangeSource(13, 17, 1, 2); + try (OffsetBasedReader reader = source.createReader(options)) { + // Unstarted reader + assertEquals(0.0, reader.getFractionConsumed(), 1e-6); + assertEquals(0, reader.getSplitPointsConsumed()); + assertEquals(BoundedReader.SPLIT_POINTS_UNKNOWN, reader.getSplitPointsRemaining()); + + // Start and produce the element 14 since granularity is 2. + assertTrue(reader.start()); + assertTrue(reader.isAtSplitPoint()); + assertEquals(14, reader.getCurrent().intValue()); + assertEquals(0, reader.getSplitPointsConsumed()); + assertEquals(BoundedReader.SPLIT_POINTS_UNKNOWN, reader.getSplitPointsRemaining()); + // Advance and produce the element 15, not a split point. + assertTrue(reader.advance()); + assertEquals(15, reader.getCurrent().intValue()); + assertEquals(0, reader.getSplitPointsConsumed()); + assertEquals(BoundedReader.SPLIT_POINTS_UNKNOWN, reader.getSplitPointsRemaining()); + + // Advance and produce the element 16, is a split point. Since the next offset (17) is + // outside the range [13, 17), remaining parallelism should become 1 from UNKNOWN. + assertTrue(reader.advance()); + assertTrue(reader.isAtSplitPoint()); + assertEquals(16, reader.getCurrent().intValue()); + assertEquals(1, reader.getSplitPointsConsumed()); + assertEquals(1, reader.getSplitPointsRemaining()); // The next offset is outside the range. + // Advance and produce the element 17, not a split point. + assertTrue(reader.advance()); + assertEquals(17, reader.getCurrent().intValue()); + assertEquals(1, reader.getSplitPointsConsumed()); + assertEquals(1, reader.getSplitPointsRemaining()); + + // Advance and reach the end of the reader. + assertFalse(reader.advance()); + assertEquals(1.0, reader.getFractionConsumed(), 1e-6); + assertEquals(2, reader.getSplitPointsConsumed()); + assertEquals(0, reader.getSplitPointsRemaining()); + } + } + + @Test + public void testProgressEmptySource() throws IOException { + PipelineOptions options = PipelineOptionsFactory.create(); + CoarseRangeSource source = new CoarseRangeSource(13, 17, 1, 100); + try (OffsetBasedReader reader = source.createReader(options)) { + // before starting + assertEquals(0.0, reader.getFractionConsumed(), 1e-6); + assertEquals(0, reader.getSplitPointsConsumed()); + assertEquals(BoundedReader.SPLIT_POINTS_UNKNOWN, reader.getSplitPointsRemaining()); + + // confirm empty + assertFalse(reader.start()); + + // after reading empty source + assertEquals(1.0, reader.getFractionConsumed(), 1e-6); + assertEquals(0, reader.getSplitPointsConsumed()); + assertEquals(0, reader.getSplitPointsRemaining()); + } + } + @Test public void testSplitAtFraction() throws IOException { PipelineOptions options = PipelineOptionsFactory.create(); diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TextIOTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TextIOTest.java index 4d6d8dd4ac0b..53a2a89241df 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TextIOTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TextIOTest.java @@ -22,10 +22,10 @@ import static org.apache.beam.sdk.TestUtils.NO_INTS_ARRAY; import static org.apache.beam.sdk.TestUtils.NO_LINES_ARRAY; import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem; - import static org.hamcrest.Matchers.containsInAnyOrder; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; @@ -34,6 +34,7 @@ import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.coders.TextualIntegerCoder; import org.apache.beam.sdk.coders.VoidCoder; +import org.apache.beam.sdk.io.BoundedSource.BoundedReader; import org.apache.beam.sdk.io.TextIO.CompressionType; import org.apache.beam.sdk.io.TextIO.TextSource; import org.apache.beam.sdk.options.PipelineOptionsFactory; @@ -422,6 +423,117 @@ public void testTextIOGetName() { "ReadMyFile [TextIO.Read]", TextIO.Read.named("ReadMyFile").from("somefile").toString()); } + @Test + public void testProgressEmptyFile() throws IOException { + try (BoundedReader reader = + prepareSource(new byte[0]).createReader(PipelineOptionsFactory.create())) { + // Check preconditions before starting. + assertEquals(0.0, reader.getFractionConsumed(), 1e-6); + assertEquals(0, reader.getSplitPointsConsumed()); + assertEquals(BoundedReader.SPLIT_POINTS_UNKNOWN, reader.getSplitPointsRemaining()); + + // Assert empty + assertFalse(reader.start()); + + // Check postconditions after finishing + assertEquals(1.0, reader.getFractionConsumed(), 1e-6); + assertEquals(0, reader.getSplitPointsConsumed()); + assertEquals(0, reader.getSplitPointsRemaining()); + } + } + + @Test + public void testProgressTextFile() throws IOException { + String file = "line1\nline2\nline3"; + try (BoundedReader reader = + prepareSource(file.getBytes()).createReader(PipelineOptionsFactory.create())) { + // Check preconditions before starting + assertEquals(0.0, reader.getFractionConsumed(), 1e-6); + assertEquals(0, reader.getSplitPointsConsumed()); + assertEquals(BoundedReader.SPLIT_POINTS_UNKNOWN, reader.getSplitPointsRemaining()); + + // Line 1 + assertTrue(reader.start()); + assertEquals(0, reader.getSplitPointsConsumed()); + assertEquals(BoundedReader.SPLIT_POINTS_UNKNOWN, reader.getSplitPointsRemaining()); + + // Line 2 + assertTrue(reader.advance()); + assertEquals(1, reader.getSplitPointsConsumed()); + assertEquals(BoundedReader.SPLIT_POINTS_UNKNOWN, reader.getSplitPointsRemaining()); + + // Line 3 + assertTrue(reader.advance()); + assertEquals(2, reader.getSplitPointsConsumed()); + assertEquals(1, reader.getSplitPointsRemaining()); + + // Check postconditions after finishing + assertFalse(reader.advance()); + assertEquals(1.0, reader.getFractionConsumed(), 1e-6); + assertEquals(3, reader.getSplitPointsConsumed()); + assertEquals(0, reader.getSplitPointsRemaining()); + } + } + + @Test + public void testProgressAfterSplitting() throws IOException { + String file = "line1\nline2\nline3"; + BoundedSource source = prepareSource(file.getBytes()); + BoundedSource remainder; + + // Create the remainder, verifying properties pre- and post-splitting. + try (BoundedReader readerOrig = source.createReader(PipelineOptionsFactory.create())) { + // Preconditions. + assertEquals(0.0, readerOrig.getFractionConsumed(), 1e-6); + assertEquals(0, readerOrig.getSplitPointsConsumed()); + assertEquals(BoundedReader.SPLIT_POINTS_UNKNOWN, readerOrig.getSplitPointsRemaining()); + + // First record, before splitting. + assertTrue(readerOrig.start()); + assertEquals(0, readerOrig.getSplitPointsConsumed()); + assertEquals(BoundedReader.SPLIT_POINTS_UNKNOWN, readerOrig.getSplitPointsRemaining()); + + // Split. 0.1 is in line1, so should now be able to detect last record. + remainder = readerOrig.splitAtFraction(0.1); + System.err.println(readerOrig.getCurrentSource()); + assertNotNull(remainder); + + // First record, after splitting. + assertEquals(0, readerOrig.getSplitPointsConsumed()); + assertEquals(1, readerOrig.getSplitPointsRemaining()); + + // Finish and postconditions. + assertFalse(readerOrig.advance()); + assertEquals(1.0, readerOrig.getFractionConsumed(), 1e-6); + assertEquals(1, readerOrig.getSplitPointsConsumed()); + assertEquals(0, readerOrig.getSplitPointsRemaining()); + } + + // Check the properties of the remainder. + try (BoundedReader reader = remainder.createReader(PipelineOptionsFactory.create())) { + // Preconditions. + assertEquals(0.0, reader.getFractionConsumed(), 1e-6); + assertEquals(0, reader.getSplitPointsConsumed()); + assertEquals(BoundedReader.SPLIT_POINTS_UNKNOWN, reader.getSplitPointsRemaining()); + + // First record should be line 2. + assertTrue(reader.start()); + assertEquals(0, reader.getSplitPointsConsumed()); + assertEquals(BoundedReader.SPLIT_POINTS_UNKNOWN, reader.getSplitPointsRemaining()); + + // Second record is line 3 + assertTrue(reader.advance()); + assertEquals(1, reader.getSplitPointsConsumed()); + assertEquals(1, reader.getSplitPointsRemaining()); + + // Check postconditions after finishing + assertFalse(reader.advance()); + assertEquals(1.0, reader.getFractionConsumed(), 1e-6); + assertEquals(2, reader.getSplitPointsConsumed()); + assertEquals(0, reader.getSplitPointsRemaining()); + } + } + @Test public void testReadEmptyLines() throws Exception { runTestReadWithData("\n\n\n".getBytes(StandardCharsets.UTF_8), diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/range/OffsetRangeTrackerTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/range/OffsetRangeTrackerTest.java index 3de04f762208..edd4c4f1cb6b 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/range/OffsetRangeTrackerTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/range/OffsetRangeTrackerTest.java @@ -104,7 +104,6 @@ public void testSplitAtOffset() throws Exception { assertFalse(tracker.tryReturnRecordAt(true, 150)); assertFalse(tracker.tryReturnRecordAt(true, 151)); // Should accept non-splitpoint records starting after stop offset. - assertTrue(tracker.tryReturnRecordAt(false, 135)); assertTrue(tracker.tryReturnRecordAt(false, 152)); assertTrue(tracker.tryReturnRecordAt(false, 160)); assertFalse(tracker.tryReturnRecordAt(true, 171)); diff --git a/sdks/java/io/hdfs/src/main/java/org/apache/beam/sdk/io/hdfs/HDFSFileSource.java b/sdks/java/io/hdfs/src/main/java/org/apache/beam/sdk/io/hdfs/HDFSFileSource.java index ab537eb6ca46..41a271cbd8cc 100644 --- a/sdks/java/io/hdfs/src/main/java/org/apache/beam/sdk/io/hdfs/HDFSFileSource.java +++ b/sdks/java/io/hdfs/src/main/java/org/apache/beam/sdk/io/hdfs/HDFSFileSource.java @@ -282,6 +282,7 @@ static class HDFSFileReader extends BoundedSource.BoundedReader> private Configuration conf; private RecordReader currentReader; private KV currentPair; + private volatile boolean done = false; /** * Create a {@code HDFSFileReader} based on a file or a file pattern specification. @@ -356,6 +357,7 @@ public boolean advance() throws IOException { } // either no next split or all readers were empty currentPair = null; + done = true; return false; } } catch (InterruptedException e) { @@ -432,6 +434,16 @@ private Double getProgress() { } } + @Override + public final long getSplitPointsRemaining() { + if (done) { + return 0; + } + // This source does not currently support dynamic work rebalancing, so remaining + // parallelism is always 1. + return 1; + } + @Override public BoundedSource> splitAtFraction(double fraction) { // Not yet supported. To implement this, the sizes of the splits should be used to From 1e669c44c9d2448b55f5bdba3dcff1831b2cd8b4 Mon Sep 17 00:00:00 2001 From: Scott Wegner Date: Thu, 19 May 2016 09:17:37 -0700 Subject: [PATCH 21/21] Fix bug in PipelineOptions DisplayData serialization PipelineOptions has been improved to generate display data to be consumed by a runner and used for display. However, there was a bug in the ProxyInvocationHandler implementation of PipelineOptions display data which was causing NullPointerExceptions when generated display data from PipelineOptions previously deserialized from JSON. This change also makes our error handling for display data exceptions consistent across the Dataflow runner: exceptions thrown during display data population will propogate out and cause the pipeline to fail. This is consistent with other user code which may throw exceptions at pipeline construction time. --- .../dataflow/DataflowPipelineTranslator.java | 50 +--------- .../DataflowPipelineTranslatorTest.java | 63 ------------- .../sdk/options/ProxyInvocationHandler.java | 4 +- .../sdk/transforms/display/DisplayData.java | 14 ++- .../options/ProxyInvocationHandlerTest.java | 6 +- .../transforms/display/DisplayDataTest.java | 91 +++++++++++++------ 6 files changed, 83 insertions(+), 145 deletions(-) diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java index 7f673932f33d..f5fefc0f6f94 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java @@ -29,7 +29,6 @@ import static org.apache.beam.sdk.util.Structs.getString; import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Preconditions.checkNotNull; import org.apache.beam.runners.dataflow.DataflowPipelineRunner.GroupByKeyAndSortValuesOnly; import org.apache.beam.runners.dataflow.internal.ReadTranslator; @@ -87,8 +86,6 @@ import org.slf4j.LoggerFactory; import java.io.IOException; -import java.io.PrintWriter; -import java.io.StringWriter; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -726,18 +723,7 @@ private void addOutput(String name, PValue value, Coder valueCoder) { } private void addDisplayData(String stepName, HasDisplayData hasDisplayData) { - DisplayData displayData; - try { - displayData = DisplayData.from(hasDisplayData); - } catch (Exception e) { - String msg = String.format("Exception thrown while collecting display data for step: %s. " - + "Display data will be not be available for this step.", stepName); - DisplayDataException displayDataException = new DisplayDataException(msg, e); - LOG.warn(msg, displayDataException); - - displayData = displayDataException.asDisplayData(); - } - + DisplayData displayData = DisplayData.from(hasDisplayData); List> list = MAPPER.convertValue(displayData, List.class); addList(getProperties(), PropertyNames.DISPLAY_DATA, list); } @@ -1056,38 +1042,4 @@ private static void translateOutputs( context.addOutput(tag.getId(), output); } } - - /** - * Wraps exceptions thrown while collecting {@link DisplayData} for the Dataflow pipeline runner. - */ - static class DisplayDataException extends Exception implements HasDisplayData { - public DisplayDataException(String message, Throwable cause) { - super(checkNotNull(message), checkNotNull(cause)); - } - - /** - * Retrieve a display data representation of the exception, which can be submitted to - * the service in place of the actual display data. - */ - public DisplayData asDisplayData() { - return DisplayData.from(this); - } - - @Override - public void populateDisplayData(DisplayData.Builder builder) { - Throwable cause = getCause(); - builder - .add(DisplayData.item("exceptionMessage", getMessage())) - .add(DisplayData.item("exceptionType", cause.getClass())) - .add(DisplayData.item("exceptionCause", cause.getMessage())) - .add(DisplayData.item("stackTrace", stackTraceToString())); - } - - private String stackTraceToString() { - StringWriter stringWriter = new StringWriter(); - PrintWriter printWriter = new PrintWriter(stringWriter); - printStackTrace(printWriter); - return stringWriter.toString(); - } - } } diff --git a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslatorTest.java b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslatorTest.java index 58c6f75cbdf6..165d2b51ba95 100644 --- a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslatorTest.java +++ b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslatorTest.java @@ -23,9 +23,7 @@ import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.containsString; -import static org.hamcrest.Matchers.hasEntry; import static org.hamcrest.Matchers.hasKey; -import static org.hamcrest.Matchers.is; import static org.hamcrest.core.IsInstanceOf.instanceOf; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNull; @@ -50,7 +48,6 @@ import org.apache.beam.sdk.io.TextIO; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.runners.RecordingPipelineVisitor; -import org.apache.beam.sdk.testing.ExpectedLogs; import org.apache.beam.sdk.transforms.Count; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.DoFn; @@ -80,7 +77,6 @@ import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; -import org.hamcrest.Matchers; import org.junit.Assert; import org.junit.Rule; import org.junit.Test; @@ -106,9 +102,7 @@ */ @RunWith(JUnit4.class) public class DataflowPipelineTranslatorTest implements Serializable { - @Rule public transient ExpectedException thrown = ExpectedException.none(); - @Rule public transient ExpectedLogs logs = ExpectedLogs.none(DataflowPipelineTranslator.class); // A Custom Mockito matcher for an initial Job that checks that all // expected fields are set. @@ -973,61 +967,4 @@ public void populateDisplayData(DisplayData.Builder builder) { assertEquals(expectedFn1DisplayData, ImmutableSet.copyOf(fn1displayData)); assertEquals(expectedFn2DisplayData, ImmutableSet.copyOf(fn2displayData)); } - - @Test - public void testCapturesDisplayDataExceptions() throws IOException { - DataflowPipelineOptions options = buildPipelineOptions(); - DataflowPipelineTranslator translator = DataflowPipelineTranslator.fromOptions(options); - Pipeline pipeline = Pipeline.create(options); - - final RuntimeException displayDataException = new RuntimeException("foobar"); - pipeline - .apply(Create.of(1, 2, 3)) - .apply(ParDo.of(new DoFn() { - @Override - public void processElement(ProcessContext c) throws Exception { - c.output(c.element()); - } - - @Override - public void populateDisplayData(DisplayData.Builder builder) { - throw displayDataException; - } - })); - - Job job = translator.translate( - pipeline, - (DataflowPipelineRunner) pipeline.getRunner(), - Collections.emptyList()).getJob(); - - String expectedMessage = "Display data will be not be available for this step"; - logs.verifyWarn(expectedMessage); - - List steps = job.getSteps(); - assertEquals("Job should have 2 steps", 2, steps.size()); - - @SuppressWarnings("unchecked") - Iterable> displayData = (Collection>) steps.get(1) - .getProperties().get("display_data"); - - String namespace = DataflowPipelineTranslator.DisplayDataException.class.getName(); - Assert.assertThat(displayData, Matchers.>hasItem(allOf( - hasEntry("namespace", namespace), - hasEntry("key", "exceptionType"), - hasEntry("value", RuntimeException.class.getName())))); - - Assert.assertThat(displayData, Matchers.>hasItem(allOf( - hasEntry("namespace", namespace), - hasEntry("key", "exceptionMessage"), - hasEntry(is("value"), Matchers.containsString(expectedMessage))))); - - Assert.assertThat(displayData, Matchers.>hasItem(allOf( - hasEntry("namespace", namespace), - hasEntry("key", "exceptionCause"), - hasEntry("value", "foobar")))); - - Assert.assertThat(displayData, Matchers.>hasItem(allOf( - hasEntry("namespace", namespace), - hasEntry("key", "stackTrace")))); - } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/options/ProxyInvocationHandler.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/options/ProxyInvocationHandler.java index 159eb5bb6c89..3292a7f366b5 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/options/ProxyInvocationHandler.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/options/ProxyInvocationHandler.java @@ -315,6 +315,7 @@ private void populateDisplayData(DisplayData.Builder builder) { } Object value = getValueFromJson(jsonOption.getKey(), spec.getGetterMethod()); + value = value == null ? "" : value; DisplayData.Type type = DisplayData.inferType(value); if (type != null) { builder.add(DisplayData.item(jsonOption.getKey(), type, value) @@ -552,7 +553,8 @@ public void serialize(PipelineOptions value, JsonGenerator jgen, SerializerProvi jgen.writeObject(serializableOptions); List> serializedDisplayData = Lists.newArrayList(); - for (DisplayData.Item item : DisplayData.from(value).items()) { + DisplayData displayData = DisplayData.from(value); + for (DisplayData.Item item : displayData.items()) { @SuppressWarnings("unchecked") Map serializedItem = MAPPER.convertValue(item, Map.class); serializedDisplayData.add(serializedItem); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/display/DisplayData.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/display/DisplayData.java index dc6e381d0929..9e9bdbfab2b7 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/display/DisplayData.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/display/DisplayData.java @@ -72,10 +72,6 @@ public static DisplayData none() { * Collect the {@link DisplayData} from a component. This will traverse all subcomponents * specified via {@link Builder#include} in the given component. Data in this component will be in * a namespace derived from the component. - * - *

    Pipeline runners should call this method in order to collect display data. While it should - * be safe to call {@code DisplayData.from} on any component which implements it, runners should - * be resilient to exceptions thrown while collecting display data. */ public static DisplayData from(HasDisplayData component) { checkNotNull(component, "component argument cannot be null"); @@ -603,7 +599,15 @@ public Builder include(HasDisplayData subComponent, String namespace) { if (newComponent) { String prevNs = this.latestNs; this.latestNs = namespace; - subComponent.populateDisplayData(this); + + try { + subComponent.populateDisplayData(this); + } catch (Throwable e) { + String msg = String.format("Error while populating display data for component: %s", + namespace); + throw new RuntimeException(msg, e); + } + this.latestNs = prevNs; } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/options/ProxyInvocationHandlerTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/options/ProxyInvocationHandlerTest.java index 6fc970015d3c..110f30a2247a 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/options/ProxyInvocationHandlerTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/options/ProxyInvocationHandlerTest.java @@ -863,12 +863,16 @@ public void testDisplayDataIncludesExplicitlySetDefaults() { } @Test - public void testDisplayDataNullValuesConvertedToEmptyString() { + public void testDisplayDataNullValuesConvertedToEmptyString() throws Exception { FooOptions options = PipelineOptionsFactory.as(FooOptions.class); options.setFoo(null); DisplayData data = DisplayData.from(options); assertThat(data, hasDisplayItem("foo", "")); + + FooOptions deserializedOptions = serializeDeserialize(FooOptions.class, options); + DisplayData deserializedData = DisplayData.from(deserializedOptions); + assertThat(deserializedData, hasDisplayItem("foo", "")); } @Test diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/display/DisplayDataTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/display/DisplayDataTest.java index 21b2e3388a5e..478724ba6fd5 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/display/DisplayDataTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/display/DisplayDataTest.java @@ -30,6 +30,7 @@ import static org.hamcrest.Matchers.hasItem; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.isA; import static org.hamcrest.Matchers.isEmptyOrNullString; import static org.hamcrest.Matchers.not; import static org.hamcrest.Matchers.notNullValue; @@ -39,11 +40,6 @@ import static org.junit.Assert.assertThat; import static org.junit.Assert.fail; -import org.apache.beam.sdk.Pipeline; -import org.apache.beam.sdk.testing.PAssert; -import org.apache.beam.sdk.testing.RunnableOnService; -import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; @@ -69,7 +65,6 @@ import org.joda.time.format.ISODateTimeFormat; import org.junit.Rule; import org.junit.Test; -import org.junit.experimental.categories.Category; import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -86,6 +81,7 @@ @RunWith(JUnit4.class) public class DisplayDataTest implements Serializable { @Rule public transient ExpectedException thrown = ExpectedException.none(); + private static final DateTimeFormatter ISO_FORMATTER = ISODateTimeFormat.dateTime(); private static final ObjectMapper MAPPER = new ObjectMapper(); @@ -413,7 +409,7 @@ public void populateDisplayData(Builder builder) { @Test public void testNullNamespaceOverride() { - thrown.expect(NullPointerException.class); + thrown.expectCause(isA(NullPointerException.class)); DisplayData.from(new HasDisplayData() { @Override @@ -516,7 +512,7 @@ public void populateDisplayData(DisplayData.Builder builder) { @Test public void testDuplicateKeyThrowsException() { - thrown.expect(IllegalArgumentException.class); + thrown.expectCause(isA(IllegalArgumentException.class)); DisplayData.from( new HasDisplayData() { @Override @@ -752,7 +748,7 @@ public void populateDisplayData(Builder builder) { } }; - thrown.expect(ClassCastException.class); + thrown.expectCause(isA(ClassCastException.class)); DisplayData.from(component); } @@ -838,7 +834,7 @@ public void testFromNull() { @Test public void testIncludeNull() { - thrown.expect(NullPointerException.class); + thrown.expectCause(isA(NullPointerException.class)); DisplayData.from( new HasDisplayData() { @Override @@ -856,7 +852,7 @@ public void populateDisplayData(Builder builder) { } }; - thrown.expect(NullPointerException.class); + thrown.expectCause(isA(NullPointerException.class)); DisplayData.from(new HasDisplayData() { @Override public void populateDisplayData(Builder builder) { @@ -867,7 +863,7 @@ public void populateDisplayData(Builder builder) { @Test public void testNullKey() { - thrown.expect(NullPointerException.class); + thrown.expectCause(isA(NullPointerException.class)); DisplayData.from( new HasDisplayData() { @Override @@ -968,23 +964,66 @@ public void populateDisplayData(Builder builder) { } /** - * Validate that all runners are resilient to exceptions thrown while retrieving display data. + * Verify that {@link DisplayData.Builder} can recover from exceptions thrown in user code. + * This is not used within the Beam SDK since we want all code to produce valid DisplayData. + * This test just ensures it is possible to write custom code that does recover. */ @Test - @Category(RunnableOnService.class) - public void testRunnersResilientToDisplayDataExceptions() { - Pipeline p = TestPipeline.create(); - PCollection pCol = p - .apply(Create.of(1, 2, 3)) - .apply(new IdentityTransform() { - @Override - public void populateDisplayData(Builder builder) { - throw new RuntimeException("bug!"); - } - }); + public void testCanRecoverFromBuildException() { + final HasDisplayData safeComponent = new HasDisplayData() { + @Override + public void populateDisplayData(Builder builder) { + builder.add(DisplayData.item("a", "a")); + } + }; + + final HasDisplayData failingComponent = new HasDisplayData() { + @Override + public void populateDisplayData(Builder builder) { + throw new RuntimeException("oh noes!"); + } + }; + + DisplayData displayData = DisplayData.from(new HasDisplayData() { + @Override + public void populateDisplayData(Builder builder) { + builder + .add(DisplayData.item("b", "b")) + .add(DisplayData.item("c", "c")); + + try { + builder.include(failingComponent); + fail("Expected exception not thrown"); + } catch (RuntimeException e) { + // Expected + } + + builder + .include(safeComponent) + .add(DisplayData.item("d", "d")); + } + }); + + assertThat(displayData, hasDisplayItem("a")); + assertThat(displayData, hasDisplayItem("b")); + assertThat(displayData, hasDisplayItem("c")); + assertThat(displayData, hasDisplayItem("d")); + } - PAssert.that(pCol).containsInAnyOrder(1, 2, 3); - p.run(); + @Test + public void testExceptionMessage() { + final RuntimeException cause = new RuntimeException("oh noes!"); + HasDisplayData component = new HasDisplayData() { + @Override + public void populateDisplayData(Builder builder) { + throw cause; + } + }; + + thrown.expectMessage(component.getClass().getName()); + thrown.expectCause(is(cause)); + + DisplayData.from(component); } private static class IdentityTransform extends PTransform, PCollection> {