From f2e3088633fef10f19bfd11ff9b508930916a740 Mon Sep 17 00:00:00 2001 From: Charles Chen Date: Wed, 7 Jun 2017 17:00:57 -0700 Subject: [PATCH 001/200] Remove support for NativeSinks from the Python DirectRunner --- .../runners/direct/transform_evaluator.py | 62 +------------------ 1 file changed, 1 insertion(+), 61 deletions(-) diff --git a/sdks/python/apache_beam/runners/direct/transform_evaluator.py b/sdks/python/apache_beam/runners/direct/transform_evaluator.py index b1cb626ca0cb6..0fec8b8cb1db3 100644 --- a/sdks/python/apache_beam/runners/direct/transform_evaluator.py +++ b/sdks/python/apache_beam/runners/direct/transform_evaluator.py @@ -29,7 +29,6 @@ from apache_beam.runners.common import DoFnState from apache_beam.runners.direct.watermark_manager import WatermarkManager from apache_beam.runners.direct.transform_result import TransformResult -from apache_beam.runners.dataflow.native_io.iobase import _NativeWrite # pylint: disable=protected-access from apache_beam.transforms import core from apache_beam.transforms.window import GlobalWindows from apache_beam.transforms.window import WindowedValue @@ -54,7 +53,6 @@ def __init__(self, evaluation_context): core.Flatten: _FlattenEvaluator, core.ParDo: _ParDoEvaluator, core._GroupByKeyOnly: _GroupByKeyOnlyEvaluator, - _NativeWrite: _NativeWriteEvaluator, } def for_application( @@ -98,8 +96,7 @@ def should_execute_serially(self, applied_ptransform): Returns: True if executor should execute applied_ptransform serially. """ - return isinstance(applied_ptransform.transform, - (core._GroupByKeyOnly, _NativeWrite)) + return isinstance(applied_ptransform.transform, core._GroupByKeyOnly) class _TransformEvaluator(object): @@ -403,60 +400,3 @@ def len_element_fn(element): return TransformResult( self._applied_ptransform, bundles, state, None, None, hold) - - -class _NativeWriteEvaluator(_TransformEvaluator): - """TransformEvaluator for _NativeWrite transform.""" - - def __init__(self, evaluation_context, applied_ptransform, - input_committed_bundle, side_inputs, scoped_metrics_container): - assert not side_inputs - super(_NativeWriteEvaluator, self).__init__( - evaluation_context, applied_ptransform, input_committed_bundle, - side_inputs, scoped_metrics_container) - - assert applied_ptransform.transform.sink - self._sink = applied_ptransform.transform.sink - - @property - def _is_final_bundle(self): - return (self._execution_context.watermarks.input_watermark - == WatermarkManager.WATERMARK_POS_INF) - - @property - def _has_already_produced_output(self): - return (self._execution_context.watermarks.output_watermark - == WatermarkManager.WATERMARK_POS_INF) - - def start_bundle(self): - # state: [values] - self.state = (self._execution_context.existing_state - if self._execution_context.existing_state else []) - - def process_element(self, element): - self.state.append(element) - - def finish_bundle(self): - # finish_bundle will append incoming bundles in memory until all the bundles - # carrying data is processed. This is done to produce only a single output - # shard (some tests depends on this behavior). It is possible to have - # incoming empty bundles after the output is produced, these bundles will be - # ignored and would not generate additional output files. - # TODO(altay): Do not wait until the last bundle to write in a single shard. - if self._is_final_bundle: - if self._has_already_produced_output: - # Ignore empty bundles that arrive after the output is produced. - assert self.state == [] - else: - self._sink.pipeline_options = self._evaluation_context.pipeline_options - with self._sink.writer() as writer: - for v in self.state: - writer.Write(v.value) - state = None - hold = WatermarkManager.WATERMARK_POS_INF - else: - state = self.state - hold = WatermarkManager.WATERMARK_NEG_INF - - return TransformResult( - self._applied_ptransform, [], state, None, None, hold) From d94ac58ea2d12f55743e8ad27a02bdb83c194da7 Mon Sep 17 00:00:00 2001 From: Vikas Kedigehalli Date: Wed, 7 Jun 2017 16:26:21 -0700 Subject: [PATCH 002/200] Make BytesCoder to be a known type --- sdks/python/apache_beam/coders/coders.py | 5 +++++ sdks/python/apache_beam/runners/worker/operation_specs.py | 4 ++++ 2 files changed, 9 insertions(+) diff --git a/sdks/python/apache_beam/coders/coders.py b/sdks/python/apache_beam/coders/coders.py index f40045d142ffd..f3e0b432e51ca 100644 --- a/sdks/python/apache_beam/coders/coders.py +++ b/sdks/python/apache_beam/coders/coders.py @@ -286,6 +286,11 @@ def _create_impl(self): def is_deterministic(self): return True + def as_cloud_object(self): + return { + '@type': 'kind:bytes', + } + def __eq__(self, other): return type(self) == type(other) diff --git a/sdks/python/apache_beam/runners/worker/operation_specs.py b/sdks/python/apache_beam/runners/worker/operation_specs.py index db5eb765598b6..b8d19a1427563 100644 --- a/sdks/python/apache_beam/runners/worker/operation_specs.py +++ b/sdks/python/apache_beam/runners/worker/operation_specs.py @@ -339,6 +339,10 @@ def get_coder_from_spec(coder_spec): assert len(coder_spec['component_encodings']) == 1 return coders.coders.LengthPrefixCoder( get_coder_from_spec(coder_spec['component_encodings'][0])) + elif coder_spec['@type'] == 'kind:bytes': + assert ('component_encodings' not in coder_spec + or len(coder_spec['component_encodings'] == 0)) + return coders.BytesCoder() # We pass coders in the form "$" to make the job # description JSON more readable. From b5852d212cab060321c43a5800f8585aa3649aec Mon Sep 17 00:00:00 2001 From: Vikas Kedigehalli Date: Wed, 7 Jun 2017 16:28:18 -0700 Subject: [PATCH 003/200] Add coder info to pubsub io --- sdks/python/apache_beam/io/gcp/pubsub.py | 32 ++++++++++++++----- sdks/python/apache_beam/io/gcp/pubsub_test.py | 28 ++++++++++++++-- .../runners/dataflow/dataflow_runner.py | 23 +++++++++---- 3 files changed, 67 insertions(+), 16 deletions(-) diff --git a/sdks/python/apache_beam/io/gcp/pubsub.py b/sdks/python/apache_beam/io/gcp/pubsub.py index 1ba8ac051272c..40326e10295aa 100644 --- a/sdks/python/apache_beam/io/gcp/pubsub.py +++ b/sdks/python/apache_beam/io/gcp/pubsub.py @@ -40,13 +40,15 @@ class ReadStringsFromPubSub(PTransform): """A ``PTransform`` for reading utf-8 string payloads from Cloud Pub/Sub.""" - def __init__(self, topic, subscription=None, id_label=None): + def __init__(self, topic=None, subscription=None, id_label=None): """Initializes ``ReadStringsFromPubSub``. Attributes: - topic: Cloud Pub/Sub topic in the form "/topics//". - subscription: Optional existing Cloud Pub/Sub subscription to use in the - form "projects//subscriptions/". + topic: Cloud Pub/Sub topic in the form "/topics//". If + provided then subscription must be None. + subscription: Existing Cloud Pub/Sub subscription to use in the + form "projects//subscriptions/". If provided then + topic must be None. id_label: The attribute on incoming Pub/Sub messages to use as a unique record identifier. When specified, the value of this attribute (which can be any string that uniquely identifies the record) will be used for @@ -55,6 +57,12 @@ def __init__(self, topic, subscription=None, id_label=None): case, deduplication of the stream will be strictly best effort. """ super(ReadStringsFromPubSub, self).__init__() + if topic and subscription: + raise ValueError("Only one of topic or subscription should be provided.") + + if not (topic or subscription): + raise ValueError("Either a topic or subscription must be provided.") + self._source = _PubSubPayloadSource( topic, subscription=subscription, @@ -90,9 +98,11 @@ class _PubSubPayloadSource(dataflow_io.NativeSource): """Source for the payload of a message as bytes from a Cloud Pub/Sub topic. Attributes: - topic: Cloud Pub/Sub topic in the form "/topics//". - subscription: Optional existing Cloud Pub/Sub subscription to use in the - form "projects//subscriptions/". + topic: Cloud Pub/Sub topic in the form "/topics//". If + provided then topic must be None. + subscription: Existing Cloud Pub/Sub subscription to use in the + form "projects//subscriptions/". If provided then + subscription must be None. id_label: The attribute on incoming Pub/Sub messages to use as a unique record identifier. When specified, the value of this attribute (which can be any string that uniquely identifies the record) will be used for @@ -101,7 +111,10 @@ class _PubSubPayloadSource(dataflow_io.NativeSource): case, deduplication of the stream will be strictly best effort. """ - def __init__(self, topic, subscription=None, id_label=None): + def __init__(self, topic=None, subscription=None, id_label=None): + # we are using this coder explicitly for portability reasons of PubsubIO + # across implementations in languages. + self.coder = coders.BytesCoder() self.topic = topic self.subscription = subscription self.id_label = id_label @@ -131,6 +144,9 @@ class _PubSubPayloadSink(dataflow_io.NativeSink): """Sink for the payload of a message as bytes to a Cloud Pub/Sub topic.""" def __init__(self, topic): + # we are using this coder explicitly for portability reasons of PubsubIO + # across implementations in languages. + self.coder = coders.BytesCoder() self.topic = topic @property diff --git a/sdks/python/apache_beam/io/gcp/pubsub_test.py b/sdks/python/apache_beam/io/gcp/pubsub_test.py index 322d08a34cb3b..cf14e8c1d9217 100644 --- a/sdks/python/apache_beam/io/gcp/pubsub_test.py +++ b/sdks/python/apache_beam/io/gcp/pubsub_test.py @@ -34,9 +34,9 @@ class TestReadStringsFromPubSub(unittest.TestCase): - def test_expand(self): + def test_expand_with_topic(self): p = TestPipeline() - pcoll = p | ReadStringsFromPubSub('a_topic', 'a_subscription', 'a_label') + pcoll = p | ReadStringsFromPubSub('a_topic', None, 'a_label') # Ensure that the output type is str self.assertEqual(unicode, pcoll.element_type) @@ -47,9 +47,33 @@ def test_expand(self): # Ensure that the properties passed through correctly source = read_pcoll.producer.transform.source self.assertEqual('a_topic', source.topic) + self.assertEqual('a_label', source.id_label) + + def test_expand_with_subscription(self): + p = TestPipeline() + pcoll = p | ReadStringsFromPubSub(None, 'a_subscription', 'a_label') + # Ensure that the output type is str + self.assertEqual(unicode, pcoll.element_type) + + # Ensure that the type on the intermediate read output PCollection is bytes + read_pcoll = pcoll.producer.inputs[0] + self.assertEqual(bytes, read_pcoll.element_type) + + # Ensure that the properties passed through correctly + source = read_pcoll.producer.transform.source self.assertEqual('a_subscription', source.subscription) self.assertEqual('a_label', source.id_label) + def test_expand_with_both_topic_and_subscription(self): + with self.assertRaisesRegexp( + ValueError, "Only one of topic or subscription should be provided."): + ReadStringsFromPubSub('a_topic', 'a_subscription', 'a_label') + + def test_expand_with_no_topic_or_subscription(self): + with self.assertRaisesRegexp( + ValueError, "Either a topic or subscription must be provided."): + ReadStringsFromPubSub(None, None, 'a_label') + class TestWriteStringsToPubSub(unittest.TestCase): def test_expand(self): diff --git a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py b/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py index 3fc8983150459..d9aa1bf098d7a 100644 --- a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py +++ b/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py @@ -618,10 +618,12 @@ def run_Read(self, transform_node): if not standard_options.streaming: raise ValueError('PubSubPayloadSource is currently available for use ' 'only in streaming pipelines.') - step.add_property(PropertyNames.PUBSUB_TOPIC, transform.source.topic) - if transform.source.subscription: + # Only one of topic or subscription should be set. + if transform.source.topic: + step.add_property(PropertyNames.PUBSUB_TOPIC, transform.source.topic) + elif transform.source.subscription: step.add_property(PropertyNames.PUBSUB_SUBSCRIPTION, - transform.source.topic) + transform.source.subscription) if transform.source.id_label: step.add_property(PropertyNames.PUBSUB_ID_LABEL, transform.source.id_label) @@ -639,7 +641,12 @@ def run_Read(self, transform_node): # step should be the type of value outputted by each step. Read steps # automatically wrap output values in a WindowedValue wrapper, if necessary. # This is also necessary for proper encoding for size estimation. - coder = coders.WindowedValueCoder(transform._infer_output_coder()) # pylint: disable=protected-access + # Using a GlobalWindowCoder as a place holder instead of the default + # PickleCoder because GlobalWindowCoder is known coder. + # TODO(robertwb): Query the collection for the windowfn to extract the + # correct coder. + coder = coders.WindowedValueCoder(transform._infer_output_coder(), + coders.coders.GlobalWindowCoder()) # pylint: disable=protected-access step.encoding = self._get_cloud_encoding(coder) step.add_property( @@ -708,8 +715,12 @@ def run__NativeWrite(self, transform_node): step.add_property(PropertyNames.FORMAT, transform.sink.format) # Wrap coder in WindowedValueCoder: this is necessary for proper encoding - # for size estimation. - coder = coders.WindowedValueCoder(transform.sink.coder) + # for size estimation. Using a GlobalWindowCoder as a place holder instead + # of the default PickleCoder because GlobalWindowCoder is known coder. + # TODO(robertwb): Query the collection for the windowfn to extract the + # correct coder. + coder = coders.WindowedValueCoder(transform.sink.coder, + coders.coders.GlobalWindowCoder()) step.encoding = self._get_cloud_encoding(coder) step.add_property(PropertyNames.ENCODING, step.encoding) step.add_property( From ccf7344820d6c69ca922aa3176dc141718382629 Mon Sep 17 00:00:00 2001 From: Thomas Groh Date: Thu, 1 Jun 2017 18:39:58 -0700 Subject: [PATCH 004/200] Expand all PValues to component PCollections always Update the implementation of WriteView The PCollectionView is constructed within the composite override, but WriteView just produces a primitive PCollection which has no consumers. Track the ViewWriter within the Direct Runner, and utilize that transform rather than the producer to update PCollection Watermarks. Remove most Flink View overrides. All of the overrides are materially identical within the flink runner, so use a single override to replace all of them. --- .../apache/beam/runners/apex/ApexRunner.java | 59 +-- .../translation/ApexPipelineTranslator.java | 16 +- .../construction/RunnerPCollectionView.java | 8 + .../beam/runners/direct/DirectGraph.java | 4 + .../runners/direct/DirectGraphVisitor.java | 22 +- .../runners/direct/ViewEvaluatorFactory.java | 8 +- .../runners/direct/ViewOverrideFactory.java | 29 +- .../beam/runners/direct/DirectGraphs.java | 7 + .../runners/direct/EvaluationContextTest.java | 5 +- .../ImmutabilityEnforcementFactoryTest.java | 4 +- .../runners/direct/ParDoEvaluatorTest.java | 1 + .../runners/direct/TransformExecutorTest.java | 1 + .../direct/ViewEvaluatorFactoryTest.java | 5 +- .../direct/ViewOverrideFactoryTest.java | 16 +- .../direct/WatermarkCallbackExecutorTest.java | 1 + .../runners/direct/WatermarkManagerTest.java | 1 + .../flink/CreateStreamingFlinkView.java | 154 ++++++++ .../FlinkStreamingPipelineTranslator.java | 36 +- .../FlinkStreamingTransformTranslators.java | 8 +- .../flink/FlinkStreamingViewOverrides.java | 372 ------------------ .../runners/dataflow/BatchViewOverrides.java | 182 +++------ .../runners/dataflow/CreateDataflowView.java | 8 +- .../dataflow/DataflowPipelineTranslator.java | 11 +- .../beam/runners/dataflow/DataflowRunner.java | 17 +- .../dataflow/StreamingViewOverrides.java | 10 +- .../DataflowPipelineTranslatorTest.java | 6 +- .../translation/TransformTranslator.java | 50 +-- .../beam/sdk/runners/TransformHierarchy.java | 46 ++- .../apache/beam/sdk/transforms/Combine.java | 17 +- .../org/apache/beam/sdk/transforms/View.java | 38 +- .../apache/beam/sdk/values/PCollection.java | 12 + .../beam/sdk/values/PCollectionViews.java | 14 + .../apache/beam/sdk/values/PValueBase.java | 12 - .../sdk/testing/PCollectionViewTesting.java | 8 + 34 files changed, 458 insertions(+), 730 deletions(-) create mode 100644 runners/flink/src/main/java/org/apache/beam/runners/flink/CreateStreamingFlinkView.java delete mode 100644 runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingViewOverrides.java diff --git a/runners/apex/src/main/java/org/apache/beam/runners/apex/ApexRunner.java b/runners/apex/src/main/java/org/apache/beam/runners/apex/ApexRunner.java index c595b3f50b2e9..95b354a9fe337 100644 --- a/runners/apex/src/main/java/org/apache/beam/runners/apex/ApexRunner.java +++ b/runners/apex/src/main/java/org/apache/beam/runners/apex/ApexRunner.java @@ -62,8 +62,6 @@ import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.ParDo.MultiOutput; -import org.apache.beam.sdk.transforms.View; -import org.apache.beam.sdk.transforms.View.AsIterable; import org.apache.beam.sdk.transforms.View.CreatePCollectionView; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionTuple; @@ -214,7 +212,7 @@ public void populateDAG(DAG dag, Configuration conf) { * @param The type associated with the {@link PCollectionView} used as a side input */ public static class CreateApexPCollectionView - extends PTransform>, PCollectionView> { + extends PTransform, PCollection> { private static final long serialVersionUID = 1L; private PCollectionView view; @@ -228,7 +226,13 @@ public static CreateApexPCollectionView of( } @Override - public PCollectionView expand(PCollection> input) { + public PCollection expand(PCollection input) { + return PCollection.createPrimitiveOutputInternal( + input.getPipeline(), input.getWindowingStrategy(), input.isBounded()) + .setCoder(input.getCoder()); + } + + public PCollectionView getView() { return view; } } @@ -241,7 +245,7 @@ public void processElement(ProcessContext c) { } private static class StreamingWrapSingletonInList - extends PTransform, PCollectionView> { + extends PTransform, PCollection> { private static final long serialVersionUID = 1L; CreatePCollectionView transform; @@ -254,10 +258,11 @@ private StreamingWrapSingletonInList( } @Override - public PCollectionView expand(PCollection input) { - return input + public PCollection expand(PCollection input) { + input .apply(ParDo.of(new WrapAsList())) - .apply(CreateApexPCollectionView.of(transform.getView())); + .apply(CreateApexPCollectionView., T>of(transform.getView())); + return input; } @Override @@ -267,15 +272,12 @@ protected String getKindString() { static class Factory extends SingleInputOutputOverrideFactory< - PCollection, PCollectionView, + PCollection, PCollection, CreatePCollectionView> { @Override - public PTransformReplacement, PCollectionView> - getReplacementTransform( - AppliedPTransform< - PCollection, PCollectionView, - CreatePCollectionView> - transform) { + public PTransformReplacement, PCollection> getReplacementTransform( + AppliedPTransform, PCollection, CreatePCollectionView> + transform) { return PTransformReplacement.of( PTransformReplacements.getSingletonMainInput(transform), new StreamingWrapSingletonInList<>(transform.getTransform())); @@ -284,18 +286,19 @@ static class Factory } private static class StreamingViewAsIterable - extends PTransform, PCollectionView>> { + extends PTransform, PCollection> { private static final long serialVersionUID = 1L; + private final PCollectionView> view; - private StreamingViewAsIterable() {} + private StreamingViewAsIterable(PCollectionView> view) { + this.view = view; + } @Override - public PCollectionView> expand(PCollection input) { - PCollectionView> view = - PCollectionViews.iterableView(input, input.getWindowingStrategy(), input.getCoder()); - - return input.apply(Combine.globally(new Concatenate()).withoutDefaults()) - .apply(CreateApexPCollectionView.> of(view)); + public PCollection expand(PCollection input) { + return ((PCollection) + input.apply(Combine.globally(new Concatenate()).withoutDefaults())) + .apply(CreateApexPCollectionView.>of(view)); } @Override @@ -305,15 +308,17 @@ protected String getKindString() { static class Factory extends SingleInputOutputOverrideFactory< - PCollection, PCollectionView>, View.AsIterable> { + PCollection, PCollection, CreatePCollectionView>> { @Override - public PTransformReplacement, PCollectionView>> + public PTransformReplacement, PCollection> getReplacementTransform( - AppliedPTransform, PCollectionView>, AsIterable> + AppliedPTransform< + PCollection, PCollection, + CreatePCollectionView>> transform) { return PTransformReplacement.of( PTransformReplacements.getSingletonMainInput(transform), - new StreamingViewAsIterable()); + new StreamingViewAsIterable(transform.getTransform().getView())); } } } diff --git a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/ApexPipelineTranslator.java b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/ApexPipelineTranslator.java index bda074b0a29b1..02f53eccdc12d 100644 --- a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/ApexPipelineTranslator.java +++ b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/ApexPipelineTranslator.java @@ -39,7 +39,6 @@ import org.apache.beam.sdk.transforms.View.CreatePCollectionView; import org.apache.beam.sdk.transforms.windowing.Window; import org.apache.beam.sdk.values.KV; -import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.PValue; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -154,7 +153,6 @@ public void translate(Read.Bounded transform, TranslationContext context) { unboundedSource, true, context.getPipelineOptions()); context.addOperator(operator, operator.output); } - } private static class CreateApexPCollectionViewTranslator @@ -162,11 +160,10 @@ private static class CreateApexPCollectionViewTranslator private static final long serialVersionUID = 1L; @Override - public void translate(CreateApexPCollectionView transform, - TranslationContext context) { - PCollectionView view = (PCollectionView) context.getOutput(); - context.addView(view); - LOG.debug("view {}", view.getName()); + public void translate( + CreateApexPCollectionView transform, TranslationContext context) { + context.addView(transform.getView()); + LOG.debug("view {}", transform.getView().getName()); } } @@ -177,9 +174,8 @@ private static class CreatePCollectionViewTranslator @Override public void translate( CreatePCollectionView transform, TranslationContext context) { - PCollectionView view = (PCollectionView) context.getOutput(); - context.addView(view); - LOG.debug("view {}", view.getName()); + context.addView(transform.getView()); + LOG.debug("view {}", transform.getView().getName()); } } diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/RunnerPCollectionView.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/RunnerPCollectionView.java index 89e878496ef3e..c359cecce361d 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/RunnerPCollectionView.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/RunnerPCollectionView.java @@ -18,6 +18,7 @@ package org.apache.beam.runners.core.construction; +import java.util.Map; import javax.annotation.Nullable; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.common.runner.v1.RunnerApi.SideInput; @@ -26,6 +27,7 @@ import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.PValue; import org.apache.beam.sdk.values.PValueBase; import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.WindowingStrategy; @@ -85,4 +87,10 @@ public WindowMappingFn getWindowMappingFn() { public Coder>> getCoderInternal() { return coder; } + + @Override + public Map, PValue> expand() { + throw new UnsupportedOperationException(String.format( + "A %s cannot be expanded", RunnerPCollectionView.class.getSimpleName())); + } } diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGraph.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGraph.java index c2c0afa730507..9ca745d4670d9 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGraph.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGraph.java @@ -79,6 +79,10 @@ private DirectGraph( return rootTransforms; } + Set> getPCollections() { + return producers.keySet(); + } + Set> getViews() { return viewWriters.keySet(); } diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGraphVisitor.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGraphVisitor.java index d54de5d9b5ee9..07bcf06926cb4 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGraphVisitor.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGraphVisitor.java @@ -21,15 +21,18 @@ import com.google.common.collect.ArrayListMultimap; import com.google.common.collect.ListMultimap; +import com.google.common.collect.Sets; import java.util.HashMap; import java.util.HashSet; import java.util.Map; import java.util.Set; +import org.apache.beam.runners.direct.ViewOverrideFactory.WriteView; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.Pipeline.PipelineVisitor; import org.apache.beam.sdk.runners.AppliedPTransform; import org.apache.beam.sdk.runners.TransformHierarchy; import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.PInput; @@ -44,6 +47,7 @@ class DirectGraphVisitor extends PipelineVisitor.Defaults { private Map, AppliedPTransform> producers = new HashMap<>(); private Map, AppliedPTransform> viewWriters = new HashMap<>(); + private Set> consumedViews = new HashSet<>(); private ListMultimap> primitiveConsumers = ArrayListMultimap.create(); @@ -73,6 +77,13 @@ public void leaveCompositeTransform(TransformHierarchy.Node node) { getClass().getSimpleName()); if (node.isRootNode()) { finalized = true; + checkState( + viewWriters.keySet().containsAll(consumedViews), + "All %ss that are consumed must be written by some %s %s: Missing %s", + PCollectionView.class.getSimpleName(), + WriteView.class.getSimpleName(), + PTransform.class.getSimpleName(), + Sets.difference(consumedViews, viewWriters.keySet())); } } @@ -86,11 +97,12 @@ public void visitPrimitiveTransform(TransformHierarchy.Node node) { for (PValue value : node.getInputs().values()) { primitiveConsumers.put(value, appliedTransform); } - if (node.getTransform() instanceof ViewOverrideFactory.WriteView) { - viewWriters.put( - ((ViewOverrideFactory.WriteView) node.getTransform()).getView(), - node.toAppliedPTransform(getPipeline())); - } + } + if (node.getTransform() instanceof ParDo.MultiOutput) { + consumedViews.addAll(((ParDo.MultiOutput) node.getTransform()).getSideInputs()); + } else if (node.getTransform() instanceof ViewOverrideFactory.WriteView) { + viewWriters.put( + ((WriteView) node.getTransform()).getView(), node.toAppliedPTransform(getPipeline())); } } diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ViewEvaluatorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ViewEvaluatorFactory.java index 057f4a1836ca0..8a281a7944fa7 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ViewEvaluatorFactory.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ViewEvaluatorFactory.java @@ -28,7 +28,6 @@ import org.apache.beam.sdk.transforms.View.CreatePCollectionView; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.PCollection; -import org.apache.beam.sdk.values.PCollectionView; /** * The {@link DirectRunner} {@link TransformEvaluatorFactory} for the {@link CreatePCollectionView} @@ -60,12 +59,13 @@ public TransformEvaluator forApplication( public void cleanup() throws Exception {} private TransformEvaluator> createEvaluator( - final AppliedPTransform>, PCollectionView, WriteView> + final AppliedPTransform< + PCollection>, PCollection>, WriteView> application) { PCollection> input = (PCollection>) Iterables.getOnlyElement(application.getInputs().values()); - final PCollectionViewWriter writer = context.createPCollectionViewWriter(input, - (PCollectionView) Iterables.getOnlyElement(application.getOutputs().values())); + final PCollectionViewWriter writer = + context.createPCollectionViewWriter(input, application.getTransform().getView()); return new TransformEvaluator>() { private final List> elements = new ArrayList<>(); diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ViewOverrideFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ViewOverrideFactory.java index fdff63d803f2e..06a73889a1aeb 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ViewOverrideFactory.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ViewOverrideFactory.java @@ -18,11 +18,11 @@ package org.apache.beam.runners.direct; -import java.util.Collections; import java.util.Map; import org.apache.beam.runners.core.construction.ForwardingPTransform; import org.apache.beam.runners.core.construction.PTransformReplacements; import org.apache.beam.runners.core.construction.PTransformTranslation.RawPTransform; +import org.apache.beam.runners.core.construction.ReplacementOutputs; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.coders.VoidCoder; import org.apache.beam.sdk.runners.AppliedPTransform; @@ -43,12 +43,12 @@ */ class ViewOverrideFactory implements PTransformOverrideFactory< - PCollection, PCollectionView, CreatePCollectionView> { + PCollection, PCollection, CreatePCollectionView> { @Override - public PTransformReplacement, PCollectionView> getReplacementTransform( + public PTransformReplacement, PCollection> getReplacementTransform( AppliedPTransform< - PCollection, PCollectionView, CreatePCollectionView> + PCollection, PCollection, CreatePCollectionView> transform) { return PTransformReplacement.of( PTransformReplacements.getSingletonMainInput(transform), @@ -57,13 +57,13 @@ public PTransformReplacement, PCollectionView> getRepl @Override public Map mapOutputs( - Map, PValue> outputs, PCollectionView newOutput) { - return Collections.emptyMap(); + Map, PValue> outputs, PCollection newOutput) { + return ReplacementOutputs.singleton(outputs, newOutput); } /** The {@link DirectRunner} composite override for {@link CreatePCollectionView}. */ static class GroupAndWriteView - extends ForwardingPTransform, PCollectionView> { + extends ForwardingPTransform, PCollection> { private final CreatePCollectionView og; private GroupAndWriteView(CreatePCollectionView og) { @@ -71,17 +71,18 @@ private GroupAndWriteView(CreatePCollectionView og) { } @Override - public PCollectionView expand(PCollection input) { - return input + public PCollection expand(final PCollection input) { + input .apply(WithKeys.of((Void) null)) .setCoder(KvCoder.of(VoidCoder.of(), input.getCoder())) .apply(GroupByKey.create()) .apply(Values.>create()) .apply(new WriteView(og)); + return input; } @Override - protected PTransform, PCollectionView> delegate() { + protected PTransform, PCollection> delegate() { return og; } } @@ -94,7 +95,7 @@ protected PTransform, PCollectionView> delegate() { * to {@link ViewT}. */ static final class WriteView - extends RawPTransform>, PCollectionView> { + extends RawPTransform>, PCollection>> { private final CreatePCollectionView og; WriteView(CreatePCollectionView og) { @@ -103,8 +104,10 @@ static final class WriteView @Override @SuppressWarnings("deprecation") - public PCollectionView expand(PCollection> input) { - return og.getView(); + public PCollection> expand(PCollection> input) { + return PCollection.>createPrimitiveOutputInternal( + input.getPipeline(), input.getWindowingStrategy(), input.isBounded()) + .setCoder(input.getCoder()); } @SuppressWarnings("deprecation") diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectGraphs.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectGraphs.java index 43de091b263d7..7707f7fe489d5 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectGraphs.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectGraphs.java @@ -18,6 +18,7 @@ package org.apache.beam.runners.direct; import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.runners.AppliedPTransform; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionView; @@ -25,6 +26,12 @@ /** Test utilities for the {@link DirectRunner}. */ final class DirectGraphs { + public static void performDirectOverrides(Pipeline p) { + p.replaceAll( + DirectRunner.fromOptions(PipelineOptionsFactory.create().as(DirectOptions.class)) + .defaultTransformOverrides()); + } + public static DirectGraph getGraph(Pipeline p) { DirectGraphVisitor visitor = new DirectGraphVisitor(); p.traverseTopologically(visitor); diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/EvaluationContextTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/EvaluationContextTest.java index c0e43d6c64681..f3edf552b27e8 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/EvaluationContextTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/EvaluationContextTest.java @@ -101,14 +101,13 @@ public void setup() { view = created.apply(View.asIterable()); unbounded = p.apply(GenerateSequence.from(0)); - p.replaceAll( - DirectRunner.fromOptions(TestPipeline.testingPipelineOptions()) - .defaultTransformOverrides()); + p.replaceAll(runner.defaultTransformOverrides()); KeyedPValueTrackingVisitor keyedPValueTrackingVisitor = KeyedPValueTrackingVisitor.create(); p.traverseTopologically(keyedPValueTrackingVisitor); BundleFactory bundleFactory = ImmutableListBundleFactory.create(); + DirectGraphs.performDirectOverrides(p); graph = DirectGraphs.getGraph(p); context = EvaluationContext.create( diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ImmutabilityEnforcementFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ImmutabilityEnforcementFactoryTest.java index c0919b9509fc5..365b6c43ade92 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ImmutabilityEnforcementFactoryTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ImmutabilityEnforcementFactoryTest.java @@ -64,7 +64,9 @@ public void processElement(ProcessContext c) c.element()[0] = 'b'; } })); - consumer = DirectGraphs.getProducer(pcollection.apply(Count.globally())); + PCollection consumer = pcollection.apply(Count.globally()); + DirectGraphs.performDirectOverrides(p); + this.consumer = DirectGraphs.getProducer(consumer); } @Test diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ParDoEvaluatorTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ParDoEvaluatorTest.java index 09a21ac524a59..df84cbf6f2c9a 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ParDoEvaluatorTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ParDoEvaluatorTest.java @@ -149,6 +149,7 @@ private ParDoEvaluator createEvaluator( Mockito.any(AppliedPTransform.class), Mockito.any(StructuralKey.class))) .thenReturn(executionContext); + DirectGraphs.performDirectOverrides(p); @SuppressWarnings("unchecked") AppliedPTransform, ?, ?> transform = (AppliedPTransform, ?, ?>) DirectGraphs.getProducer(output); diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/TransformExecutorTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/TransformExecutorTest.java index 86412a0234e64..3dd4028af6054 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/TransformExecutorTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/TransformExecutorTest.java @@ -90,6 +90,7 @@ public void setup() { created = p.apply(Create.of("foo", "spam", "third")); PCollection> downstream = created.apply(WithKeys.of(3)); + DirectGraphs.performDirectOverrides(p); DirectGraph graph = DirectGraphs.getGraph(p); createdProducer = graph.getProducer(created); downstreamProducer = graph.getProducer(downstream); diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ViewEvaluatorFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ViewEvaluatorFactoryTest.java index 419698e29798c..ad1aecce623b1 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ViewEvaluatorFactoryTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ViewEvaluatorFactoryTest.java @@ -36,7 +36,6 @@ import org.apache.beam.sdk.transforms.WithKeys; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.PCollection; -import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.PCollectionViews; import org.joda.time.Instant; import org.junit.Rule; @@ -66,12 +65,12 @@ public void testInMemoryEvaluator() throws Exception { .setCoder(KvCoder.of(VoidCoder.of(), StringUtf8Coder.of())) .apply(GroupByKey.create()) .apply(Values.>create()); - PCollectionView> view = + PCollection> view = concat.apply(new ViewOverrideFactory.WriteView<>(createView)); EvaluationContext context = mock(EvaluationContext.class); TestViewWriter> viewWriter = new TestViewWriter<>(); - when(context.createPCollectionViewWriter(concat, view)).thenReturn(viewWriter); + when(context.createPCollectionViewWriter(concat, createView.getView())).thenReturn(viewWriter); CommittedBundle inputBundle = bundleFactory.createBundle(input).commit(Instant.now()); AppliedPTransform producer = DirectGraphs.getProducer(view); diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ViewOverrideFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ViewOverrideFactoryTest.java index 024e15c4c5361..94728c7909209 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ViewOverrideFactoryTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ViewOverrideFactoryTest.java @@ -18,7 +18,6 @@ package org.apache.beam.runners.direct; -import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.is; import static org.junit.Assert.assertThat; @@ -63,11 +62,11 @@ public void replacementSucceeds() { PCollection ints = p.apply("CreateContents", Create.of(1, 2, 3)); final PCollectionView> view = PCollectionViews.listView(ints, WindowingStrategy.globalDefault(), ints.getCoder()); - PTransformReplacement, PCollectionView>> + PTransformReplacement, PCollection> replacementTransform = factory.getReplacementTransform( AppliedPTransform - ., PCollectionView>, + ., PCollection, CreatePCollectionView>> of( "foo", @@ -75,12 +74,7 @@ public void replacementSucceeds() { view.expand(), CreatePCollectionView.>of(view), p)); - PCollectionView> afterReplacement = - ints.apply(replacementTransform.getTransform()); - assertThat( - "The CreatePCollectionView replacement should return the same View", - afterReplacement, - equalTo(view)); + ints.apply(replacementTransform.getTransform()); PCollection> outputViewContents = p.apply("CreateSingleton", Create.of(0)) @@ -104,10 +98,10 @@ public void replacementGetViewReturnsOriginal() { final PCollection ints = p.apply("CreateContents", Create.of(1, 2, 3)); final PCollectionView> view = PCollectionViews.listView(ints, WindowingStrategy.globalDefault(), ints.getCoder()); - PTransformReplacement, PCollectionView>> replacement = + PTransformReplacement, PCollection> replacement = factory.getReplacementTransform( AppliedPTransform - ., PCollectionView>, + ., PCollection, CreatePCollectionView>> of( "foo", diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WatermarkCallbackExecutorTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WatermarkCallbackExecutorTest.java index b66734600ee99..1d8aac1d1c779 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WatermarkCallbackExecutorTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WatermarkCallbackExecutorTest.java @@ -59,6 +59,7 @@ public class WatermarkCallbackExecutorTest { public void setup() { PCollection created = p.apply(Create.of(1, 2, 3)); PCollection summed = created.apply(Sum.integersGlobally()); + DirectGraphs.performDirectOverrides(p); DirectGraph graph = DirectGraphs.getGraph(p); create = graph.getProducer(created); sum = graph.getProducer(summed); diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WatermarkManagerTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WatermarkManagerTest.java index 9528ac93285ed..e0b52515d8c63 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WatermarkManagerTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WatermarkManagerTest.java @@ -121,6 +121,7 @@ public void processElement(ProcessContext c) throws Exception { flattened = preFlatten.apply("flattened", Flatten.pCollections()); clock = MockClock.fromInstant(new Instant(1000)); + DirectGraphs.performDirectOverrides(p); graph = DirectGraphs.getGraph(p); manager = WatermarkManager.create(clock, graph); diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/CreateStreamingFlinkView.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/CreateStreamingFlinkView.java new file mode 100644 index 0000000000000..0cc3aec1d3aff --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/CreateStreamingFlinkView.java @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink; + +import com.google.common.collect.Iterables; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import org.apache.beam.runners.core.construction.ReplacementOutputs; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.CoderRegistry; +import org.apache.beam.sdk.coders.ListCoder; +import org.apache.beam.sdk.runners.AppliedPTransform; +import org.apache.beam.sdk.runners.PTransformOverrideFactory; +import org.apache.beam.sdk.transforms.Combine; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.View.CreatePCollectionView; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.PValue; +import org.apache.beam.sdk.values.TupleTag; + +/** Flink streaming overrides for various view (side input) transforms. */ +class CreateStreamingFlinkView + extends PTransform, PCollection> { + private final PCollectionView view; + + public CreateStreamingFlinkView(PCollectionView view) { + this.view = view; + } + + @Override + public PCollection expand(PCollection input) { + input + .apply(Combine.globally(new Concatenate()).withoutDefaults()) + .apply(CreateFlinkPCollectionView.of(view)); + return input; + } + + /** + * Combiner that combines {@code T}s into a single {@code List} containing all inputs. + * + *

For internal use by {@link CreateStreamingFlinkView}. This combiner requires that the input + * {@link PCollection} fits in memory. For a large {@link PCollection} this is expected to crash! + * + * @param the type of elements to concatenate. + */ + private static class Concatenate extends Combine.CombineFn, List> { + @Override + public List createAccumulator() { + return new ArrayList(); + } + + @Override + public List addInput(List accumulator, T input) { + accumulator.add(input); + return accumulator; + } + + @Override + public List mergeAccumulators(Iterable> accumulators) { + List result = createAccumulator(); + for (List accumulator : accumulators) { + result.addAll(accumulator); + } + return result; + } + + @Override + public List extractOutput(List accumulator) { + return accumulator; + } + + @Override + public Coder> getAccumulatorCoder(CoderRegistry registry, Coder inputCoder) { + return ListCoder.of(inputCoder); + } + + @Override + public Coder> getDefaultOutputCoder(CoderRegistry registry, Coder inputCoder) { + return ListCoder.of(inputCoder); + } + } + + /** + * Creates a primitive {@link PCollectionView}. + * + *

For internal use only by runner implementors. + * + * @param The type of the elements of the input PCollection + * @param The type associated with the {@link PCollectionView} used as a side input + */ + public static class CreateFlinkPCollectionView + extends PTransform>, PCollection>> { + private PCollectionView view; + + private CreateFlinkPCollectionView(PCollectionView view) { + this.view = view; + } + + public static CreateFlinkPCollectionView of( + PCollectionView view) { + return new CreateFlinkPCollectionView<>(view); + } + + @Override + public PCollection> expand(PCollection> input) { + return PCollection.>createPrimitiveOutputInternal( + input.getPipeline(), input.getWindowingStrategy(), input.isBounded()) + .setCoder(input.getCoder()); + } + + public PCollectionView getView() { + return view; + } + } + + public static class Factory + implements PTransformOverrideFactory< + PCollection, PCollection, CreatePCollectionView> { + public Factory() {} + + @Override + public PTransformReplacement, PCollection> getReplacementTransform( + AppliedPTransform< + PCollection, PCollection, CreatePCollectionView> + transform) { + return PTransformReplacement.of( + (PCollection) Iterables.getOnlyElement(transform.getInputs().values()), + new CreateStreamingFlinkView(transform.getTransform().getView())); + } + + @Override + public Map mapOutputs( + Map, PValue> outputs, PCollection newOutput) { + return ReplacementOutputs.singleton(outputs, newOutput); + } + } +} diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPipelineTranslator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPipelineTranslator.java index 8da68c5fc11ef..a88ff071fcaca 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPipelineTranslator.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPipelineTranslator.java @@ -33,10 +33,9 @@ import org.apache.beam.sdk.runners.PTransformOverride; import org.apache.beam.sdk.runners.PTransformOverrideFactory; import org.apache.beam.sdk.runners.TransformHierarchy; -import org.apache.beam.sdk.transforms.Combine; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo.MultiOutput; -import org.apache.beam.sdk.transforms.View; +import org.apache.beam.sdk.transforms.View.CreatePCollectionView; import org.apache.beam.sdk.util.InstanceBuilder; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionTuple; @@ -85,37 +84,8 @@ public void translate(Pipeline pipeline) { new SplittableParDoViaKeyedWorkItems.OverrideFactory())) .add( PTransformOverride.of( - PTransformMatchers.classEqualTo(View.AsIterable.class), - new ReflectiveOneToOneOverrideFactory( - FlinkStreamingViewOverrides.StreamingViewAsIterable.class, flinkRunner))) - .add( - PTransformOverride.of( - PTransformMatchers.classEqualTo(View.AsList.class), - new ReflectiveOneToOneOverrideFactory( - FlinkStreamingViewOverrides.StreamingViewAsList.class, flinkRunner))) - .add( - PTransformOverride.of( - PTransformMatchers.classEqualTo(View.AsMap.class), - new ReflectiveOneToOneOverrideFactory( - FlinkStreamingViewOverrides.StreamingViewAsMap.class, flinkRunner))) - .add( - PTransformOverride.of( - PTransformMatchers.classEqualTo(View.AsMultimap.class), - new ReflectiveOneToOneOverrideFactory( - FlinkStreamingViewOverrides.StreamingViewAsMultimap.class, flinkRunner))) - .add( - PTransformOverride.of( - PTransformMatchers.classEqualTo(View.AsSingleton.class), - new ReflectiveOneToOneOverrideFactory( - FlinkStreamingViewOverrides.StreamingViewAsSingleton.class, flinkRunner))) - // this has to be last since the ViewAsSingleton override - // can expand to a Combine.GloballyAsSingletonView - .add( - PTransformOverride.of( - PTransformMatchers.classEqualTo(Combine.GloballyAsSingletonView.class), - new ReflectiveOneToOneOverrideFactory( - FlinkStreamingViewOverrides.StreamingCombineGloballyAsSingletonView.class, - flinkRunner))) + PTransformMatchers.classEqualTo(CreatePCollectionView.class), + new CreateStreamingFlinkView.Factory())) .build(); // Ensure all outputs of all reads are consumed. diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java index 2a7c5d66c953d..ef46b63ae5619 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java @@ -124,7 +124,7 @@ class FlinkStreamingTransformTranslators { TRANSLATORS.put(Window.Assign.class, new WindowAssignTranslator()); TRANSLATORS.put(Flatten.PCollections.class, new FlattenPCollectionTranslator()); TRANSLATORS.put( - FlinkStreamingViewOverrides.CreateFlinkPCollectionView.class, + CreateStreamingFlinkView.CreateFlinkPCollectionView.class, new CreateViewStreamingTranslator()); TRANSLATORS.put(Reshuffle.class, new ReshuffleTranslatorStreaming()); @@ -584,17 +584,17 @@ OutputT> createDoFnOperator( private static class CreateViewStreamingTranslator extends FlinkStreamingPipelineTranslator.StreamTransformTranslator< - FlinkStreamingViewOverrides.CreateFlinkPCollectionView> { + CreateStreamingFlinkView.CreateFlinkPCollectionView> { @Override public void translateNode( - FlinkStreamingViewOverrides.CreateFlinkPCollectionView transform, + CreateStreamingFlinkView.CreateFlinkPCollectionView transform, FlinkStreamingTranslationContext context) { // just forward DataStream>> inputDataSet = context.getInputDataStream(context.getInput(transform)); - PCollectionView view = context.getOutput(transform); + PCollectionView view = transform.getView(); context.setOutputDataStream(view, inputDataSet); } diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingViewOverrides.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingViewOverrides.java deleted file mode 100644 index ce1c895555207..0000000000000 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingViewOverrides.java +++ /dev/null @@ -1,372 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.flink; - -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.Map; -import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.coders.CoderRegistry; -import org.apache.beam.sdk.coders.KvCoder; -import org.apache.beam.sdk.coders.ListCoder; -import org.apache.beam.sdk.transforms.Combine; -import org.apache.beam.sdk.transforms.DoFn; -import org.apache.beam.sdk.transforms.PTransform; -import org.apache.beam.sdk.transforms.ParDo; -import org.apache.beam.sdk.transforms.View; -import org.apache.beam.sdk.values.KV; -import org.apache.beam.sdk.values.PCollection; -import org.apache.beam.sdk.values.PCollectionView; -import org.apache.beam.sdk.values.PCollectionViews; - -/** - * Flink streaming overrides for various view (side input) transforms. - */ -class FlinkStreamingViewOverrides { - - /** - * Specialized implementation for - * {@link org.apache.beam.sdk.transforms.View.AsMap View.AsMap} - * for the Flink runner in streaming mode. - */ - static class StreamingViewAsMap - extends PTransform>, PCollectionView>> { - - private final transient FlinkRunner runner; - - @SuppressWarnings("unused") // used via reflection in FlinkRunner#apply() - public StreamingViewAsMap(FlinkRunner runner, View.AsMap transform) { - this.runner = runner; - } - - @Override - public PCollectionView> expand(PCollection> input) { - PCollectionView> view = - PCollectionViews.mapView( - input, - input.getWindowingStrategy(), - input.getCoder()); - - @SuppressWarnings({"rawtypes", "unchecked"}) - KvCoder inputCoder = (KvCoder) input.getCoder(); - try { - inputCoder.getKeyCoder().verifyDeterministic(); - } catch (Coder.NonDeterministicException e) { - runner.recordViewUsesNonDeterministicKeyCoder(this); - } - - return input - .apply(Combine.globally(new Concatenate>()).withoutDefaults()) - .apply(CreateFlinkPCollectionView., Map>of(view)); - } - - @Override - protected String getKindString() { - return "StreamingViewAsMap"; - } - } - - /** - * Specialized expansion for {@link - * View.AsMultimap View.AsMultimap} for the - * Flink runner in streaming mode. - */ - static class StreamingViewAsMultimap - extends PTransform>, PCollectionView>>> { - - private final transient FlinkRunner runner; - - /** - * Builds an instance of this class from the overridden transform. - */ - @SuppressWarnings("unused") // used via reflection in FlinkRunner#apply() - public StreamingViewAsMultimap(FlinkRunner runner, View.AsMultimap transform) { - this.runner = runner; - } - - @Override - public PCollectionView>> expand(PCollection> input) { - PCollectionView>> view = - PCollectionViews.multimapView( - input, - input.getWindowingStrategy(), - input.getCoder()); - - @SuppressWarnings({"rawtypes", "unchecked"}) - KvCoder inputCoder = (KvCoder) input.getCoder(); - try { - inputCoder.getKeyCoder().verifyDeterministic(); - } catch (Coder.NonDeterministicException e) { - runner.recordViewUsesNonDeterministicKeyCoder(this); - } - - return input - .apply(Combine.globally(new Concatenate>()).withoutDefaults()) - .apply(CreateFlinkPCollectionView., Map>>of(view)); - } - - @Override - protected String getKindString() { - return "StreamingViewAsMultimap"; - } - } - - /** - * Specialized implementation for - * {@link View.AsList View.AsList} for the - * Flink runner in streaming mode. - */ - static class StreamingViewAsList - extends PTransform, PCollectionView>> { - /** - * Builds an instance of this class from the overridden transform. - */ - @SuppressWarnings("unused") // used via reflection in FlinkRunner#apply() - public StreamingViewAsList(FlinkRunner runner, View.AsList transform) {} - - @Override - public PCollectionView> expand(PCollection input) { - PCollectionView> view = - PCollectionViews.listView( - input, - input.getWindowingStrategy(), - input.getCoder()); - - return input.apply(Combine.globally(new Concatenate()).withoutDefaults()) - .apply(CreateFlinkPCollectionView.>of(view)); - } - - @Override - protected String getKindString() { - return "StreamingViewAsList"; - } - } - - /** - * Specialized implementation for - * {@link View.AsIterable View.AsIterable} for the - * Flink runner in streaming mode. - */ - static class StreamingViewAsIterable - extends PTransform, PCollectionView>> { - /** - * Builds an instance of this class from the overridden transform. - */ - @SuppressWarnings("unused") // used via reflection in FlinkRunner#apply() - public StreamingViewAsIterable(FlinkRunner runner, View.AsIterable transform) { } - - @Override - public PCollectionView> expand(PCollection input) { - PCollectionView> view = - PCollectionViews.iterableView( - input, - input.getWindowingStrategy(), - input.getCoder()); - - return input.apply(Combine.globally(new Concatenate()).withoutDefaults()) - .apply(CreateFlinkPCollectionView.>of(view)); - } - - @Override - protected String getKindString() { - return "StreamingViewAsIterable"; - } - } - - /** - * Specialized expansion for - * {@link View.AsSingleton View.AsSingleton} for the - * Flink runner in streaming mode. - */ - static class StreamingViewAsSingleton - extends PTransform, PCollectionView> { - private View.AsSingleton transform; - - /** - * Builds an instance of this class from the overridden transform. - */ - @SuppressWarnings("unused") // used via reflection in FlinkRunner#apply() - public StreamingViewAsSingleton(FlinkRunner runner, View.AsSingleton transform) { - this.transform = transform; - } - - @Override - public PCollectionView expand(PCollection input) { - Combine.Globally combine = Combine.globally( - new SingletonCombine<>(transform.hasDefaultValue(), transform.defaultValue())); - if (!transform.hasDefaultValue()) { - combine = combine.withoutDefaults(); - } - return input.apply(combine.asSingletonView()); - } - - @Override - protected String getKindString() { - return "StreamingViewAsSingleton"; - } - - private static class SingletonCombine extends Combine.BinaryCombineFn { - private boolean hasDefaultValue; - private T defaultValue; - - SingletonCombine(boolean hasDefaultValue, T defaultValue) { - this.hasDefaultValue = hasDefaultValue; - this.defaultValue = defaultValue; - } - - @Override - public T apply(T left, T right) { - throw new IllegalArgumentException("PCollection with more than one element " - + "accessed as a singleton view. Consider using Combine.globally().asSingleton() to " - + "combine the PCollection into a single value"); - } - - @Override - public T identity() { - if (hasDefaultValue) { - return defaultValue; - } else { - throw new IllegalArgumentException( - "Empty PCollection accessed as a singleton view. " - + "Consider setting withDefault to provide a default value"); - } - } - } - } - - static class StreamingCombineGloballyAsSingletonView - extends PTransform, PCollectionView> { - Combine.GloballyAsSingletonView transform; - - /** - * Builds an instance of this class from the overridden transform. - */ - @SuppressWarnings("unused") // used via reflection in FlinkRunner#apply() - public StreamingCombineGloballyAsSingletonView( - FlinkRunner runner, - Combine.GloballyAsSingletonView transform) { - this.transform = transform; - } - - @Override - public PCollectionView expand(PCollection input) { - PCollection combined = - input.apply(Combine.globally(transform.getCombineFn()) - .withoutDefaults() - .withFanout(transform.getFanout())); - - PCollectionView view = PCollectionViews.singletonView( - combined, - combined.getWindowingStrategy(), - transform.getInsertDefault(), - transform.getInsertDefault() - ? transform.getCombineFn().defaultValue() : null, - combined.getCoder()); - return combined - .apply(ParDo.of(new WrapAsList())) - .apply(CreateFlinkPCollectionView.of(view)); - } - - @Override - protected String getKindString() { - return "StreamingCombineGloballyAsSingletonView"; - } - } - - private static class WrapAsList extends DoFn> { - @ProcessElement - public void processElement(ProcessContext c) { - c.output(Collections.singletonList(c.element())); - } - } - - /** - * Combiner that combines {@code T}s into a single {@code List} containing all inputs. - * - *

For internal use by {@link StreamingViewAsMap}, {@link StreamingViewAsMultimap}, - * {@link StreamingViewAsList}, {@link StreamingViewAsIterable}. - * They require the input {@link PCollection} fits in memory. - * For a large {@link PCollection} this is expected to crash! - * - * @param the type of elements to concatenate. - */ - private static class Concatenate extends Combine.CombineFn, List> { - @Override - public List createAccumulator() { - return new ArrayList(); - } - - @Override - public List addInput(List accumulator, T input) { - accumulator.add(input); - return accumulator; - } - - @Override - public List mergeAccumulators(Iterable> accumulators) { - List result = createAccumulator(); - for (List accumulator : accumulators) { - result.addAll(accumulator); - } - return result; - } - - @Override - public List extractOutput(List accumulator) { - return accumulator; - } - - @Override - public Coder> getAccumulatorCoder(CoderRegistry registry, Coder inputCoder) { - return ListCoder.of(inputCoder); - } - - @Override - public Coder> getDefaultOutputCoder(CoderRegistry registry, Coder inputCoder) { - return ListCoder.of(inputCoder); - } - } - - /** - * Creates a primitive {@link PCollectionView}. - * - *

For internal use only by runner implementors. - * - * @param The type of the elements of the input PCollection - * @param The type associated with the {@link PCollectionView} used as a side input - */ - public static class CreateFlinkPCollectionView - extends PTransform>, PCollectionView> { - private PCollectionView view; - - private CreateFlinkPCollectionView(PCollectionView view) { - this.view = view; - } - - public static CreateFlinkPCollectionView of( - PCollectionView view) { - return new CreateFlinkPCollectionView<>(view); - } - - @Override - public PCollectionView expand(PCollection> input) { - return view; - } - } -} diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchViewOverrides.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchViewOverrides.java index b4a6e6470a4ca..ad3faed1cca14 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchViewOverrides.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchViewOverrides.java @@ -39,8 +39,6 @@ import java.util.Iterator; import java.util.List; import java.util.Map; -import org.apache.beam.runners.core.construction.PTransformReplacements; -import org.apache.beam.runners.core.construction.SingleInputOutputOverrideFactory; import org.apache.beam.runners.dataflow.internal.IsmFormat; import org.apache.beam.runners.dataflow.internal.IsmFormat.IsmRecord; import org.apache.beam.runners.dataflow.internal.IsmFormat.IsmRecordCoder; @@ -57,17 +55,11 @@ import org.apache.beam.sdk.coders.StructuredCoder; import org.apache.beam.sdk.coders.VarIntCoder; import org.apache.beam.sdk.coders.VarLongCoder; -import org.apache.beam.sdk.runners.AppliedPTransform; -import org.apache.beam.sdk.transforms.Combine; -import org.apache.beam.sdk.transforms.Combine.GloballyAsSingletonView; -import org.apache.beam.sdk.transforms.CombineFnBase.GlobalCombineFn; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.Flatten; import org.apache.beam.sdk.transforms.GroupByKey; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; -import org.apache.beam.sdk.transforms.View; -import org.apache.beam.sdk.transforms.View.AsSingleton; import org.apache.beam.sdk.transforms.View.CreatePCollectionView; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.GlobalWindow; @@ -83,7 +75,6 @@ import org.apache.beam.sdk.values.PCollectionList; import org.apache.beam.sdk.values.PCollectionTuple; import org.apache.beam.sdk.values.PCollectionView; -import org.apache.beam.sdk.values.PCollectionViews; import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.TupleTagList; import org.apache.beam.sdk.values.WindowingStrategy; @@ -192,12 +183,13 @@ public void processElement(ProcessContext c) } private final DataflowRunner runner; - /** - * Builds an instance of this class from the overridden transform. - */ + private final PCollectionView> view; + /** Builds an instance of this class from the overridden transform. */ @SuppressWarnings("unused") // used via reflection in DataflowRunner#apply() - public BatchViewAsMap(DataflowRunner runner, View.AsMap transform) { + public BatchViewAsMap( + DataflowRunner runner, CreatePCollectionView, Map> transform) { this.runner = runner; + this.view = transform.getView(); } @Override @@ -207,12 +199,7 @@ public PCollectionView> expand(PCollection> input) { private PCollectionView> applyInternal(PCollection> input) { - - @SuppressWarnings({"rawtypes", "unchecked"}) - KvCoder inputCoder = (KvCoder) input.getCoder(); try { - PCollectionView> view = PCollectionViews.mapView( - input, input.getWindowingStrategy(), inputCoder); return BatchViewAsMultimap.applyForMapLike(runner, input, view, true /* unique keys */); } catch (NonDeterministicException e) { runner.recordViewUsesNonDeterministicKeyCoder(this); @@ -249,19 +236,14 @@ protected String getKindString() { inputCoder.getKeyCoder(), FullWindowedValueCoder.of(inputCoder.getValueCoder(), windowCoder))); - TransformedMap, V> defaultValue = new TransformedMap<>( - WindowedValueToValue.of(), - ImmutableMap.>of()); - return BatchViewAsSingleton., TransformedMap, V>, Map, W> applyForSingleton( runner, input, new ToMapDoFn(windowCoder), - true, - defaultValue, - finalValueCoder); + finalValueCoder, + view); } } @@ -680,12 +662,13 @@ public void processElement(ProcessContext c) } private final DataflowRunner runner; - /** - * Builds an instance of this class from the overridden transform. - */ + private final PCollectionView>> view; + /** Builds an instance of this class from the overridden transform. */ @SuppressWarnings("unused") // used via reflection in DataflowRunner#apply() - public BatchViewAsMultimap(DataflowRunner runner, View.AsMultimap transform) { + public BatchViewAsMultimap( + DataflowRunner runner, CreatePCollectionView, Map>> transform) { this.runner = runner; + this.view = transform.getView(); } @Override @@ -695,12 +678,7 @@ public PCollectionView>> expand(PCollection> input) private PCollectionView>> applyInternal(PCollection> input) { - @SuppressWarnings({"rawtypes", "unchecked"}) - KvCoder inputCoder = (KvCoder) input.getCoder(); try { - PCollectionView>> view = PCollectionViews.multimapView( - input, input.getWindowingStrategy(), inputCoder); - return applyForMapLike(runner, input, view, false /* unique keys not expected */); } catch (NonDeterministicException e) { runner.recordViewUsesNonDeterministicKeyCoder(this); @@ -738,16 +716,15 @@ public PCollectionView>> expand(PCollection> input) IterableWithWindowedValuesToIterable.of(), ImmutableMap.>>of()); - return BatchViewAsSingleton., - TransformedMap>, Iterable>, - Map>, - W> applyForSingleton( - runner, - input, - new ToMultimapDoFn(windowCoder), - true, - defaultValue, - finalValueCoder); + return BatchViewAsSingleton + ., TransformedMap>, Iterable>, + Map>, W> + applyForSingleton( + runner, + input, + new ToMultimapDoFn(windowCoder), + finalValueCoder, + view); } private static PCollectionView applyForMapLike( @@ -827,10 +804,9 @@ private static PCollectionView app PCollectionList.of(ImmutableList.of( perHashWithReifiedWindows, windowMapSizeMetadata, windowMapKeysMetadata)); - return Pipeline.applyTransform(outputs, - Flatten.>>pCollections()) - .apply(CreateDataflowView.>, - ViewT>of(view)); + Pipeline.applyTransform(outputs, Flatten.>>pCollections()) + .apply(CreateDataflowView.>, ViewT>of(view)); + return view; } @Override @@ -915,14 +891,12 @@ public void processElement(ProcessContext c) throws Exception { } private final DataflowRunner runner; - private final View.AsSingleton transform; - /** - * Builds an instance of this class from the overridden transform. - */ + private final PCollectionView view; + /** Builds an instance of this class from the overridden transform. */ @SuppressWarnings("unused") // used via reflection in DataflowRunner#apply() - public BatchViewAsSingleton(DataflowRunner runner, View.AsSingleton transform) { + public BatchViewAsSingleton(DataflowRunner runner, CreatePCollectionView transform) { this.runner = runner; - this.transform = transform; + this.view = transform.getView(); } @Override @@ -935,9 +909,8 @@ public PCollectionView expand(PCollection input) { runner, input, new IsmRecordForSingularValuePerWindowDoFn(windowCoder), - transform.hasDefaultValue(), - transform.defaultValue(), - input.getCoder()); + input.getCoder(), + view); } static PCollectionView @@ -946,23 +919,13 @@ public PCollectionView expand(PCollection input) { PCollection input, DoFn>>>, IsmRecord>> doFn, - boolean hasDefault, - FinalT defaultValue, - Coder defaultValueCoder) { + Coder defaultValueCoder, + PCollectionView view) { @SuppressWarnings("unchecked") Coder windowCoder = (Coder) input.getWindowingStrategy().getWindowFn().windowCoder(); - @SuppressWarnings({"rawtypes", "unchecked"}) - PCollectionView view = - (PCollectionView) PCollectionViews.singletonView( - (PCollection) input, - (WindowingStrategy) input.getWindowingStrategy(), - hasDefault, - defaultValue, - defaultValueCoder); - IsmRecordCoder> ismCoder = coderForSingleton(windowCoder, defaultValueCoder); @@ -972,8 +935,9 @@ public PCollectionView expand(PCollection input) { reifiedPerWindowAndSorted.setCoder(ismCoder); runner.addPCollectionRequiringIndexedFormat(reifiedPerWindowAndSorted); - return reifiedPerWindowAndSorted.apply( + reifiedPerWindowAndSorted.apply( CreateDataflowView.>, ViewT>of(view)); + return view; } @Override @@ -1079,18 +1043,18 @@ public void processElement(ProcessContext c) throws Exception { } private final DataflowRunner runner; + private final PCollectionView> view; /** * Builds an instance of this class from the overridden transform. */ @SuppressWarnings("unused") // used via reflection in DataflowRunner#apply() - public BatchViewAsList(DataflowRunner runner, View.AsList transform) { + public BatchViewAsList(DataflowRunner runner, CreatePCollectionView> transform) { this.runner = runner; + this.view = transform.getView(); } @Override public PCollectionView> expand(PCollection input) { - PCollectionView> view = PCollectionViews.listView( - input, input.getWindowingStrategy(), input.getCoder()); return applyForIterableLike(runner, input, view); } @@ -1116,8 +1080,9 @@ static PCollectionView applyForIterab reifiedPerWindowAndSorted.setCoder(ismCoder); runner.addPCollectionRequiringIndexedFormat(reifiedPerWindowAndSorted); - return reifiedPerWindowAndSorted.apply( + reifiedPerWindowAndSorted.apply( CreateDataflowView.>, ViewT>of(view)); + return view; } PCollection>> reifiedPerWindowAndSorted = input @@ -1126,8 +1091,9 @@ static PCollectionView applyForIterab reifiedPerWindowAndSorted.setCoder(ismCoder); runner.addPCollectionRequiringIndexedFormat(reifiedPerWindowAndSorted); - return reifiedPerWindowAndSorted.apply( + reifiedPerWindowAndSorted.apply( CreateDataflowView.>, ViewT>of(view)); + return view; } @Override @@ -1164,18 +1130,17 @@ static class BatchViewAsIterable extends PTransform, PCollectionView>> { private final DataflowRunner runner; - /** - * Builds an instance of this class from the overridden transform. - */ + private final PCollectionView> view; + /** Builds an instance of this class from the overridden transform. */ @SuppressWarnings("unused") // used via reflection in DataflowRunner#apply() - public BatchViewAsIterable(DataflowRunner runner, View.AsIterable transform) { + public BatchViewAsIterable( + DataflowRunner runner, CreatePCollectionView> transform) { this.runner = runner; + this.view = transform.getView(); } @Override public PCollectionView> expand(PCollection input) { - PCollectionView> view = PCollectionViews.iterableView( - input, input.getWindowingStrategy(), input.getCoder()); return BatchViewAsList.applyForIterableLike(runner, input, view); } } @@ -1377,59 +1342,4 @@ public void verifyDeterministic() verifyDeterministic(this, "Expected map coder to be deterministic.", originalMapCoder); } } - - static class BatchCombineGloballyAsSingletonViewFactory - extends SingleInputOutputOverrideFactory< - PCollection, PCollectionView, - Combine.GloballyAsSingletonView> { - private final DataflowRunner runner; - - BatchCombineGloballyAsSingletonViewFactory(DataflowRunner runner) { - this.runner = runner; - } - - @Override - public PTransformReplacement, PCollectionView> - getReplacementTransform( - AppliedPTransform< - PCollection, PCollectionView, - GloballyAsSingletonView> - transform) { - GloballyAsSingletonView combine = transform.getTransform(); - return PTransformReplacement.of( - PTransformReplacements.getSingletonMainInput(transform), - new BatchCombineGloballyAsSingletonView<>( - runner, combine.getCombineFn(), combine.getFanout(), combine.getInsertDefault())); - } - - private static class BatchCombineGloballyAsSingletonView - extends PTransform, PCollectionView> { - private final DataflowRunner runner; - private final GlobalCombineFn combineFn; - private final int fanout; - private final boolean insertDefault; - - BatchCombineGloballyAsSingletonView( - DataflowRunner runner, - GlobalCombineFn combineFn, - int fanout, - boolean insertDefault) { - this.runner = runner; - this.combineFn = combineFn; - this.fanout = fanout; - this.insertDefault = insertDefault; - } - - @Override - public PCollectionView expand(PCollection input) { - PCollection combined = - input.apply(Combine.globally(combineFn).withoutDefaults().withFanout(fanout)); - AsSingleton viewAsSingleton = View.asSingleton(); - if (insertDefault) { - viewAsSingleton.withDefaultValue(combineFn.defaultValue()); - } - return combined.apply(new BatchViewAsSingleton<>(runner, viewAsSingleton)); - } - } - } } diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/CreateDataflowView.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/CreateDataflowView.java index e7542cb38a453..caad7f8406c81 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/CreateDataflowView.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/CreateDataflowView.java @@ -24,7 +24,7 @@ /** A {@link DataflowRunner} marker class for creating a {@link PCollectionView}. */ public class CreateDataflowView - extends PTransform, PCollectionView> { + extends PTransform, PCollection> { public static CreateDataflowView of(PCollectionView view) { return new CreateDataflowView<>(view); } @@ -36,8 +36,10 @@ private CreateDataflowView(PCollectionView view) { } @Override - public PCollectionView expand(PCollection input) { - return view; + public PCollection expand(PCollection input) { + return PCollection.createPrimitiveOutputInternal( + input.getPipeline(), input.getWindowingStrategy(), input.isBounded()) + .setCoder(input.getCoder()); } public PCollectionView getView() { diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java index 8eaf61b6d88f9..a3a7ab6bb1611 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java @@ -440,6 +440,14 @@ public void visitPrimitiveTransform(TransformHierarchy.Node node) { public void visitValue(PValue value, TransformHierarchy.Node producer) { LOG.debug("Checking translation of {}", value); // Primitive transforms are the only ones assigned step names. + if (producer.getTransform() instanceof CreateDataflowView) { + // CreateDataflowView produces a dummy output (as it must be a primitive transform) but + // in the Dataflow Job graph produces only the view and not the output PCollection. + asOutputReference( + ((CreateDataflowView) producer.getTransform()).getView(), + producer.toAppliedPTransform(getPipeline())); + return; + } asOutputReference(value, producer.toAppliedPTransform(getPipeline())); } @@ -465,6 +473,7 @@ public StepTranslator addStep(PTransform transform, String type) { StepTranslator stepContext = new StepTranslator(this, step); stepContext.addInput(PropertyNames.USER_NAME, getFullName(transform)); stepContext.addDisplayData(step, stepName, transform); + LOG.info("Adding {} as step {}", getCurrentTransform(transform).getFullName(), stepName); return stepContext; } @@ -677,7 +686,7 @@ private void translateTyped( context.addStep(transform, "CollectionToSingleton"); PCollection input = context.getInput(transform); stepContext.addInput(PropertyNames.PARALLEL_INPUT, input); - stepContext.addCollectionToSingletonOutput(input, context.getOutput(transform)); + stepContext.addCollectionToSingletonOutput(input, transform.getView()); } }); diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java index 3e7c8ce98475a..ea9db24ff638e 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java @@ -67,7 +67,6 @@ import org.apache.beam.runners.core.construction.SingleInputOutputOverrideFactory; import org.apache.beam.runners.core.construction.UnboundedReadFromBoundedSource; import org.apache.beam.runners.core.construction.UnconsumedReads; -import org.apache.beam.runners.dataflow.BatchViewOverrides.BatchCombineGloballyAsSingletonViewFactory; import org.apache.beam.runners.dataflow.DataflowPipelineTranslator.JobSpecification; import org.apache.beam.runners.dataflow.StreamingViewOverrides.StreamingCreatePCollectionViewFactory; import org.apache.beam.runners.dataflow.options.DataflowPipelineDebugOptions; @@ -129,6 +128,7 @@ import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollection.IsBounded; import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.PCollectionViews; import org.apache.beam.sdk.values.PDone; import org.apache.beam.sdk.values.PInput; import org.apache.beam.sdk.values.PValue; @@ -350,34 +350,29 @@ private List getOverrides(boolean streaming) { PTransformOverride.of( PTransformMatchers.stateOrTimerParDoSingle(), BatchStatefulParDoOverrides.singleOutputOverrideFactory())) - - .add( - PTransformOverride.of( - PTransformMatchers.classEqualTo(Combine.GloballyAsSingletonView.class), - new BatchCombineGloballyAsSingletonViewFactory(this))) .add( PTransformOverride.of( - PTransformMatchers.classEqualTo(View.AsMap.class), + PTransformMatchers.createViewWithViewFn(PCollectionViews.MapViewFn.class), new ReflectiveOneToOneOverrideFactory( BatchViewOverrides.BatchViewAsMap.class, this))) .add( PTransformOverride.of( - PTransformMatchers.classEqualTo(View.AsMultimap.class), + PTransformMatchers.createViewWithViewFn(PCollectionViews.MultimapViewFn.class), new ReflectiveOneToOneOverrideFactory( BatchViewOverrides.BatchViewAsMultimap.class, this))) .add( PTransformOverride.of( - PTransformMatchers.classEqualTo(View.AsSingleton.class), + PTransformMatchers.createViewWithViewFn(PCollectionViews.SingletonViewFn.class), new ReflectiveOneToOneOverrideFactory( BatchViewOverrides.BatchViewAsSingleton.class, this))) .add( PTransformOverride.of( - PTransformMatchers.classEqualTo(View.AsList.class), + PTransformMatchers.createViewWithViewFn(PCollectionViews.ListViewFn.class), new ReflectiveOneToOneOverrideFactory( BatchViewOverrides.BatchViewAsList.class, this))) .add( PTransformOverride.of( - PTransformMatchers.classEqualTo(View.AsIterable.class), + PTransformMatchers.createViewWithViewFn(PCollectionViews.IterableViewFn.class), new ReflectiveOneToOneOverrideFactory( BatchViewOverrides.BatchViewAsIterable.class, this))); } diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/StreamingViewOverrides.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/StreamingViewOverrides.java index 6c385d74085bd..18532486962c0 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/StreamingViewOverrides.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/StreamingViewOverrides.java @@ -42,12 +42,12 @@ class StreamingViewOverrides { static class StreamingCreatePCollectionViewFactory extends SingleInputOutputOverrideFactory< - PCollection, PCollectionView, CreatePCollectionView> { + PCollection, PCollection, CreatePCollectionView> { @Override - public PTransformReplacement, PCollectionView> + public PTransformReplacement, PCollection> getReplacementTransform( AppliedPTransform< - PCollection, PCollectionView, CreatePCollectionView> + PCollection, PCollection, CreatePCollectionView> transform) { StreamingCreatePCollectionView streamingView = new StreamingCreatePCollectionView<>(transform.getTransform().getView()); @@ -56,7 +56,7 @@ static class StreamingCreatePCollectionViewFactory } private static class StreamingCreatePCollectionView - extends PTransform, PCollectionView> { + extends PTransform, PCollection> { private final PCollectionView view; private StreamingCreatePCollectionView(PCollectionView view) { @@ -64,7 +64,7 @@ private StreamingCreatePCollectionView(PCollectionView view) { } @Override - public PCollectionView expand(PCollection input) { + public PCollection expand(PCollection input) { return input .apply(Combine.globally(new Concatenate()).withoutDefaults()) .apply(ParDo.of(StreamingPCollectionViewWriterFn.create(view, input.getCoder()))) diff --git a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslatorTest.java b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslatorTest.java index 89dc2d520564d..53215f60f1179 100644 --- a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslatorTest.java +++ b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslatorTest.java @@ -920,15 +920,15 @@ public void testToSingletonTranslationWithIsmSideInput() throws Exception { assertAllStepOutputsHaveUniqueIds(job); List steps = job.getSteps(); - assertEquals(5, steps.size()); + assertEquals(9, steps.size()); @SuppressWarnings("unchecked") List> toIsmRecordOutputs = - (List>) steps.get(3).getProperties().get(PropertyNames.OUTPUT_INFO); + (List>) steps.get(7).getProperties().get(PropertyNames.OUTPUT_INFO); assertTrue( Structs.getBoolean(Iterables.getOnlyElement(toIsmRecordOutputs), "use_indexed_format")); - Step collectionToSingletonStep = steps.get(4); + Step collectionToSingletonStep = steps.get(8); assertEquals("CollectionToSingleton", collectionToSingletonStep.getKind()); } diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java index 64aa35a68fa89..ac5e0cd8e9f74 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java @@ -508,50 +508,6 @@ public String toNativeString() { }; } - private static TransformEvaluator> viewAsSingleton() { - return new TransformEvaluator>() { - @Override - public void evaluate(View.AsSingleton transform, EvaluationContext context) { - Iterable> iter = - context.getWindowedValues(context.getInput(transform)); - PCollectionView output = context.getOutput(transform); - Coder>> coderInternal = output.getCoderInternal(); - - @SuppressWarnings("unchecked") - Iterable> iterCast = (Iterable>) iter; - - context.putPView(output, iterCast, coderInternal); - } - - @Override - public String toNativeString() { - return "collect()"; - } - }; - } - - private static TransformEvaluator> viewAsIter() { - return new TransformEvaluator>() { - @Override - public void evaluate(View.AsIterable transform, EvaluationContext context) { - Iterable> iter = - context.getWindowedValues(context.getInput(transform)); - PCollectionView> output = context.getOutput(transform); - Coder>> coderInternal = output.getCoderInternal(); - - @SuppressWarnings("unchecked") - Iterable> iterCast = (Iterable>) iter; - - context.putPView(output, iterCast, coderInternal); - } - - @Override - public String toNativeString() { - return "collect()"; - } - }; - } - private static TransformEvaluator> createPCollView() { return new TransformEvaluator>() { @@ -560,7 +516,7 @@ public void evaluate(View.CreatePCollectionView transform, EvaluationContext context) { Iterable> iter = context.getWindowedValues(context.getInput(transform)); - PCollectionView output = context.getOutput(transform); + PCollectionView output = transform.getView(); Coder>> coderInternal = output.getCoderInternal(); @SuppressWarnings("unchecked") @@ -645,8 +601,8 @@ public String toNativeString() { EVALUATORS.put(Combine.PerKey.class, combinePerKey()); EVALUATORS.put(Flatten.PCollections.class, flattenPColl()); EVALUATORS.put(Create.Values.class, create()); - EVALUATORS.put(View.AsSingleton.class, viewAsSingleton()); - EVALUATORS.put(View.AsIterable.class, viewAsIter()); +// EVALUATORS.put(View.AsSingleton.class, viewAsSingleton()); +// EVALUATORS.put(View.AsIterable.class, viewAsIter()); EVALUATORS.put(View.CreatePCollectionView.class, createPCollView()); EVALUATORS.put(Window.Assign.class, window()); EVALUATORS.put(Reshuffle.class, reshuffle()); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformHierarchy.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformHierarchy.java index 2f0e8efd7de8f..ee1ce7b2b6683 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformHierarchy.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformHierarchy.java @@ -24,10 +24,12 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.MoreObjects; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Iterables; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Map.Entry; @@ -39,6 +41,7 @@ import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.runners.PTransformOverrideFactory.ReplacementOutput; import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PInput; import org.apache.beam.sdk.values.POutput; import org.apache.beam.sdk.values.PValue; @@ -165,7 +168,7 @@ public void finishSpecifyingInput() { * nodes. */ public void setOutput(POutput output) { - for (PValue value : output.expand().values()) { + for (PCollection value : fullyExpand(output).values()) { if (!producers.containsKey(value)) { producers.put(value, current); value.finishSpecifyingOutput( @@ -226,6 +229,47 @@ public Node getCurrent() { return current; } + private Map, PCollection> fullyExpand(POutput output) { + Map, PCollection> result = new LinkedHashMap<>(); + for (Map.Entry, PValue> value : output.expand().entrySet()) { + if (value.getValue() instanceof PCollection) { + PCollection previous = result.put(value.getKey(), (PCollection) value.getValue()); + checkArgument( + previous == null, + "Found conflicting %ss in flattened expansion of %s: %s maps to %s and %s", + output, + TupleTag.class.getSimpleName(), + value.getKey(), + previous, + value.getValue()); + } else { + if (value.getValue().expand().size() == 1 + && Iterables.getOnlyElement(value.getValue().expand().values()) + .equals(value.getValue())) { + throw new IllegalStateException( + String.format( + "Non %s %s that expands into itself %s", + PCollection.class.getSimpleName(), + PValue.class.getSimpleName(), + value.getValue())); + } + for (Map.Entry, PCollection> valueComponent : + fullyExpand(value.getValue()).entrySet()) { + PCollection previous = result.put(valueComponent.getKey(), valueComponent.getValue()); + checkArgument( + previous == null, + "Found conflicting %ss in flattened expansion of %s: %s maps to %s and %s", + output, + TupleTag.class.getSimpleName(), + valueComponent.getKey(), + previous, + valueComponent.getValue()); + } + } + } + return result; + } + /** * Provides internal tracking of transform relationships with helper methods * for initialization and ordered visitation. diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Combine.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Combine.java index 9e1cc7113b804..6a90bcfde2e48 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Combine.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Combine.java @@ -1277,14 +1277,15 @@ private GloballyAsSingletonView( public PCollectionView expand(PCollection input) { PCollection combined = input.apply(Combine.globally(fn).withoutDefaults().withFanout(fanout)); - return combined.apply( - CreatePCollectionView.of( - PCollectionViews.singletonView( - combined, - input.getWindowingStrategy(), - insertDefault, - insertDefault ? fn.defaultValue() : null, - combined.getCoder()))); + PCollectionView view = + PCollectionViews.singletonView( + combined, + input.getWindowingStrategy(), + insertDefault, + insertDefault ? fn.defaultValue() : null, + combined.getCoder()); + combined.apply(CreatePCollectionView.of(view)); + return view; } public int getFanout() { diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/View.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/View.java index 073c750901da7..331b143f76c67 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/View.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/View.java @@ -257,8 +257,10 @@ public PCollectionView> expand(PCollection input) { throw new IllegalStateException("Unable to create a side-input view from input", e); } - return input.apply(CreatePCollectionView.>of(PCollectionViews.listView( - input, input.getWindowingStrategy(), input.getCoder()))); + PCollectionView> view = + PCollectionViews.listView(input, input.getWindowingStrategy(), input.getCoder()); + input.apply(CreatePCollectionView.>of(view)); + return view; } } @@ -282,8 +284,10 @@ public PCollectionView> expand(PCollection input) { throw new IllegalStateException("Unable to create a side-input view from input", e); } - return input.apply(CreatePCollectionView.>of(PCollectionViews.iterableView( - input, input.getWindowingStrategy(), input.getCoder()))); + PCollectionView> view = + PCollectionViews.iterableView(input, input.getWindowingStrategy(), input.getCoder()); + input.apply(CreatePCollectionView.>of(view)); + return view; } } @@ -423,11 +427,10 @@ public PCollectionView>> expand(PCollection> input) throw new IllegalStateException("Unable to create a side-input view from input", e); } - return input.apply(CreatePCollectionView., Map>>of( - PCollectionViews.multimapView( - input, - input.getWindowingStrategy(), - input.getCoder()))); + PCollectionView>> view = + PCollectionViews.multimapView(input, input.getWindowingStrategy(), input.getCoder()); + input.apply(CreatePCollectionView., Map>>of(view)); + return view; } } @@ -459,11 +462,10 @@ public PCollectionView> expand(PCollection> input) { throw new IllegalStateException("Unable to create a side-input view from input", e); } - return input.apply(CreatePCollectionView., Map>of( - PCollectionViews.mapView( - input, - input.getWindowingStrategy(), - input.getCoder()))); + PCollectionView> view = + PCollectionViews.mapView(input, input.getWindowingStrategy(), input.getCoder()); + input.apply(CreatePCollectionView., Map>of(view)); + return view; } } @@ -480,7 +482,7 @@ public PCollectionView> expand(PCollection> input) { */ @Internal public static class CreatePCollectionView - extends PTransform, PCollectionView> { + extends PTransform, PCollection> { private PCollectionView view; private CreatePCollectionView(PCollectionView view) { @@ -506,8 +508,10 @@ public PCollectionView getView() { } @Override - public PCollectionView expand(PCollection input) { - return view; + public PCollection expand(PCollection input) { + return PCollection.createPrimitiveOutputInternal( + input.getPipeline(), input.getWindowingStrategy(), input.isBounded()) + .setCoder(input.getCoder()); } } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollection.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollection.java index f210fd8614e4f..4063d110cb7b3 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollection.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollection.java @@ -20,6 +20,8 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; +import java.util.Collections; +import java.util.Map; import javax.annotation.Nullable; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.annotations.Internal; @@ -226,6 +228,11 @@ public String getName() { return super.getName(); } + @Override + public final Map, PValue> expand() { + return Collections., PValue>singletonMap(tag, this); + } + /** * Sets the name of this {@link PCollection}. Returns {@code this}. * @@ -314,6 +321,11 @@ public IsBounded isBounded() { private IsBounded isBounded; + /** + * A local {@link TupleTag} used in the expansion of this {@link PValueBase}. + */ + private final TupleTag tag = new TupleTag<>(); + private PCollection(Pipeline p) { super(p); } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionViews.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionViews.java index 74887c71b3687..5e2e2c316e2e1 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionViews.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionViews.java @@ -169,6 +169,15 @@ private SingletonViewFn(boolean hasDefault, T defaultValue, Coder valueCoder) } } + /** + * Returns if a default value was specified. + */ + @Deprecated + @Internal + public boolean hasDefault() { + return hasDefault; + } + /** * Returns the default value that was specified. * @@ -491,5 +500,10 @@ public boolean equals(Object other) { public String toString() { return MoreObjects.toStringHelper(this).add("tag", tag).toString(); } + + @Override + public Map, PValue> expand() { + return Collections., PValue>singletonMap(tag, pCollection); + } } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PValueBase.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PValueBase.java index 6f638d7b9a3a2..f312eac83746c 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PValueBase.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PValueBase.java @@ -19,8 +19,6 @@ import static com.google.common.base.Preconditions.checkState; -import java.util.Collections; -import java.util.Map; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.transforms.PTransform; @@ -86,11 +84,6 @@ protected PValueBase() { */ private String name; - /** - * A local {@link TupleTag} used in the expansion of this {@link PValueBase}. - */ - private TupleTag tag = new TupleTag<>(); - /** * Whether this {@link PValueBase} has been finalized, and its core * properties, e.g., name, can no longer be changed. @@ -107,11 +100,6 @@ boolean isFinishedSpecifying() { return finishedSpecifying; } - @Override - public final Map, PValue> expand() { - return Collections., PValue>singletonMap(tag, this); - } - @Override public void finishSpecifying(PInput input, PTransform transform) { finishedSpecifying = true; diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/PCollectionViewTesting.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/PCollectionViewTesting.java index adf27f8f52793..aaf8b91d1845b 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/PCollectionViewTesting.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/PCollectionViewTesting.java @@ -22,7 +22,9 @@ import com.google.common.base.MoreObjects; import com.google.common.collect.Iterables; import com.google.common.collect.Lists; +import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.Objects; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.IterableCoder; @@ -37,6 +39,7 @@ import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.PValue; import org.apache.beam.sdk.values.PValueBase; import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.WindowingStrategy; @@ -349,5 +352,10 @@ public String toString() { .add("viewFn", viewFn) .toString(); } + + @Override + public Map, PValue> expand() { + return Collections., PValue>singletonMap(tag, pCollection); + } } } From 4ebebfdb34de3e209c033de15e32cf67ab346d44 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Wed, 7 Jun 2017 23:00:43 -0700 Subject: [PATCH 005/200] Avoid flakiness in data channel for empty streams. As empty stream is used as end-of-stream marker, don't ever send it as the data itself. --- .../apache_beam/runners/worker/data_plane.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/sdks/python/apache_beam/runners/worker/data_plane.py b/sdks/python/apache_beam/runners/worker/data_plane.py index 5edd0b4907509..7365db69f56a1 100644 --- a/sdks/python/apache_beam/runners/worker/data_plane.py +++ b/sdks/python/apache_beam/runners/worker/data_plane.py @@ -167,12 +167,18 @@ def input_elements(self, instruction_id, expected_targets): yield data def output_stream(self, instruction_id, target): + # TODO: Return an output stream that sends data + # to the Runner once a fixed size buffer is full. + # Currently we buffer all the data before sending + # any messages. def add_to_send_queue(data): - self._to_send.put( - beam_fn_api_pb2.Elements.Data( - instruction_reference=instruction_id, - target=target, - data=data)) + if data: + self._to_send.put( + beam_fn_api_pb2.Elements.Data( + instruction_reference=instruction_id, + target=target, + data=data)) + # End of stream marker. self._to_send.put( beam_fn_api_pb2.Elements.Data( instruction_reference=instruction_id, From 3e04902008b410269b23179dc2146623ff1fbd0a Mon Sep 17 00:00:00 2001 From: Charles Chen Date: Wed, 7 Jun 2017 17:46:36 -0700 Subject: [PATCH 006/200] Refine Python DirectRunner watermark advancement behavior This change helps prepare for streaming pipeline execution. --- .../runners/direct/watermark_manager.py | 20 ++++++++++++++++--- sdks/python/apache_beam/utils/timestamp.py | 5 +++++ 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/sdks/python/apache_beam/runners/direct/watermark_manager.py b/sdks/python/apache_beam/runners/direct/watermark_manager.py index 3a135397e12f0..0d7cd4fd79c00 100644 --- a/sdks/python/apache_beam/runners/direct/watermark_manager.py +++ b/sdks/python/apache_beam/runners/direct/watermark_manager.py @@ -25,6 +25,7 @@ from apache_beam import pvalue from apache_beam.utils.timestamp import MAX_TIMESTAMP from apache_beam.utils.timestamp import MIN_TIMESTAMP +from apache_beam.utils.timestamp import TIME_GRANULARITY class WatermarkManager(object): @@ -193,9 +194,22 @@ def remove_pending(self, completed): def refresh(self): with self._lock: - pending_holder = (WatermarkManager.WATERMARK_NEG_INF - if self._pending else - WatermarkManager.WATERMARK_POS_INF) + min_pending_timestamp = WatermarkManager.WATERMARK_POS_INF + has_pending_elements = False + for input_bundle in self._pending: + # TODO(ccy): we can have the Bundle class keep track of the minimum + # timestamp so we don't have to do an iteration here. + for wv in input_bundle.get_elements_iterable(): + has_pending_elements = True + if wv.timestamp < min_pending_timestamp: + min_pending_timestamp = wv.timestamp + + # If there is a pending element with a certain timestamp, we can at most + # advance our watermark to the maximum timestamp less than that + # timestamp. + pending_holder = WatermarkManager.WATERMARK_POS_INF + if has_pending_elements: + pending_holder = min_pending_timestamp - TIME_GRANULARITY input_watermarks = [ tw.output_watermark for tw in self._input_transform_watermarks] diff --git a/sdks/python/apache_beam/utils/timestamp.py b/sdks/python/apache_beam/utils/timestamp.py index 5d1b48c14e33e..b3e840ee284e1 100644 --- a/sdks/python/apache_beam/utils/timestamp.py +++ b/sdks/python/apache_beam/utils/timestamp.py @@ -208,3 +208,8 @@ def __rmul__(self, other): def __mod__(self, other): other = Duration.of(other) return Duration(micros=self.micros % other.micros) + + +# The minimum granularity / interval expressible in a Timestamp / Duration +# object. +TIME_GRANULARITY = Duration(micros=1) From 156f326a16e15b4e22a189a2a263d11d7b273656 Mon Sep 17 00:00:00 2001 From: Colin Phipps Date: Mon, 5 Jun 2017 12:12:49 +0000 Subject: [PATCH 007/200] Raise entity limit per RPC to 9MB. This is closer to the API limit, while still leaving room for overhead. Brings the Java SDK into line with the Python SDK. Switch the unit test to use the size of each entity, which is what the connector is actually using, rather than the property size (which is slightly smaller and would cause the test to fail for some values). --- .../beam/sdk/io/gcp/datastore/DatastoreV1.java | 2 +- .../sdk/io/gcp/datastore/DatastoreV1Test.java | 16 +++++++++------- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/datastore/DatastoreV1.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/datastore/DatastoreV1.java index b198a6f568b91..06b9c8af9319f 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/datastore/DatastoreV1.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/datastore/DatastoreV1.java @@ -213,7 +213,7 @@ public class DatastoreV1 { * the mutations themselves and not the CommitRequest wrapper around them. */ @VisibleForTesting - static final int DATASTORE_BATCH_UPDATE_BYTES_LIMIT = 5_000_000; + static final int DATASTORE_BATCH_UPDATE_BYTES_LIMIT = 9_000_000; /** * Returns an empty {@link DatastoreV1.Read} builder. Configure the source {@code projectId}, diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/datastore/DatastoreV1Test.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/datastore/DatastoreV1Test.java index 460049e8b355e..229b1fbb23672 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/datastore/DatastoreV1Test.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/datastore/DatastoreV1Test.java @@ -651,12 +651,14 @@ private void datastoreWriterFnTest(int numMutations) throws Exception { @Test public void testDatatoreWriterFnWithLargeEntities() throws Exception { List mutations = new ArrayList<>(); - int propertySize = 900_000; + int entitySize = 0; for (int i = 0; i < 12; ++i) { - Entity.Builder entity = Entity.newBuilder().setKey(makeKey("key" + i, i + 1)); - entity.putProperties("long", makeValue(new String(new char[propertySize]) - ).setExcludeFromIndexes(true).build()); - mutations.add(makeUpsert(entity.build()).build()); + Entity entity = Entity.newBuilder().setKey(makeKey("key" + i, i + 1)) + .putProperties("long", makeValue(new String(new char[900_000]) + ).setExcludeFromIndexes(true).build()) + .build(); + entitySize = entity.getSerializedSize(); // Take the size of any one entity. + mutations.add(makeUpsert(entity).build()); } DatastoreWriterFn datastoreWriter = new DatastoreWriterFn(StaticValueProvider.of(PROJECT_ID), @@ -667,10 +669,10 @@ public void testDatatoreWriterFnWithLargeEntities() throws Exception { // This test is over-specific currently; it requires that we split the 12 entity writes into 3 // requests, but we only need each CommitRequest to be less than 10MB in size. - int propertiesPerRpc = DATASTORE_BATCH_UPDATE_BYTES_LIMIT / propertySize; + int entitiesPerRpc = DATASTORE_BATCH_UPDATE_BYTES_LIMIT / entitySize; int start = 0; while (start < mutations.size()) { - int end = Math.min(mutations.size(), start + propertiesPerRpc); + int end = Math.min(mutations.size(), start + entitiesPerRpc); CommitRequest.Builder commitRequest = CommitRequest.newBuilder(); commitRequest.setMode(CommitRequest.Mode.NON_TRANSACTIONAL); commitRequest.addAllMutations(mutations.subList(start, end)); From ca7b9c288151d318898ab000b91d26fcf62046ca Mon Sep 17 00:00:00 2001 From: Kenneth Knowles Date: Thu, 25 May 2017 06:29:09 -0700 Subject: [PATCH 008/200] Add Runner API oriented PTransformMatchers for DirectRunner overrides --- .../core/construction/PTransformMatchers.java | 94 ++++++++++++++++++- .../construction/PTransformTranslation.java | 7 +- .../construction/PTransformMatchersTest.java | 32 +++++++ 3 files changed, 128 insertions(+), 5 deletions(-) diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformMatchers.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformMatchers.java index bfe24a02ab633..c339891d51eda 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformMatchers.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformMatchers.java @@ -18,6 +18,7 @@ package org.apache.beam.runners.core.construction; import com.google.common.base.MoreObjects; +import java.io.IOException; import java.util.HashSet; import java.util.Set; import org.apache.beam.sdk.annotations.Experimental; @@ -49,6 +50,34 @@ public class PTransformMatchers { private PTransformMatchers() {} + /** + * Returns a {@link PTransformMatcher} that matches a {@link PTransform} if the URN of the + * {@link PTransform} is equal to the URN provided ot this matcher. + */ + public static PTransformMatcher urnEqualTo(String urn) { + return new EqualUrnPTransformMatcher(urn); + } + + private static class EqualUrnPTransformMatcher implements PTransformMatcher { + private final String urn; + + private EqualUrnPTransformMatcher(String urn) { + this.urn = urn; + } + + @Override + public boolean matches(AppliedPTransform application) { + return urn.equals(PTransformTranslation.urnForTransformOrNull(application.getTransform())); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("urn", urn) + .toString(); + } + } + /** * Returns a {@link PTransformMatcher} that matches a {@link PTransform} if the class of the * {@link PTransform} is equal to the {@link Class} provided ot this matcher. @@ -150,6 +179,68 @@ public String toString() { }; } + /** + * A {@link PTransformMatcher} that matches a {@link ParDo} by URN if it has a splittable {@link + * DoFn}. + */ + public static PTransformMatcher splittableParDo() { + return new PTransformMatcher() { + @Override + public boolean matches(AppliedPTransform application) { + if (PTransformTranslation.PAR_DO_TRANSFORM_URN.equals( + PTransformTranslation.urnForTransformOrNull(application.getTransform()))) { + + try { + return ParDoTranslation.isSplittable(application); + } catch (IOException e) { + throw new RuntimeException( + String.format( + "Transform with URN %s could not be translated", + PTransformTranslation.PAR_DO_TRANSFORM_URN), + e); + } + } + return false; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper("SplittableParDoMultiMatcher").toString(); + } + }; + } + + /** + * A {@link PTransformMatcher} that matches a {@link ParDo} transform by URN + * and whether it contains state or timers as specified by {@link ParDoTranslation}. + */ + public static PTransformMatcher stateOrTimerParDo() { + return new PTransformMatcher() { + @Override + public boolean matches(AppliedPTransform application) { + if (PTransformTranslation.PAR_DO_TRANSFORM_URN.equals( + PTransformTranslation.urnForTransformOrNull(application.getTransform()))) { + + try { + return ParDoTranslation.usesStateOrTimers(application); + } catch (IOException e) { + throw new RuntimeException( + String.format( + "Transform with URN %s could not be translated", + PTransformTranslation.PAR_DO_TRANSFORM_URN), + e); + } + } + return false; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper("StateOrTimerParDoMatcher").toString(); + } + }; + } + /** * A {@link PTransformMatcher} that matches a {@link ParDo.MultiOutput} containing a {@link DoFn} * that uses state or timers, as specified by {@link DoFnSignature#usesState()} and @@ -268,7 +359,8 @@ public static PTransformMatcher writeWithRunnerDeterminedSharding() { return new PTransformMatcher() { @Override public boolean matches(AppliedPTransform application) { - if (application.getTransform() instanceof WriteFiles) { + if (PTransformTranslation.WRITE_FILES_TRANSFORM_URN.equals( + PTransformTranslation.urnForTransformOrNull(application.getTransform()))) { WriteFiles write = (WriteFiles) application.getTransform(); return write.getSharding() == null && write.getNumShards() == null; } diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformTranslation.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformTranslation.java index 32ecf430c2718..bae7b0574b22a 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformTranslation.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformTranslation.java @@ -179,13 +179,12 @@ public static String urnForTransformOrNull(PTransform transform) { * Returns the URN for the transform if it is known, otherwise throws. */ public static String urnForTransform(PTransform transform) { - TransformPayloadTranslator translator = KNOWN_PAYLOAD_TRANSLATORS.get(transform.getClass()); - if (translator == null) { + String urn = urnForTransformOrNull(transform); + if (urn == null) { throw new IllegalStateException( String.format("No translator known for %s", transform.getClass().getName())); } - - return translator.getUrn(transform); + return urn; } /** diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/PTransformMatchersTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/PTransformMatchersTest.java index 249759880803d..6459849f24fa9 100644 --- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/PTransformMatchersTest.java +++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/PTransformMatchersTest.java @@ -27,6 +27,8 @@ import com.google.common.collect.ImmutableMap; import java.io.Serializable; import java.util.Collections; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.coders.VarIntCoder; import org.apache.beam.sdk.coders.VoidCoder; import org.apache.beam.sdk.io.DefaultFilenamePolicy; @@ -95,9 +97,14 @@ public class PTransformMatchersTest implements Serializable { PCollection> input = PCollection.createPrimitiveOutputInternal( p, WindowingStrategy.globalDefault(), IsBounded.BOUNDED); + input.setName("dummy input"); + input.setCoder(KvCoder.of(StringUtf8Coder.of(), VarIntCoder.of())); + PCollection output = PCollection.createPrimitiveOutputInternal( p, WindowingStrategy.globalDefault(), IsBounded.BOUNDED); + output.setName("dummy output"); + output.setCoder(VarIntCoder.of()); return AppliedPTransform.of("pardo", input.expand(), output.expand(), pardo, p); } @@ -271,6 +278,18 @@ public void parDoMultiSplittable() { assertThat(PTransformMatchers.stateOrTimerParDoSingle().matches(parDoApplication), is(false)); } + @Test + public void parDoSplittable() { + AppliedPTransform parDoApplication = + getAppliedTransform( + ParDo.of(splittableDoFn).withOutputTags(new TupleTag(), TupleTagList.empty())); + assertThat(PTransformMatchers.splittableParDo().matches(parDoApplication), is(true)); + + assertThat(PTransformMatchers.stateOrTimerParDoMulti().matches(parDoApplication), is(false)); + assertThat(PTransformMatchers.splittableParDoSingle().matches(parDoApplication), is(false)); + assertThat(PTransformMatchers.stateOrTimerParDoSingle().matches(parDoApplication), is(false)); + } + @Test public void parDoMultiWithState() { AppliedPTransform parDoApplication = @@ -283,6 +302,19 @@ public void parDoMultiWithState() { assertThat(PTransformMatchers.stateOrTimerParDoSingle().matches(parDoApplication), is(false)); } + @Test + public void parDoWithState() { + AppliedPTransform statefulApplication = + getAppliedTransform( + ParDo.of(doFnWithState).withOutputTags(new TupleTag(), TupleTagList.empty())); + assertThat(PTransformMatchers.stateOrTimerParDo().matches(statefulApplication), is(true)); + + AppliedPTransform splittableApplication = + getAppliedTransform( + ParDo.of(splittableDoFn).withOutputTags(new TupleTag(), TupleTagList.empty())); + assertThat(PTransformMatchers.stateOrTimerParDo().matches(splittableApplication), is(false)); + } + @Test public void parDoMultiWithTimers() { AppliedPTransform parDoApplication = From d8d9087877c01f1786271726a541fb3eeda7f939 Mon Sep 17 00:00:00 2001 From: Kenneth Knowles Date: Thu, 25 May 2017 06:31:16 -0700 Subject: [PATCH 009/200] DirectRunner override matchers using Runner API --- .../beam/runners/direct/DirectRunner.java | 20 +++++++++---------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java index dbd1ec47ed541..136ccf3bd86a5 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java @@ -30,6 +30,7 @@ import java.util.Set; import org.apache.beam.runners.core.SplittableParDoViaKeyedWorkItems; import org.apache.beam.runners.core.construction.PTransformMatchers; +import org.apache.beam.runners.core.construction.PTransformTranslation; import org.apache.beam.runners.core.construction.SplittableParDo; import org.apache.beam.runners.direct.DirectRunner.DirectPipelineResult; import org.apache.beam.runners.direct.TestStreamEvaluatorFactory.DirectTestStreamFactory; @@ -42,12 +43,9 @@ import org.apache.beam.sdk.metrics.MetricsEnvironment; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.runners.PTransformOverride; -import org.apache.beam.sdk.testing.TestStream; -import org.apache.beam.sdk.transforms.GroupByKey; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.ParDo.MultiOutput; -import org.apache.beam.sdk.transforms.View.CreatePCollectionView; import org.apache.beam.sdk.util.UserCodeException; import org.apache.beam.sdk.values.PCollection; import org.joda.time.Duration; @@ -230,33 +228,33 @@ List defaultTransformOverrides() { new WriteWithShardingFactory())) /* Uses a view internally. */ .add( PTransformOverride.of( - PTransformMatchers.classEqualTo(CreatePCollectionView.class), + PTransformMatchers.urnEqualTo(PTransformTranslation.CREATE_VIEW_TRANSFORM_URN), new ViewOverrideFactory())) /* Uses pardos and GBKs */ .add( PTransformOverride.of( - PTransformMatchers.classEqualTo(TestStream.class), + PTransformMatchers.urnEqualTo(PTransformTranslation.TEST_STREAM_TRANSFORM_URN), new DirectTestStreamFactory(this))) /* primitive */ // SplittableParMultiDo is implemented in terms of nonsplittable simple ParDos and extra // primitives .add( PTransformOverride.of( - PTransformMatchers.splittableParDoMulti(), new ParDoMultiOverrideFactory())) + PTransformMatchers.splittableParDo(), new ParDoMultiOverrideFactory())) // state and timer pardos are implemented in terms of simple ParDos and extra primitives .add( PTransformOverride.of( - PTransformMatchers.stateOrTimerParDoMulti(), new ParDoMultiOverrideFactory())) + PTransformMatchers.stateOrTimerParDo(), new ParDoMultiOverrideFactory())) .add( PTransformOverride.of( - PTransformMatchers.classEqualTo(SplittableParDo.ProcessKeyedElements.class), + PTransformMatchers.urnEqualTo( + SplittableParDo.SPLITTABLE_PROCESS_KEYED_ELEMENTS_URN), new SplittableParDoViaKeyedWorkItems.OverrideFactory())) .add( PTransformOverride.of( - PTransformMatchers.classEqualTo( - SplittableParDoViaKeyedWorkItems.GBKIntoKeyedWorkItems.class), + PTransformMatchers.urnEqualTo(SplittableParDo.SPLITTABLE_GBKIKWI_URN), new DirectGBKIntoKeyedWorkItemsOverrideFactory())) /* Returns a GBKO */ .add( PTransformOverride.of( - PTransformMatchers.classEqualTo(GroupByKey.class), + PTransformMatchers.urnEqualTo(PTransformTranslation.GROUP_BY_KEY_TRANSFORM_URN), new DirectGroupByKeyOverrideFactory())) /* returns two chained primitives. */ .build(); } From 36aea2d26d7c8ea3299d7a25d617a6ba99794e18 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Wed, 7 Jun 2017 23:35:11 -0700 Subject: [PATCH 010/200] Use inner module for non-public coders. --- sdks/python/apache_beam/runners/worker/operation_specs.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sdks/python/apache_beam/runners/worker/operation_specs.py b/sdks/python/apache_beam/runners/worker/operation_specs.py index b8d19a1427563..bdafbeaf44ad6 100644 --- a/sdks/python/apache_beam/runners/worker/operation_specs.py +++ b/sdks/python/apache_beam/runners/worker/operation_specs.py @@ -326,11 +326,12 @@ def get_coder_from_spec(coder_spec): assert len(coder_spec['component_encodings']) == 2 value_coder, window_coder = [ get_coder_from_spec(c) for c in coder_spec['component_encodings']] - return coders.WindowedValueCoder(value_coder, window_coder=window_coder) + return coders.coders.WindowedValueCoder( + value_coder, window_coder=window_coder) elif coder_spec['@type'] == 'kind:interval_window': assert ('component_encodings' not in coder_spec or not coder_spec['component_encodings']) - return coders.IntervalWindowCoder() + return coders.coders.IntervalWindowCoder() elif coder_spec['@type'] == 'kind:global_window': assert ('component_encodings' not in coder_spec or not coder_spec['component_encodings']) From 349898c4702fc3e52d8c0cd1c5a04f14cd40fd27 Mon Sep 17 00:00:00 2001 From: Seshadri Chakkravarthy Date: Thu, 18 May 2017 12:07:01 -0700 Subject: [PATCH 011/200] Implements HCatalogIO --- sdks/java/io/hcatalog/pom.xml | 163 ++++++ .../beam/sdk/io/hcatalog/HCatalogIO.java | 511 ++++++++++++++++++ .../beam/sdk/io/hcatalog/package-info.java | 22 + .../io/hcatalog/EmbeddedMetastoreService.java | 88 +++ .../beam/sdk/io/hcatalog/HCatalogIOTest.java | 271 ++++++++++ .../sdk/io/hcatalog/HCatalogIOTestUtils.java | 106 ++++ .../hcatalog/src/test/resources/hive-site.xml | 301 +++++++++++ sdks/java/io/pom.xml | 1 + 8 files changed, 1463 insertions(+) create mode 100644 sdks/java/io/hcatalog/pom.xml create mode 100644 sdks/java/io/hcatalog/src/main/java/org/apache/beam/sdk/io/hcatalog/HCatalogIO.java create mode 100644 sdks/java/io/hcatalog/src/main/java/org/apache/beam/sdk/io/hcatalog/package-info.java create mode 100644 sdks/java/io/hcatalog/src/test/java/org/apache/beam/sdk/io/hcatalog/EmbeddedMetastoreService.java create mode 100644 sdks/java/io/hcatalog/src/test/java/org/apache/beam/sdk/io/hcatalog/HCatalogIOTest.java create mode 100644 sdks/java/io/hcatalog/src/test/java/org/apache/beam/sdk/io/hcatalog/HCatalogIOTestUtils.java create mode 100644 sdks/java/io/hcatalog/src/test/resources/hive-site.xml diff --git a/sdks/java/io/hcatalog/pom.xml b/sdks/java/io/hcatalog/pom.xml new file mode 100644 index 0000000000000..19b62a5826285 --- /dev/null +++ b/sdks/java/io/hcatalog/pom.xml @@ -0,0 +1,163 @@ + + + + + 4.0.0 + + + org.apache.beam + beam-sdks-java-io-parent + 2.1.0-SNAPSHOT + ../pom.xml + + + beam-sdks-java-io-hcatalog + Apache Beam :: SDKs :: Java :: IO :: HCatalog + IO to read and write for HCatalog source. + + + 2.1.0 + 2.5 + + + + + + org.apache.maven.plugins + maven-surefire-plugin + + true + + + + + + + + org.apache.beam + beam-sdks-java-core + + + + org.apache.beam + beam-sdks-java-io-hadoop-common + + + + org.apache.hadoop + hadoop-common + + + + commons-io + commons-io + ${apache.commons.version} + + + + org.slf4j + slf4j-api + + + + com.google.guava + guava + + + + com.google.code.findbugs + jsr305 + + + + org.apache.hive + hive-exec + ${hive.version} + + + + com.google.auto.value + auto-value + provided + + + + org.apache.hive.hcatalog + hive-hcatalog-core + ${hive.version} + + + org.apache.hive + hive-exec + + + com.google.protobuf + protobuf-java + + + + + + org.apache.hive.hcatalog + hive-hcatalog-core + tests + test + ${hive.version} + + + + junit + junit + test + + + + org.apache.hive + hive-exec + ${hive.version} + test-jar + test + + + + org.apache.hive + hive-common + ${hive.version} + test + + + + org.apache.hive + hive-cli + ${hive.version} + test + + + + org.apache.beam + beam-runners-direct-java + test + + + + org.hamcrest + hamcrest-all + test + + + \ No newline at end of file diff --git a/sdks/java/io/hcatalog/src/main/java/org/apache/beam/sdk/io/hcatalog/HCatalogIO.java b/sdks/java/io/hcatalog/src/main/java/org/apache/beam/sdk/io/hcatalog/HCatalogIO.java new file mode 100644 index 0000000000000..07b56e3f650cd --- /dev/null +++ b/sdks/java/io/hcatalog/src/main/java/org/apache/beam/sdk/io/hcatalog/HCatalogIO.java @@ -0,0 +1,511 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.hcatalog; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.auto.value.AutoValue; +import com.google.common.annotations.VisibleForTesting; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.NoSuchElementException; +import javax.annotation.Nullable; +import org.apache.beam.sdk.annotations.Experimental; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.io.BoundedSource; +import org.apache.beam.sdk.io.hadoop.WritableCoder; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.display.DisplayData; +import org.apache.beam.sdk.values.PBegin; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PDone; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.hadoop.hive.metastore.IMetaStoreClient; +import org.apache.hadoop.hive.ql.metadata.Table; +import org.apache.hadoop.hive.ql.stats.StatsUtils; +import org.apache.hive.hcatalog.common.HCatConstants; +import org.apache.hive.hcatalog.common.HCatException; +import org.apache.hive.hcatalog.common.HCatUtil; +import org.apache.hive.hcatalog.data.DefaultHCatRecord; +import org.apache.hive.hcatalog.data.HCatRecord; +import org.apache.hive.hcatalog.data.transfer.DataTransferFactory; +import org.apache.hive.hcatalog.data.transfer.HCatReader; +import org.apache.hive.hcatalog.data.transfer.HCatWriter; +import org.apache.hive.hcatalog.data.transfer.ReadEntity; +import org.apache.hive.hcatalog.data.transfer.ReaderContext; +import org.apache.hive.hcatalog.data.transfer.WriteEntity; +import org.apache.hive.hcatalog.data.transfer.WriterContext; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * IO to read and write data using HCatalog. + * + *

Reading using HCatalog

+ * + *

HCatalog source supports reading of HCatRecord from a HCatalog managed source, for eg. Hive. + * + *

To configure a HCatalog source, you must specify a metastore URI and a table name. Other + * optional parameters are database & filter For instance: + * + *

{@code
+ * Map configProperties = new HashMap();
+ * configProperties.put("hive.metastore.uris","thrift://metastore-host:port");
+ *
+ * pipeline
+ *   .apply(HCatalogIO.read()
+ *       .withConfigProperties(configProperties) //mandatory
+ *       .withTable("employee") //mandatory
+ *       .withDatabase("default") //optional, assumes default if none specified
+ *       .withFilter(filterString) //optional,
+ *       should be specified if the table is partitioned
+ * }
+ * + *

Writing using HCatalog

+ * + *

HCatalog sink supports writing of HCatRecord to a HCatalog managed source, for eg. Hive. + * + *

To configure a HCatalog sink, you must specify a metastore URI and a table name. Other + * optional parameters are database, partition & batchsize The destination table should exist + * beforehand, the transform does not create a new table if it does not exist For instance: + * + *

{@code
+ * Map configProperties = new HashMap();
+ * configProperties.put("hive.metastore.uris","thrift://metastore-host:port");
+ *
+ * pipeline
+ *   .apply(...)
+ *   .apply(HiveIO.write()
+ *       .withConfigProperties(configProperties) //mandatory
+ *       .withTable("employee") //mandatory
+ *       .withDatabase("default") //optional, assumes default if none specified
+ *       .withFilter(partitionValues) //optional,
+ *       should be specified if the table is partitioned
+ *       .withBatchSize(1024L)) //optional,
+ *       assumes a default batch size of 1024 if none specified
+ * }
+ */ +@Experimental +public class HCatalogIO { + + private static final Logger LOG = LoggerFactory.getLogger(HCatalogIO.class); + + /** Write data to Hive. */ + public static Write write() { + return new AutoValue_HCatalogIO_Write.Builder().setBatchSize(1024L).build(); + } + + /** Read data from Hive. */ + public static Read read() { + return new AutoValue_HCatalogIO_Read.Builder().setDatabase("default").build(); + } + + private HCatalogIO() {} + + /** A {@link PTransform} to read data using HCatalog. */ + @VisibleForTesting + @AutoValue + public abstract static class Read extends PTransform> { + @Nullable + abstract Map getConfigProperties(); + + @Nullable + abstract String getDatabase(); + + @Nullable + abstract String getTable(); + + @Nullable + abstract String getFilter(); + + @Nullable + abstract ReaderContext getContext(); + + @Nullable + abstract Integer getSplitId(); + + abstract Builder toBuilder(); + + @AutoValue.Builder + abstract static class Builder { + abstract Builder setConfigProperties(Map configProperties); + + abstract Builder setDatabase(String database); + + abstract Builder setTable(String table); + + abstract Builder setFilter(String filter); + + abstract Builder setSplitId(Integer splitId); + + abstract Builder setContext(ReaderContext context); + + abstract Read build(); + } + + /** Sets the configuration properties like metastore URI. This is mandatory */ + public Read withConfigProperties(Map configProperties) { + return toBuilder().setConfigProperties(new HashMap<>(configProperties)).build(); + } + + /** Sets the database name. This is optional, assumes 'default' database if none specified */ + public Read withDatabase(String database) { + return toBuilder().setDatabase(database).build(); + } + + /** Sets the table name to read from. This is mandatory */ + public Read withTable(String table) { + return toBuilder().setTable(table).build(); + } + + /** Sets the filter (partition) details. This is optional, assumes none if not specified */ + public Read withFilter(String filter) { + return toBuilder().setFilter(filter).build(); + } + + Read withSplitId(int splitId) { + checkArgument(splitId >= 0, "Invalid split id-" + splitId); + return toBuilder().setSplitId(splitId).build(); + } + + Read withContext(ReaderContext context) { + return toBuilder().setContext(context).build(); + } + + @Override + public PCollection expand(PBegin input) { + return input.apply(org.apache.beam.sdk.io.Read.from(new BoundedHCatalogSource(this))); + } + + @Override + public void validate(PipelineOptions options) { + checkNotNull(getTable(), "table"); + checkNotNull(getConfigProperties(), "configProperties"); + } + + @Override + public void populateDisplayData(DisplayData.Builder builder) { + super.populateDisplayData(builder); + builder.add(DisplayData.item("configProperties", getConfigProperties().toString())); + builder.add(DisplayData.item("table", getTable())); + builder.addIfNotNull(DisplayData.item("database", getDatabase())); + builder.addIfNotNull(DisplayData.item("filter", getFilter())); + } + } + + /** A HCatalog {@link BoundedSource} reading {@link HCatRecord} from a given instance. */ + @VisibleForTesting + static class BoundedHCatalogSource extends BoundedSource { + private Read spec; + + BoundedHCatalogSource(Read spec) { + this.spec = spec; + } + + @Override + @SuppressWarnings({"unchecked", "rawtypes"}) + public Coder getDefaultOutputCoder() { + return (Coder) WritableCoder.of(DefaultHCatRecord.class); + } + + @Override + public void validate() { + spec.validate(null); + } + + @Override + public void populateDisplayData(DisplayData.Builder builder) { + spec.populateDisplayData(builder); + } + + @Override + public BoundedReader createReader(PipelineOptions options) { + return new BoundedHCatalogReader(this); + } + + /** + * Returns the size of the table in bytes, does not take into consideration filter/partition + * details passed, if any. + */ + @Override + public long getEstimatedSizeBytes(PipelineOptions pipelineOptions) throws Exception { + Configuration conf = new Configuration(); + for (Entry entry : spec.getConfigProperties().entrySet()) { + conf.set(entry.getKey(), entry.getValue()); + } + IMetaStoreClient client = null; + try { + HiveConf hiveConf = HCatUtil.getHiveConf(conf); + client = HCatUtil.getHiveMetastoreClient(hiveConf); + Table table = HCatUtil.getTable(client, spec.getDatabase(), spec.getTable()); + return StatsUtils.getFileSizeForTable(hiveConf, table); + } finally { + // IMetaStoreClient is not AutoCloseable, closing it manually + if (client != null) { + client.close(); + } + } + } + + /** + * Calculates the 'desired' number of splits based on desiredBundleSizeBytes which is passed as + * a hint to native API. Retrieves the actual splits generated by native API, which could be + * different from the 'desired' split count calculated using desiredBundleSizeBytes + */ + @Override + public List> split( + long desiredBundleSizeBytes, PipelineOptions options) throws Exception { + int desiredSplitCount = 1; + long estimatedSizeBytes = getEstimatedSizeBytes(options); + if (desiredBundleSizeBytes > 0 && estimatedSizeBytes > 0) { + desiredSplitCount = (int) Math.ceil((double) estimatedSizeBytes / desiredBundleSizeBytes); + } + ReaderContext readerContext = getReaderContext(desiredSplitCount); + //process the splits returned by native API + //this could be different from 'desiredSplitCount' calculated above + LOG.info( + "Splitting into bundles of {} bytes: " + + "estimated size {}, desired split count {}, actual split count {}", + desiredBundleSizeBytes, + estimatedSizeBytes, + desiredSplitCount, + readerContext.numSplits()); + + List> res = new ArrayList<>(); + for (int split = 0; split < readerContext.numSplits(); split++) { + res.add(new BoundedHCatalogSource(spec.withContext(readerContext).withSplitId(split))); + } + return res; + } + + private ReaderContext getReaderContext(long desiredSplitCount) throws HCatException { + ReadEntity entity = + new ReadEntity.Builder() + .withDatabase(spec.getDatabase()) + .withTable(spec.getTable()) + .withFilter(spec.getFilter()) + .build(); + // pass the 'desired' split count as an hint to the API + Map configProps = new HashMap<>(spec.getConfigProperties()); + configProps.put( + HCatConstants.HCAT_DESIRED_PARTITION_NUM_SPLITS, String.valueOf(desiredSplitCount)); + return DataTransferFactory.getHCatReader(entity, configProps).prepareRead(); + } + + static class BoundedHCatalogReader extends BoundedSource.BoundedReader { + private final BoundedHCatalogSource source; + private HCatRecord current; + private Iterator hcatIterator; + + public BoundedHCatalogReader(BoundedHCatalogSource source) { + this.source = source; + } + + @Override + public boolean start() throws HCatException { + HCatReader reader = + DataTransferFactory.getHCatReader(source.spec.getContext(), source.spec.getSplitId()); + hcatIterator = reader.read(); + return advance(); + } + + @Override + public boolean advance() { + if (hcatIterator.hasNext()) { + current = hcatIterator.next(); + return true; + } else { + current = null; + return false; + } + } + + @Override + public BoundedHCatalogSource getCurrentSource() { + return source; + } + + @Override + public HCatRecord getCurrent() { + if (current == null) { + throw new NoSuchElementException("Current element is null"); + } + return current; + } + + @Override + public void close() { + // nothing to close/release + } + } + } + + /** A {@link PTransform} to write to a HCatalog managed source. */ + @AutoValue + public abstract static class Write extends PTransform, PDone> { + @Nullable + abstract Map getConfigProperties(); + + @Nullable + abstract String getDatabase(); + + @Nullable + abstract String getTable(); + + @Nullable + abstract Map getFilter(); + + abstract long getBatchSize(); + + abstract Builder toBuilder(); + + @AutoValue.Builder + abstract static class Builder { + abstract Builder setConfigProperties(Map configProperties); + + abstract Builder setDatabase(String database); + + abstract Builder setTable(String table); + + abstract Builder setFilter(Map partition); + + abstract Builder setBatchSize(long batchSize); + + abstract Write build(); + } + + /** Sets the configuration properties like metastore URI. This is mandatory */ + public Write withConfigProperties(Map configProperties) { + return toBuilder().setConfigProperties(new HashMap<>(configProperties)).build(); + } + + /** Sets the database name. This is optional, assumes 'default' database if none specified */ + public Write withDatabase(String database) { + return toBuilder().setDatabase(database).build(); + } + + /** Sets the table name to write to, the table should exist beforehand. This is mandatory */ + public Write withTable(String table) { + return toBuilder().setTable(table).build(); + } + + /** Sets the filter (partition) details. This is required if the table is partitioned */ + public Write withFilter(Map filter) { + return toBuilder().setFilter(filter).build(); + } + + /** + * Sets batch size for the write operation. This is optional, assumes a default batch size of + * 1024 if not set + */ + public Write withBatchSize(long batchSize) { + return toBuilder().setBatchSize(batchSize).build(); + } + + @Override + public PDone expand(PCollection input) { + input.apply(ParDo.of(new WriteFn(this))); + return PDone.in(input.getPipeline()); + } + + @Override + public void validate(PipelineOptions options) { + checkNotNull(getConfigProperties(), "configProperties"); + checkNotNull(getTable(), "table"); + } + + private static class WriteFn extends DoFn { + private final Write spec; + private WriterContext writerContext; + private HCatWriter slaveWriter; + private HCatWriter masterWriter; + private List hCatRecordsBatch; + + public WriteFn(Write spec) { + this.spec = spec; + } + + @Override + public void populateDisplayData(DisplayData.Builder builder) { + super.populateDisplayData(builder); + builder.addIfNotNull(DisplayData.item("database", spec.getDatabase())); + builder.add(DisplayData.item("table", spec.getTable())); + builder.addIfNotNull(DisplayData.item("filter", String.valueOf(spec.getFilter()))); + builder.add(DisplayData.item("configProperties", spec.getConfigProperties().toString())); + builder.add(DisplayData.item("batchSize", spec.getBatchSize())); + } + + @Setup + public void initiateWrite() throws HCatException { + WriteEntity entity = + new WriteEntity.Builder() + .withDatabase(spec.getDatabase()) + .withTable(spec.getTable()) + .withPartition(spec.getFilter()) + .build(); + masterWriter = DataTransferFactory.getHCatWriter(entity, spec.getConfigProperties()); + writerContext = masterWriter.prepareWrite(); + slaveWriter = DataTransferFactory.getHCatWriter(writerContext); + } + + @StartBundle + public void startBundle() { + hCatRecordsBatch = new ArrayList<>(); + } + + @ProcessElement + public void processElement(ProcessContext ctx) throws HCatException { + hCatRecordsBatch.add(ctx.element()); + if (hCatRecordsBatch.size() >= spec.getBatchSize()) { + flush(); + } + } + + @FinishBundle + public void finishBundle() throws HCatException { + flush(); + } + + private void flush() throws HCatException { + if (hCatRecordsBatch.isEmpty()) { + return; + } + try { + slaveWriter.write(hCatRecordsBatch.iterator()); + masterWriter.commit(writerContext); + } catch (HCatException e) { + LOG.error("Exception in flush - write/commit data to Hive", e); + //abort on exception + masterWriter.abort(writerContext); + throw e; + } finally { + hCatRecordsBatch.clear(); + } + } + } + } +} diff --git a/sdks/java/io/hcatalog/src/main/java/org/apache/beam/sdk/io/hcatalog/package-info.java b/sdks/java/io/hcatalog/src/main/java/org/apache/beam/sdk/io/hcatalog/package-info.java new file mode 100644 index 0000000000000..dff5bd120e6c9 --- /dev/null +++ b/sdks/java/io/hcatalog/src/main/java/org/apache/beam/sdk/io/hcatalog/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Transforms for reading and writing using HCatalog. + */ +package org.apache.beam.sdk.io.hcatalog; diff --git a/sdks/java/io/hcatalog/src/test/java/org/apache/beam/sdk/io/hcatalog/EmbeddedMetastoreService.java b/sdks/java/io/hcatalog/src/test/java/org/apache/beam/sdk/io/hcatalog/EmbeddedMetastoreService.java new file mode 100644 index 0000000000000..5792bf6f810cd --- /dev/null +++ b/sdks/java/io/hcatalog/src/test/java/org/apache/beam/sdk/io/hcatalog/EmbeddedMetastoreService.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.hcatalog; + +import static org.apache.hive.hcatalog.common.HCatUtil.makePathASafeFileName; + +import java.io.File; +import java.io.IOException; +import org.apache.commons.io.FileUtils; +import org.apache.hadoop.hive.cli.CliSessionState; +import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.hadoop.hive.ql.CommandNeedRetryException; +import org.apache.hadoop.hive.ql.Driver; +import org.apache.hadoop.hive.ql.session.SessionState; + +/** + * Implementation of a light-weight embedded metastore. This class is a trimmed-down version of + * https://github.com/apache/hive/blob/master/hcatalog/core/src/test/java/org/apache/hive/hcatalog/mapreduce + * /HCatBaseTest.java + */ +public final class EmbeddedMetastoreService implements AutoCloseable { + private final Driver driver; + private final HiveConf hiveConf; + private final SessionState sessionState; + + EmbeddedMetastoreService(String baseDirPath) throws IOException { + FileUtils.forceDeleteOnExit(new File(baseDirPath)); + + String hiveDirPath = makePathASafeFileName(baseDirPath + "/hive"); + String testDataDirPath = + makePathASafeFileName( + hiveDirPath + + "/data/" + + EmbeddedMetastoreService.class.getCanonicalName() + + System.currentTimeMillis()); + String testWarehouseDirPath = makePathASafeFileName(testDataDirPath + "/warehouse"); + + hiveConf = new HiveConf(getClass()); + hiveConf.setVar(HiveConf.ConfVars.PREEXECHOOKS, ""); + hiveConf.setVar(HiveConf.ConfVars.POSTEXECHOOKS, ""); + hiveConf.setBoolVar(HiveConf.ConfVars.HIVE_SUPPORT_CONCURRENCY, false); + hiveConf.setVar(HiveConf.ConfVars.METASTOREWAREHOUSE, testWarehouseDirPath); + hiveConf.setVar(HiveConf.ConfVars.HIVEMAPREDMODE, "nonstrict"); + hiveConf.setBoolVar(HiveConf.ConfVars.HIVEOPTIMIZEMETADATAQUERIES, true); + hiveConf.setVar( + HiveConf.ConfVars.HIVE_AUTHORIZATION_MANAGER, + "org.apache.hadoop.hive.ql.security.authorization.plugin.sqlstd." + + "SQLStdHiveAuthorizerFactory"); + hiveConf.set("test.tmp.dir", hiveDirPath); + + System.setProperty("derby.stream.error.file", "/dev/null"); + driver = new Driver(hiveConf); + sessionState = SessionState.start(new CliSessionState(hiveConf)); + } + + /** Executes the passed query on the embedded metastore service. */ + void executeQuery(String query) throws CommandNeedRetryException { + driver.run(query); + } + + /** Returns the HiveConf object for the embedded metastore. */ + HiveConf getHiveConf() { + return hiveConf; + } + + @Override + public void close() throws Exception { + driver.close(); + sessionState.close(); + } +} diff --git a/sdks/java/io/hcatalog/src/test/java/org/apache/beam/sdk/io/hcatalog/HCatalogIOTest.java b/sdks/java/io/hcatalog/src/test/java/org/apache/beam/sdk/io/hcatalog/HCatalogIOTest.java new file mode 100644 index 0000000000000..49c538f0eb695 --- /dev/null +++ b/sdks/java/io/hcatalog/src/test/java/org/apache/beam/sdk/io/hcatalog/HCatalogIOTest.java @@ -0,0 +1,271 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.hcatalog; + +import static org.apache.beam.sdk.io.hcatalog.HCatalogIOTestUtils.TEST_RECORDS_COUNT; +import static org.apache.beam.sdk.io.hcatalog.HCatalogIOTestUtils.TEST_TABLE_NAME; +import static org.apache.beam.sdk.io.hcatalog.HCatalogIOTestUtils.getConfigPropertiesAsMap; +import static org.apache.beam.sdk.io.hcatalog.HCatalogIOTestUtils.getExpectedRecords; +import static org.apache.beam.sdk.io.hcatalog.HCatalogIOTestUtils.getHCatRecords; +import static org.apache.beam.sdk.io.hcatalog.HCatalogIOTestUtils.getReaderContext; +import static org.apache.beam.sdk.io.hcatalog.HCatalogIOTestUtils.insertTestData; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.isA; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; + +import java.io.IOException; +import java.io.Serializable; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; +import java.util.ArrayList; +import java.util.List; +import org.apache.beam.sdk.io.BoundedSource; +import org.apache.beam.sdk.io.hcatalog.HCatalogIO.BoundedHCatalogSource; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.SourceTestUtils; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.util.UserCodeException; +import org.apache.beam.sdk.values.PCollection; +import org.apache.hadoop.hive.metastore.api.NoSuchObjectException; +import org.apache.hadoop.hive.ql.CommandNeedRetryException; +import org.apache.hive.hcatalog.data.HCatRecord; +import org.apache.hive.hcatalog.data.transfer.ReaderContext; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.rules.TemporaryFolder; +import org.junit.rules.TestRule; +import org.junit.rules.TestWatcher; +import org.junit.runner.Description; +import org.junit.runners.model.Statement; + +/** Test for HCatalogIO. */ +public class HCatalogIOTest implements Serializable { + public static final PipelineOptions OPTIONS = PipelineOptionsFactory.create(); + + @ClassRule + public static final TemporaryFolder TMP_FOLDER = new TemporaryFolder(); + + @Rule public final transient TestPipeline defaultPipeline = TestPipeline.create(); + + @Rule public final transient TestPipeline readAfterWritePipeline = TestPipeline.create(); + + @Rule public transient ExpectedException thrown = ExpectedException.none(); + + @Rule + public final transient TestRule testDataSetupRule = + new TestWatcher() { + public Statement apply(final Statement base, final Description description) { + return new Statement() { + @Override + public void evaluate() throws Throwable { + if (description.getAnnotation(NeedsTestData.class) != null) { + prepareTestData(); + } else if (description.getAnnotation(NeedsEmptyTestTables.class) != null) { + reCreateTestTable(); + } + base.evaluate(); + } + }; + } + }; + + private static EmbeddedMetastoreService service; + + /** Use this annotation to setup complete test data(table populated with records). */ + @Retention(RetentionPolicy.RUNTIME) + @Target({ElementType.METHOD}) + @interface NeedsTestData {} + + /** Use this annotation to setup test tables alone(empty tables, no records are populated). */ + @Retention(RetentionPolicy.RUNTIME) + @Target({ElementType.METHOD}) + @interface NeedsEmptyTestTables {} + + @BeforeClass + public static void setupEmbeddedMetastoreService () throws IOException { + service = new EmbeddedMetastoreService(TMP_FOLDER.getRoot().getAbsolutePath()); + } + + @AfterClass + public static void shutdownEmbeddedMetastoreService () throws Exception { + service.executeQuery("drop table " + TEST_TABLE_NAME); + service.close(); + } + + /** Perform end-to-end test of Write-then-Read operation. */ + @Test + @NeedsEmptyTestTables + public void testWriteThenReadSuccess() throws Exception { + defaultPipeline + .apply(Create.of(getHCatRecords(TEST_RECORDS_COUNT))) + .apply( + HCatalogIO.write() + .withConfigProperties(getConfigPropertiesAsMap(service.getHiveConf())) + .withTable(TEST_TABLE_NAME)); + defaultPipeline.run(); + + PCollection output = + readAfterWritePipeline + .apply( + HCatalogIO.read() + .withConfigProperties(getConfigPropertiesAsMap(service.getHiveConf())) + .withTable(HCatalogIOTestUtils.TEST_TABLE_NAME)) + .apply( + ParDo.of( + new DoFn() { + @ProcessElement + public void processElement(ProcessContext c) { + c.output(c.element().get(0).toString()); + } + })); + PAssert.that(output).containsInAnyOrder(getExpectedRecords(TEST_RECORDS_COUNT)); + readAfterWritePipeline.run(); + } + + /** Test of Write to a non-existent table. */ + @Test + public void testWriteFailureTableDoesNotExist() throws Exception { + thrown.expectCause(isA(UserCodeException.class)); + thrown.expectMessage(containsString("org.apache.hive.hcatalog.common.HCatException")); + thrown.expectMessage(containsString("NoSuchObjectException")); + defaultPipeline + .apply(Create.of(getHCatRecords(TEST_RECORDS_COUNT))) + .apply( + HCatalogIO.write() + .withConfigProperties(getConfigPropertiesAsMap(service.getHiveConf())) + .withTable("myowntable")); + defaultPipeline.run(); + } + + /** Test of Write without specifying a table. */ + @Test + public void testWriteFailureValidationTable() throws Exception { + thrown.expect(NullPointerException.class); + thrown.expectMessage(containsString("table")); + HCatalogIO.write() + .withConfigProperties(getConfigPropertiesAsMap(service.getHiveConf())) + .validate(null); + } + + /** Test of Write without specifying configuration properties. */ + @Test + public void testWriteFailureValidationConfigProp() throws Exception { + thrown.expect(NullPointerException.class); + thrown.expectMessage(containsString("configProperties")); + HCatalogIO.write().withTable("myowntable").validate(null); + } + + /** Test of Read from a non-existent table. */ + @Test + public void testReadFailureTableDoesNotExist() throws Exception { + defaultPipeline.apply( + HCatalogIO.read() + .withConfigProperties(getConfigPropertiesAsMap(service.getHiveConf())) + .withTable("myowntable")); + thrown.expectCause(isA(NoSuchObjectException.class)); + defaultPipeline.run(); + } + + /** Test of Read without specifying configuration properties. */ + @Test + public void testReadFailureValidationConfig() throws Exception { + thrown.expect(NullPointerException.class); + thrown.expectMessage(containsString("configProperties")); + HCatalogIO.read().withTable("myowntable").validate(null); + } + + /** Test of Read without specifying a table. */ + @Test + public void testReadFailureValidationTable() throws Exception { + thrown.expect(NullPointerException.class); + thrown.expectMessage(containsString("table")); + HCatalogIO.read() + .withConfigProperties(getConfigPropertiesAsMap(service.getHiveConf())) + .validate(null); + } + + /** Test of Read using SourceTestUtils.readFromSource(..). */ + @Test + @NeedsTestData + public void testReadFromSource() throws Exception { + ReaderContext context = getReaderContext(getConfigPropertiesAsMap(service.getHiveConf())); + HCatalogIO.Read spec = + HCatalogIO.read() + .withConfigProperties(getConfigPropertiesAsMap(service.getHiveConf())) + .withContext(context) + .withTable(TEST_TABLE_NAME); + + List records = new ArrayList<>(); + for (int i = 0; i < context.numSplits(); i++) { + BoundedHCatalogSource source = new BoundedHCatalogSource(spec.withSplitId(i)); + for (HCatRecord record : SourceTestUtils.readFromSource(source, OPTIONS)) { + records.add(record.get(0).toString()); + } + } + assertThat(records, containsInAnyOrder(getExpectedRecords(TEST_RECORDS_COUNT).toArray())); + } + + /** Test of Read using SourceTestUtils.assertSourcesEqualReferenceSource(..). */ + @Test + @NeedsTestData + public void testSourceEqualsSplits() throws Exception { + final int numRows = 1500; + final int numSamples = 10; + final long bytesPerRow = 15; + ReaderContext context = getReaderContext(getConfigPropertiesAsMap(service.getHiveConf())); + HCatalogIO.Read spec = + HCatalogIO.read() + .withConfigProperties(getConfigPropertiesAsMap(service.getHiveConf())) + .withContext(context) + .withTable(TEST_TABLE_NAME); + + BoundedHCatalogSource source = new BoundedHCatalogSource(spec); + List> unSplitSource = source.split(-1, OPTIONS); + assertEquals(1, unSplitSource.size()); + + List> splits = + source.split(numRows * bytesPerRow / numSamples, OPTIONS); + assertTrue(splits.size() >= 1); + + SourceTestUtils.assertSourcesEqualReferenceSource(unSplitSource.get(0), splits, OPTIONS); + } + + private void reCreateTestTable() throws CommandNeedRetryException { + service.executeQuery("drop table " + TEST_TABLE_NAME); + service.executeQuery("create table " + TEST_TABLE_NAME + "(mycol1 string, mycol2 int)"); + } + + private void prepareTestData() throws Exception { + reCreateTestTable(); + insertTestData(getConfigPropertiesAsMap(service.getHiveConf())); + } +} diff --git a/sdks/java/io/hcatalog/src/test/java/org/apache/beam/sdk/io/hcatalog/HCatalogIOTestUtils.java b/sdks/java/io/hcatalog/src/test/java/org/apache/beam/sdk/io/hcatalog/HCatalogIOTestUtils.java new file mode 100644 index 0000000000000..f66e0bcc1e8fa --- /dev/null +++ b/sdks/java/io/hcatalog/src/test/java/org/apache/beam/sdk/io/hcatalog/HCatalogIOTestUtils.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.hcatalog; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.hive.hcatalog.common.HCatException; +import org.apache.hive.hcatalog.data.DefaultHCatRecord; +import org.apache.hive.hcatalog.data.HCatRecord; +import org.apache.hive.hcatalog.data.transfer.DataTransferFactory; +import org.apache.hive.hcatalog.data.transfer.ReadEntity; +import org.apache.hive.hcatalog.data.transfer.ReaderContext; +import org.apache.hive.hcatalog.data.transfer.WriteEntity; +import org.apache.hive.hcatalog.data.transfer.WriterContext; + +/** Utility class for HCatalogIOTest. */ +public class HCatalogIOTestUtils { + public static final String TEST_TABLE_NAME = "mytable"; + + public static final int TEST_RECORDS_COUNT = 1000; + + private static final ReadEntity READ_ENTITY = + new ReadEntity.Builder().withTable(TEST_TABLE_NAME).build(); + private static final WriteEntity WRITE_ENTITY = + new WriteEntity.Builder().withTable(TEST_TABLE_NAME).build(); + + /** Returns a ReaderContext instance for the passed datastore config params. */ + static ReaderContext getReaderContext(Map config) throws HCatException { + return DataTransferFactory.getHCatReader(READ_ENTITY, config).prepareRead(); + } + + /** Returns a WriterContext instance for the passed datastore config params. */ + static WriterContext getWriterContext(Map config) throws HCatException { + return DataTransferFactory.getHCatWriter(WRITE_ENTITY, config).prepareWrite(); + } + + /** Writes records to the table using the passed WriterContext. */ + static void writeRecords(WriterContext context) throws HCatException { + DataTransferFactory.getHCatWriter(context).write(getHCatRecords(TEST_RECORDS_COUNT).iterator()); + } + + /** Commits the pending writes to the database. */ + static void commitRecords(Map config, WriterContext context) throws IOException { + DataTransferFactory.getHCatWriter(WRITE_ENTITY, config).commit(context); + } + + /** Returns a list of strings containing 'expected' test data for verification. */ + static List getExpectedRecords(int count) { + List expected = new ArrayList<>(); + for (int i = 0; i < count; i++) { + expected.add("record " + i); + } + return expected; + } + + /** Returns a list of HCatRecords of passed size. */ + static List getHCatRecords(int size) { + List expected = new ArrayList<>(); + for (int i = 0; i < size; i++) { + expected.add(toHCatRecord(i)); + } + return expected; + } + + /** Inserts data into test datastore. */ + static void insertTestData(Map configMap) throws Exception { + WriterContext cntxt = getWriterContext(configMap); + writeRecords(cntxt); + commitRecords(configMap, cntxt); + } + + /** Returns config params for the test datastore as a Map. */ + static Map getConfigPropertiesAsMap(HiveConf hiveConf) { + Map map = new HashMap<>(); + for (Entry kv : hiveConf) { + map.put(kv.getKey(), kv.getValue()); + } + return map; + } + + /** returns a DefaultHCatRecord instance for passed value. */ + static DefaultHCatRecord toHCatRecord(int value) { + return new DefaultHCatRecord(Arrays.asList("record " + value, value)); + } +} diff --git a/sdks/java/io/hcatalog/src/test/resources/hive-site.xml b/sdks/java/io/hcatalog/src/test/resources/hive-site.xml new file mode 100644 index 0000000000000..5bb1496c2c901 --- /dev/null +++ b/sdks/java/io/hcatalog/src/test/resources/hive-site.xml @@ -0,0 +1,301 @@ + + + + + + + + hive.in.test + true + Internal marker for test. Used for masking env-dependent values + + + + + + + + + + + hadoop.tmp.dir + ${test.tmp.dir}/hadoop-tmp + A base for other temporary directories. + + + + + + hive.exec.scratchdir + ${test.tmp.dir}/scratchdir + Scratch space for Hive jobs + + + + hive.exec.local.scratchdir + ${test.tmp.dir}/localscratchdir/ + Local scratch space for Hive jobs + + + + datanucleus.schema.autoCreateAll + true + + + + javax.jdo.option.ConnectionURL + jdbc:derby:;databaseName=${test.tmp.dir}/junit_metastore_db;create=true + + + + javax.jdo.option.ConnectionDriverName + org.apache.derby.jdbc.EmbeddedDriver + + + + javax.jdo.option.ConnectionUserName + APP + + + + javax.jdo.option.ConnectionPassword + mine + + + + + hive.metastore.warehouse.dir + ${test.warehouse.dir} + + + + + hive.metastore.metadb.dir + file://${test.tmp.dir}/metadb/ + + Required by metastore server or if the uris argument below is not supplied + + + + + test.log.dir + ${test.tmp.dir}/log/ + + + + + test.data.files + ${hive.root}/data/files + + + + + test.data.scripts + ${hive.root}/data/scripts + + + + + hive.jar.path + ${maven.local.repository}/org/apache/hive/hive-exec/${hive.version}/hive-exec-${hive.version}.jar + + + + + hive.metastore.rawstore.impl + org.apache.hadoop.hive.metastore.ObjectStore + Name of the class that implements org.apache.hadoop.hive.metastore.rawstore interface. This class is used to store and retrieval of raw metadata objects such as table, database + + + + hive.querylog.location + ${test.tmp.dir}/tmp + Location of the structured hive logs + + + + hive.exec.pre.hooks + org.apache.hadoop.hive.ql.hooks.PreExecutePrinter, org.apache.hadoop.hive.ql.hooks.EnforceReadOnlyTables + Pre Execute Hook for Tests + + + + hive.exec.post.hooks + org.apache.hadoop.hive.ql.hooks.PostExecutePrinter + Post Execute Hook for Tests + + + + hive.support.concurrency + true + Whether hive supports concurrency or not. A zookeeper instance must be up and running for the default hive lock manager to support read-write locks. + + + + hive.unlock.numretries + 2 + The number of times you want to retry to do one unlock + + + + hive.lock.sleep.between.retries + 2 + The sleep time (in seconds) between various retries + + + + + fs.pfile.impl + org.apache.hadoop.fs.ProxyLocalFileSystem + A proxy for local file system used for cross file system testing + + + + hive.exec.mode.local.auto + false + + Let hive determine whether to run in local mode automatically + Disabling this for tests so that minimr is not affected + + + + + hive.auto.convert.join + false + Whether Hive enable the optimization about converting common join into mapjoin based on the input file size + + + + hive.ignore.mapjoin.hint + false + Whether Hive ignores the mapjoin hint + + + + hive.input.format + org.apache.hadoop.hive.ql.io.CombineHiveInputFormat + The default input format, if it is not specified, the system assigns it. It is set to HiveInputFormat for hadoop versions 17, 18 and 19, whereas it is set to CombineHiveInputFormat for hadoop 20. The user can always overwrite it - if there is a bug in CombineHiveInputFormat, it can always be manually set to HiveInputFormat. + + + + hive.default.rcfile.serde + org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe + The default SerDe hive will use for the rcfile format + + + + hive.stats.key.prefix.reserve.length + 0 + + + + hive.conf.restricted.list + dummy.config.value + Using dummy config value above because you cannot override config with empty value + + + + hive.exec.submit.local.task.via.child + false + + + + + hive.dummyparam.test.server.specific.config.override + from.hive-site.xml + Using dummy param to test server specific configuration + + + + hive.dummyparam.test.server.specific.config.hivesite + from.hive-site.xml + Using dummy param to test server specific configuration + + + + test.var.hiveconf.property + ${hive.exec.default.partition.name} + Test hiveconf property substitution + + + + test.property1 + value1 + Test property defined in hive-site.xml only + + + + hive.test.dummystats.aggregator + value2 + + + + hive.fetch.task.conversion + minimal + + + + hive.users.in.admin.role + hive_admin_user + + + + hive.llap.io.cache.orc.size + 8388608 + + + + hive.llap.io.cache.orc.arena.size + 8388608 + + + + hive.llap.io.cache.orc.alloc.max + 2097152 + + + + + hive.llap.io.cache.orc.alloc.min + 32768 + + + + hive.llap.cache.allow.synthetic.fileid + true + + + + hive.llap.io.use.lrfu + true + + + + + hive.llap.io.allocator.direct + false + + + + diff --git a/sdks/java/io/pom.xml b/sdks/java/io/pom.xml index 44f3baa6d5f26..13cd418355daa 100644 --- a/sdks/java/io/pom.xml +++ b/sdks/java/io/pom.xml @@ -72,6 +72,7 @@ hadoop-file-system hadoop hbase + hcatalog jdbc jms kafka From 82a6cb6104c0bb179832dabf4433c183743ea983 Mon Sep 17 00:00:00 2001 From: Eugene Kirpichov Date: Thu, 8 Jun 2017 14:51:15 -0700 Subject: [PATCH 012/200] Slight debuggability improvements in BigtableIO - ByteKeyRangeTracker.splitAtPosition logs the "insane" case first. - BigtableIO logs the split position at INFO --- .../sdk/io/range/ByteKeyRangeTracker.java | 22 +++++++++++-------- .../beam/sdk/io/gcp/bigtable/BigtableIO.java | 2 +- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/range/ByteKeyRangeTracker.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/range/ByteKeyRangeTracker.java index 99717a4bffa9b..b889ec755fc4c 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/range/ByteKeyRangeTracker.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/range/ByteKeyRangeTracker.java @@ -71,6 +71,10 @@ public synchronized boolean tryReturnRecordAt(boolean isAtSplitPoint, ByteKey re "Trying to return record which is before the last-returned record"); if (position == null) { + LOG.info( + "Adjusting range start from {} to {} as position of first returned record", + range.getStartKey(), + recordStart); range = range.withStartKey(recordStart); } position = recordStart; @@ -87,6 +91,15 @@ public synchronized boolean tryReturnRecordAt(boolean isAtSplitPoint, ByteKey re @Override public synchronized boolean trySplitAtPosition(ByteKey splitPosition) { + // Sanity check. + if (!range.containsKey(splitPosition)) { + LOG.warn( + "{}: Rejecting split request at {} because it is not within the range.", + this, + splitPosition); + return false; + } + // Unstarted. if (position == null) { LOG.warn( @@ -106,15 +119,6 @@ public synchronized boolean trySplitAtPosition(ByteKey splitPosition) { return false; } - // Sanity check. - if (!range.containsKey(splitPosition)) { - LOG.warn( - "{}: Rejecting split request at {} because it is not within the range.", - this, - splitPosition); - return false; - } - range = range.withEndKey(splitPosition); return true; } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableIO.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableIO.java index 1692cda65623a..62679bb507d7f 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableIO.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableIO.java @@ -1027,7 +1027,7 @@ public final synchronized BigtableSource splitAtFraction(double fraction) { "{}: Failed to interpolate key for fraction {}.", rangeTracker.getRange(), fraction, e); return null; } - LOG.debug( + LOG.info( "Proposing to split {} at fraction {} (key {})", rangeTracker, fraction, splitKey); BigtableSource primary; BigtableSource residual; From 7689e43ac4b88c85962cea14d65d788ad27dbe93 Mon Sep 17 00:00:00 2001 From: Charles Chen Date: Wed, 7 Jun 2017 16:08:43 -0700 Subject: [PATCH 013/200] Move Runner API protos to portability/runners/api This fixes a circular import issue between transforms/ and runners/ --- .gitignore | 2 +- sdks/python/apache_beam/coders/coders.py | 2 +- sdks/python/apache_beam/pipeline.py | 4 ++-- .../python/apache_beam/portability/__init__.py | 18 ++++++++++++++++++ .../portability/runners/__init__.py | 18 ++++++++++++++++++ .../{ => portability}/runners/api/__init__.py | 0 sdks/python/apache_beam/pvalue.py | 2 +- .../runners/dataflow/dataflow_runner.py | 4 ++-- .../apache_beam/runners/pipeline_context.py | 2 +- .../runners/portability/fn_api_runner.py | 2 +- .../apache_beam/runners/worker/data_plane.py | 2 +- .../runners/worker/data_plane_test.py | 2 +- .../apache_beam/runners/worker/log_handler.py | 2 +- .../runners/worker/log_handler_test.py | 2 +- .../apache_beam/runners/worker/sdk_worker.py | 2 +- .../runners/worker/sdk_worker_main.py | 2 +- .../runners/worker/sdk_worker_test.py | 2 +- sdks/python/apache_beam/transforms/core.py | 2 +- .../apache_beam/transforms/ptransform.py | 2 +- sdks/python/apache_beam/transforms/trigger.py | 2 +- sdks/python/apache_beam/transforms/window.py | 4 ++-- sdks/python/apache_beam/utils/urns.py | 2 +- sdks/python/gen_protos.py | 2 +- sdks/python/run_pylint.sh | 2 +- 24 files changed, 60 insertions(+), 24 deletions(-) create mode 100644 sdks/python/apache_beam/portability/__init__.py create mode 100644 sdks/python/apache_beam/portability/runners/__init__.py rename sdks/python/apache_beam/{ => portability}/runners/api/__init__.py (100%) diff --git a/.gitignore b/.gitignore index bd419a78beec5..631d7f32cb965 100644 --- a/.gitignore +++ b/.gitignore @@ -25,7 +25,7 @@ sdks/python/**/*.egg sdks/python/LICENSE sdks/python/NOTICE sdks/python/README.md -sdks/python/apache_beam/runners/api/*pb2*.* +sdks/python/apache_beam/portability/runners/api/*pb2*.* # Ignore IntelliJ files. .idea/ diff --git a/sdks/python/apache_beam/coders/coders.py b/sdks/python/apache_beam/coders/coders.py index f3e0b432e51ca..1be1f3c7a4775 100644 --- a/sdks/python/apache_beam/coders/coders.py +++ b/sdks/python/apache_beam/coders/coders.py @@ -25,6 +25,7 @@ import google.protobuf from apache_beam.coders import coder_impl +from apache_beam.portability.runners.api import beam_runner_api_pb2 from apache_beam.utils import urns from apache_beam.utils import proto_utils @@ -205,7 +206,6 @@ def to_runner_api(self, context): """For internal use only; no backwards-compatibility guarantees. """ # TODO(BEAM-115): Use specialized URNs and components. - from apache_beam.runners.api import beam_runner_api_pb2 return beam_runner_api_pb2.Coder( spec=beam_runner_api_pb2.SdkFunctionSpec( spec=beam_runner_api_pb2.FunctionSpec( diff --git a/sdks/python/apache_beam/pipeline.py b/sdks/python/apache_beam/pipeline.py index 9093abfccfc3b..cea7215b2b822 100644 --- a/sdks/python/apache_beam/pipeline.py +++ b/sdks/python/apache_beam/pipeline.py @@ -339,7 +339,7 @@ def visit_value(self, value, _): def to_runner_api(self): """For internal use only; no backwards-compatibility guarantees.""" from apache_beam.runners import pipeline_context - from apache_beam.runners.api import beam_runner_api_pb2 + from apache_beam.portability.runners.api import beam_runner_api_pb2 context = pipeline_context.PipelineContext() # Mutates context; placing inline would force dependence on # argument evaluation order. @@ -525,7 +525,7 @@ def named_outputs(self): if isinstance(output, pvalue.PCollection)} def to_runner_api(self, context): - from apache_beam.runners.api import beam_runner_api_pb2 + from apache_beam.portability.runners.api import beam_runner_api_pb2 def transform_to_runner_api(transform, context): if transform is None: diff --git a/sdks/python/apache_beam/portability/__init__.py b/sdks/python/apache_beam/portability/__init__.py new file mode 100644 index 0000000000000..0bce5d68f7243 --- /dev/null +++ b/sdks/python/apache_beam/portability/__init__.py @@ -0,0 +1,18 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""For internal use only; no backwards-compatibility guarantees.""" diff --git a/sdks/python/apache_beam/portability/runners/__init__.py b/sdks/python/apache_beam/portability/runners/__init__.py new file mode 100644 index 0000000000000..0bce5d68f7243 --- /dev/null +++ b/sdks/python/apache_beam/portability/runners/__init__.py @@ -0,0 +1,18 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""For internal use only; no backwards-compatibility guarantees.""" diff --git a/sdks/python/apache_beam/runners/api/__init__.py b/sdks/python/apache_beam/portability/runners/api/__init__.py similarity index 100% rename from sdks/python/apache_beam/runners/api/__init__.py rename to sdks/python/apache_beam/portability/runners/api/__init__.py diff --git a/sdks/python/apache_beam/pvalue.py b/sdks/python/apache_beam/pvalue.py index 7385e82c3a5c2..8a774c4c5bf45 100644 --- a/sdks/python/apache_beam/pvalue.py +++ b/sdks/python/apache_beam/pvalue.py @@ -128,7 +128,7 @@ def __reduce_ex__(self, unused_version): return _InvalidUnpickledPCollection, () def to_runner_api(self, context): - from apache_beam.runners.api import beam_runner_api_pb2 + from apache_beam.portability.runners.api import beam_runner_api_pb2 from apache_beam.internal import pickler return beam_runner_api_pb2.PCollection( unique_name='%d%s.%s' % ( diff --git a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py b/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py index d9aa1bf098d7a..a6cc25d715127 100644 --- a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py +++ b/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py @@ -732,7 +732,7 @@ def run__NativeWrite(self, transform_node): @classmethod def serialize_windowing_strategy(cls, windowing): from apache_beam.runners import pipeline_context - from apache_beam.runners.api import beam_runner_api_pb2 + from apache_beam.portability.runners.api import beam_runner_api_pb2 context = pipeline_context.PipelineContext() windowing_proto = windowing.to_runner_api(context) return cls.byte_array_to_json_string( @@ -745,7 +745,7 @@ def deserialize_windowing_strategy(cls, serialized_data): # Imported here to avoid circular dependencies. # pylint: disable=wrong-import-order, wrong-import-position from apache_beam.runners import pipeline_context - from apache_beam.runners.api import beam_runner_api_pb2 + from apache_beam.portability.runners.api import beam_runner_api_pb2 from apache_beam.transforms.core import Windowing proto = beam_runner_api_pb2.MessageWithComponents() proto.ParseFromString(cls.json_string_to_byte_array(serialized_data)) diff --git a/sdks/python/apache_beam/runners/pipeline_context.py b/sdks/python/apache_beam/runners/pipeline_context.py index 1c89d0652a26e..1330c3904edfe 100644 --- a/sdks/python/apache_beam/runners/pipeline_context.py +++ b/sdks/python/apache_beam/runners/pipeline_context.py @@ -24,7 +24,7 @@ from apache_beam import pipeline from apache_beam import pvalue from apache_beam import coders -from apache_beam.runners.api import beam_runner_api_pb2 +from apache_beam.portability.runners.api import beam_runner_api_pb2 from apache_beam.transforms import core diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner.py b/sdks/python/apache_beam/runners/portability/fn_api_runner.py index db34ef9671763..a83eae403701e 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner.py @@ -33,7 +33,7 @@ from apache_beam.internal import pickler from apache_beam.io import iobase from apache_beam.transforms.window import GlobalWindows -from apache_beam.runners.api import beam_fn_api_pb2 +from apache_beam.portability.runners.api import beam_fn_api_pb2 from apache_beam.runners.portability import maptask_executor_runner from apache_beam.runners.worker import data_plane from apache_beam.runners.worker import operation_specs diff --git a/sdks/python/apache_beam/runners/worker/data_plane.py b/sdks/python/apache_beam/runners/worker/data_plane.py index 7365db69f56a1..734ee9cda36a3 100644 --- a/sdks/python/apache_beam/runners/worker/data_plane.py +++ b/sdks/python/apache_beam/runners/worker/data_plane.py @@ -28,7 +28,7 @@ import threading from apache_beam.coders import coder_impl -from apache_beam.runners.api import beam_fn_api_pb2 +from apache_beam.portability.runners.api import beam_fn_api_pb2 import grpc # This module is experimental. No backwards-compatibility guarantees. diff --git a/sdks/python/apache_beam/runners/worker/data_plane_test.py b/sdks/python/apache_beam/runners/worker/data_plane_test.py index e3e01ac5971f3..a2b31e8eb72ad 100644 --- a/sdks/python/apache_beam/runners/worker/data_plane_test.py +++ b/sdks/python/apache_beam/runners/worker/data_plane_test.py @@ -29,7 +29,7 @@ from concurrent import futures import grpc -from apache_beam.runners.api import beam_fn_api_pb2 +from apache_beam.portability.runners.api import beam_fn_api_pb2 from apache_beam.runners.worker import data_plane diff --git a/sdks/python/apache_beam/runners/worker/log_handler.py b/sdks/python/apache_beam/runners/worker/log_handler.py index 59ffbf4f45a2d..dca0e4bd11b4b 100644 --- a/sdks/python/apache_beam/runners/worker/log_handler.py +++ b/sdks/python/apache_beam/runners/worker/log_handler.py @@ -21,7 +21,7 @@ import Queue as queue import threading -from apache_beam.runners.api import beam_fn_api_pb2 +from apache_beam.portability.runners.api import beam_fn_api_pb2 import grpc # This module is experimental. No backwards-compatibility guarantees. diff --git a/sdks/python/apache_beam/runners/worker/log_handler_test.py b/sdks/python/apache_beam/runners/worker/log_handler_test.py index 8720ca8a3f8b1..6dd018f6ad457 100644 --- a/sdks/python/apache_beam/runners/worker/log_handler_test.py +++ b/sdks/python/apache_beam/runners/worker/log_handler_test.py @@ -22,7 +22,7 @@ from concurrent import futures import grpc -from apache_beam.runners.api import beam_fn_api_pb2 +from apache_beam.portability.runners.api import beam_fn_api_pb2 from apache_beam.runners.worker import log_handler diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker.py b/sdks/python/apache_beam/runners/worker/sdk_worker.py index 33c50adfa4bb6..33f2b61dd3883 100644 --- a/sdks/python/apache_beam/runners/worker/sdk_worker.py +++ b/sdks/python/apache_beam/runners/worker/sdk_worker.py @@ -38,7 +38,7 @@ from apache_beam.io import iobase from apache_beam.runners.dataflow.native_io import iobase as native_iobase from apache_beam.utils import counters -from apache_beam.runners.api import beam_fn_api_pb2 +from apache_beam.portability.runners.api import beam_fn_api_pb2 from apache_beam.runners.worker import operation_specs from apache_beam.runners.worker import operations diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker_main.py b/sdks/python/apache_beam/runners/worker/sdk_worker_main.py index b8917791b91eb..9c11068a972e2 100644 --- a/sdks/python/apache_beam/runners/worker/sdk_worker_main.py +++ b/sdks/python/apache_beam/runners/worker/sdk_worker_main.py @@ -24,7 +24,7 @@ import grpc from google.protobuf import text_format -from apache_beam.runners.api import beam_fn_api_pb2 +from apache_beam.portability.runners.api import beam_fn_api_pb2 from apache_beam.runners.worker.log_handler import FnApiLogRecordHandler from apache_beam.runners.worker.sdk_worker import SdkHarness diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker_test.py b/sdks/python/apache_beam/runners/worker/sdk_worker_test.py index 0d0811b22201b..93f60d3e7424a 100644 --- a/sdks/python/apache_beam/runners/worker/sdk_worker_test.py +++ b/sdks/python/apache_beam/runners/worker/sdk_worker_test.py @@ -29,7 +29,7 @@ from apache_beam.io.concat_source_test import RangeSource from apache_beam.io.iobase import SourceBundle -from apache_beam.runners.api import beam_fn_api_pb2 +from apache_beam.portability.runners.api import beam_fn_api_pb2 from apache_beam.runners.worker import data_plane from apache_beam.runners.worker import sdk_worker diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py index 0e497f9448d10..d7fa770af3be7 100644 --- a/sdks/python/apache_beam/transforms/core.py +++ b/sdks/python/apache_beam/transforms/core.py @@ -27,7 +27,7 @@ from apache_beam import typehints from apache_beam.coders import typecoders from apache_beam.internal import util -from apache_beam.runners.api import beam_runner_api_pb2 +from apache_beam.portability.runners.api import beam_runner_api_pb2 from apache_beam.transforms import ptransform from apache_beam.transforms.display import DisplayDataItem from apache_beam.transforms.display import HasDisplayData diff --git a/sdks/python/apache_beam/transforms/ptransform.py b/sdks/python/apache_beam/transforms/ptransform.py index bd2a120b8c630..79fe3add26ddf 100644 --- a/sdks/python/apache_beam/transforms/ptransform.py +++ b/sdks/python/apache_beam/transforms/ptransform.py @@ -430,7 +430,7 @@ def register_urn(cls, urn, parameter_type, constructor): cls._known_urns[urn] = parameter_type, constructor def to_runner_api(self, context): - from apache_beam.runners.api import beam_runner_api_pb2 + from apache_beam.portability.runners.api import beam_runner_api_pb2 urn, typed_param = self.to_runner_api_parameter(context) return beam_runner_api_pb2.FunctionSpec( urn=urn, diff --git a/sdks/python/apache_beam/transforms/trigger.py b/sdks/python/apache_beam/transforms/trigger.py index 42009958552ad..41516070e8e3f 100644 --- a/sdks/python/apache_beam/transforms/trigger.py +++ b/sdks/python/apache_beam/transforms/trigger.py @@ -33,7 +33,7 @@ from apache_beam.transforms.window import TimestampCombiner from apache_beam.transforms.window import WindowedValue from apache_beam.transforms.window import WindowFn -from apache_beam.runners.api import beam_runner_api_pb2 +from apache_beam.portability.runners.api import beam_runner_api_pb2 from apache_beam.utils.timestamp import MAX_TIMESTAMP from apache_beam.utils.timestamp import MIN_TIMESTAMP diff --git a/sdks/python/apache_beam/transforms/window.py b/sdks/python/apache_beam/transforms/window.py index e87a00763897f..08c7a2d132f7b 100644 --- a/sdks/python/apache_beam/transforms/window.py +++ b/sdks/python/apache_beam/transforms/window.py @@ -55,8 +55,8 @@ from google.protobuf import timestamp_pb2 from apache_beam.coders import coders -from apache_beam.runners.api import beam_runner_api_pb2 -from apache_beam.runners.api import standard_window_fns_pb2 +from apache_beam.portability.runners.api import beam_runner_api_pb2 +from apache_beam.portability.runners.api import standard_window_fns_pb2 from apache_beam.transforms import timeutil from apache_beam.utils import proto_utils from apache_beam.utils import urns diff --git a/sdks/python/apache_beam/utils/urns.py b/sdks/python/apache_beam/utils/urns.py index 849b8e37dfbdc..b925bcc9fbcdd 100644 --- a/sdks/python/apache_beam/utils/urns.py +++ b/sdks/python/apache_beam/utils/urns.py @@ -102,7 +102,7 @@ def to_runner_api(self, context): Prefer overriding self.to_runner_api_parameter. """ - from apache_beam.runners.api import beam_runner_api_pb2 + from apache_beam.portability.runners.api import beam_runner_api_pb2 urn, typed_param = self.to_runner_api_parameter(context) return beam_runner_api_pb2.SdkFunctionSpec( spec=beam_runner_api_pb2.FunctionSpec( diff --git a/sdks/python/gen_protos.py b/sdks/python/gen_protos.py index 92b8414583eaf..a33c74b9cd92d 100644 --- a/sdks/python/gen_protos.py +++ b/sdks/python/gen_protos.py @@ -35,7 +35,7 @@ os.path.join('..', 'common', 'fn-api', 'src', 'main', 'proto') ] -PYTHON_OUTPUT_PATH = os.path.join('apache_beam', 'runners', 'api') +PYTHON_OUTPUT_PATH = os.path.join('apache_beam', 'portability', 'runners', 'api') def generate_proto_files(): diff --git a/sdks/python/run_pylint.sh b/sdks/python/run_pylint.sh index 7808136a49252..7434516bfdeb4 100755 --- a/sdks/python/run_pylint.sh +++ b/sdks/python/run_pylint.sh @@ -46,7 +46,7 @@ EXCLUDED_GENERATED_FILES=( "apache_beam/io/gcp/internal/clients/storage/storage_v1_client.py" "apache_beam/io/gcp/internal/clients/storage/storage_v1_messages.py" "apache_beam/coders/proto2_coder_test_messages_pb2.py" -apache_beam/runners/api/*pb2*.py +apache_beam/portability/runners/api/*pb2*.py ) FILES_TO_IGNORE="" From 7caea7a845eff072a647baf69b9b004db4523652 Mon Sep 17 00:00:00 2001 From: Etienne Chauchot Date: Mon, 5 Jun 2017 16:21:58 +0200 Subject: [PATCH 014/200] [BEAM-2410] Remove TransportClient from ElasticSearchIO to decouple IO and ES server versions --- .../sdk/io/common/IOTestPipelineOptions.java | 6 +- .../sdk/io/elasticsearch/ElasticsearchIO.java | 4 +- .../ElasticSearchIOTestUtils.java | 81 ++++++++++--------- .../io/elasticsearch/ElasticsearchIOIT.java | 14 ++-- .../io/elasticsearch/ElasticsearchIOTest.java | 36 +++++---- .../ElasticsearchTestDataSet.java | 37 +++------ 6 files changed, 87 insertions(+), 91 deletions(-) diff --git a/sdks/java/io/common/src/test/java/org/apache/beam/sdk/io/common/IOTestPipelineOptions.java b/sdks/java/io/common/src/test/java/org/apache/beam/sdk/io/common/IOTestPipelineOptions.java index 387fd226e1aeb..25ab9298ea2fc 100644 --- a/sdks/java/io/common/src/test/java/org/apache/beam/sdk/io/common/IOTestPipelineOptions.java +++ b/sdks/java/io/common/src/test/java/org/apache/beam/sdk/io/common/IOTestPipelineOptions.java @@ -71,11 +71,7 @@ public interface IOTestPipelineOptions extends TestPipelineOptions { Integer getElasticsearchHttpPort(); void setElasticsearchHttpPort(Integer value); - @Description("Tcp port for elasticsearch server") - @Default.Integer(9300) - Integer getElasticsearchTcpPort(); - void setElasticsearchTcpPort(Integer value); - + /* Cassandra */ @Description("Host for Cassandra server (host name/ip address)") @Default.String("cassandra-host") String getCassandraHost(); diff --git a/sdks/java/io/elasticsearch/src/main/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIO.java b/sdks/java/io/elasticsearch/src/main/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIO.java index f6ceef2286f5a..e3965dc6a0c01 100644 --- a/sdks/java/io/elasticsearch/src/main/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIO.java +++ b/sdks/java/io/elasticsearch/src/main/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIO.java @@ -139,7 +139,7 @@ private ElasticsearchIO() {} private static final ObjectMapper mapper = new ObjectMapper(); - private static JsonNode parseResponse(Response response) throws IOException { + static JsonNode parseResponse(Response response) throws IOException { return mapper.readValue(response.getEntity().getContent(), JsonNode.class); } @@ -264,7 +264,7 @@ private void populateDisplayData(DisplayData.Builder builder) { builder.addIfNotNull(DisplayData.item("username", getUsername())); } - private RestClient createClient() throws MalformedURLException { + RestClient createClient() throws MalformedURLException { HttpHost[] hosts = new HttpHost[getAddresses().size()]; int i = 0; for (String address : getAddresses()) { diff --git a/sdks/java/io/elasticsearch/src/test/java/org/apache/beam/sdk/io/elasticsearch/ElasticSearchIOTestUtils.java b/sdks/java/io/elasticsearch/src/test/java/org/apache/beam/sdk/io/elasticsearch/ElasticSearchIOTestUtils.java index b0d161fc4ac92..203963d149fcf 100644 --- a/sdks/java/io/elasticsearch/src/test/java/org/apache/beam/sdk/io/elasticsearch/ElasticSearchIOTestUtils.java +++ b/sdks/java/io/elasticsearch/src/test/java/org/apache/beam/sdk/io/elasticsearch/ElasticSearchIOTestUtils.java @@ -17,19 +17,17 @@ */ package org.apache.beam.sdk.io.elasticsearch; +import com.fasterxml.jackson.databind.JsonNode; import java.io.IOException; import java.util.ArrayList; +import java.util.Collections; import java.util.List; -import org.elasticsearch.action.admin.indices.exists.indices.IndicesExistsRequest; -import org.elasticsearch.action.admin.indices.exists.indices.IndicesExistsResponse; -import org.elasticsearch.action.admin.indices.upgrade.post.UpgradeRequest; -import org.elasticsearch.action.bulk.BulkRequestBuilder; -import org.elasticsearch.action.bulk.BulkResponse; -import org.elasticsearch.action.search.SearchResponse; -import org.elasticsearch.client.Client; -import org.elasticsearch.client.IndicesAdminClient; -import org.elasticsearch.client.Requests; -import org.elasticsearch.index.IndexNotFoundException; +import org.apache.http.HttpEntity; +import org.apache.http.entity.ContentType; +import org.apache.http.message.BasicHeader; +import org.apache.http.nio.entity.NStringEntity; +import org.elasticsearch.client.Response; +import org.elasticsearch.client.RestClient; /** Test utilities to use with {@link ElasticsearchIO}. */ class ElasticSearchIOTestUtils { @@ -41,57 +39,68 @@ enum InjectionMode { } /** Deletes the given index synchronously. */ - static void deleteIndex(String index, Client client) throws Exception { - IndicesAdminClient indices = client.admin().indices(); - IndicesExistsResponse indicesExistsResponse = - indices.exists(new IndicesExistsRequest(index)).get(); - if (indicesExistsResponse.isExists()) { - indices.prepareClose(index).get(); - indices.delete(Requests.deleteIndexRequest(index)).get(); + static void deleteIndex(String index, RestClient restClient) throws IOException { + try { + restClient.performRequest("DELETE", String.format("/%s", index), new BasicHeader("", "")); + } catch (IOException e) { + // it is fine to ignore this expression as deleteIndex occurs in @before, + // so when the first tests is run, the index does not exist yet + if (!e.getMessage().contains("index_not_found_exception")){ + throw e; + } } } /** Inserts the given number of test documents into Elasticsearch. */ - static void insertTestDocuments(String index, String type, long numDocs, Client client) - throws Exception { - final BulkRequestBuilder bulkRequestBuilder = client.prepareBulk().setRefresh(true); + static void insertTestDocuments(String index, String type, long numDocs, RestClient restClient) + throws IOException { List data = ElasticSearchIOTestUtils.createDocuments( numDocs, ElasticSearchIOTestUtils.InjectionMode.DO_NOT_INJECT_INVALID_DOCS); + StringBuilder bulkRequest = new StringBuilder(); for (String document : data) { - bulkRequestBuilder.add(client.prepareIndex(index, type, null).setSource(document)); + bulkRequest.append(String.format("{ \"index\" : {} }%n%s%n", document)); } - final BulkResponse bulkResponse = bulkRequestBuilder.execute().actionGet(); - if (bulkResponse.hasFailures()) { + String endPoint = String.format("/%s/%s/_bulk", index, type); + HttpEntity requestBody = + new NStringEntity(bulkRequest.toString(), ContentType.APPLICATION_JSON); + Response response = restClient.performRequest("POST", endPoint, + Collections.singletonMap("refresh", "true"), requestBody, + new BasicHeader("", "")); + JsonNode searchResult = ElasticsearchIO.parseResponse(response); + boolean errors = searchResult.path("errors").asBoolean(); + if (errors){ throw new IOException( - String.format( - "Cannot insert test documents in index %s : %s", - index, bulkResponse.buildFailureMessage())); + String.format("Failed to insert test documents in index %s", index)); } } /** - * Forces an upgrade of the given index to make recently inserted documents available for search. + * Forces a refresh of the given index to make recently inserted documents available for search. * * @return The number of docs in the index */ - static long upgradeIndexAndGetCurrentNumDocs(String index, String type, Client client) { + static long refreshIndexAndGetCurrentNumDocs(String index, String type, RestClient restClient) + throws IOException { + long result = 0; try { - client.admin().indices().upgrade(new UpgradeRequest(index)).actionGet(); - SearchResponse response = - client.prepareSearch(index).setTypes(type).execute().actionGet(5000); - return response.getHits().getTotalHits(); + String endPoint = String.format("/%s/_refresh", index); + restClient.performRequest("POST", endPoint, new BasicHeader("", "")); + + endPoint = String.format("/%s/%s/_search", index, type); + Response response = restClient.performRequest("GET", endPoint, new BasicHeader("", "")); + JsonNode searchResult = ElasticsearchIO.parseResponse(response); + result = searchResult.path("hits").path("total").asLong(); + } catch (IOException e) { // it is fine to ignore bellow exceptions because in testWriteWithBatchSize* sometimes, // we call upgrade before any doc have been written // (when there are fewer docs processed than batchSize). // In that cases index/type has not been created (created upon first doc insertion) - } catch (IndexNotFoundException e) { - } catch (java.lang.IllegalArgumentException e) { - if (!e.getMessage().contains("No search type")) { + if (!e.getMessage().contains("index_not_found_exception")){ throw e; } } - return 0; + return result; } /** diff --git a/sdks/java/io/elasticsearch/src/test/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIOIT.java b/sdks/java/io/elasticsearch/src/test/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIOIT.java index 2d6393adc56fc..7c37e8745347c 100644 --- a/sdks/java/io/elasticsearch/src/test/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIOIT.java +++ b/sdks/java/io/elasticsearch/src/test/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIOIT.java @@ -32,7 +32,7 @@ import org.apache.beam.sdk.transforms.Count; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.values.PCollection; -import org.elasticsearch.client.transport.TransportClient; +import org.elasticsearch.client.RestClient; import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Rule; @@ -57,7 +57,7 @@ */ public class ElasticsearchIOIT { private static final Logger LOG = LoggerFactory.getLogger(ElasticsearchIOIT.class); - private static TransportClient client; + private static RestClient restClient; private static IOTestPipelineOptions options; private static ElasticsearchIO.ConnectionConfiguration readConnectionConfiguration; @Rule public TestPipeline pipeline = TestPipeline.create(); @@ -66,16 +66,16 @@ public class ElasticsearchIOIT { public static void beforeClass() throws Exception { PipelineOptionsFactory.register(IOTestPipelineOptions.class); options = TestPipeline.testingPipelineOptions().as(IOTestPipelineOptions.class); - client = ElasticsearchTestDataSet.getClient(options); readConnectionConfiguration = ElasticsearchTestDataSet.getConnectionConfiguration( options, ElasticsearchTestDataSet.ReadOrWrite.READ); + restClient = readConnectionConfiguration.createClient(); } @AfterClass public static void afterClass() throws Exception { - ElasticsearchTestDataSet.deleteIndex(client, ElasticsearchTestDataSet.ReadOrWrite.WRITE); - client.close(); + ElasticsearchTestDataSet.deleteIndex(restClient, ElasticsearchTestDataSet.ReadOrWrite.WRITE); + restClient.close(); } @Test @@ -128,8 +128,8 @@ public void testWriteVolume() throws Exception { pipeline.run(); long currentNumDocs = - ElasticSearchIOTestUtils.upgradeIndexAndGetCurrentNumDocs( - ElasticsearchTestDataSet.ES_INDEX, ElasticsearchTestDataSet.ES_TYPE, client); + ElasticSearchIOTestUtils.refreshIndexAndGetCurrentNumDocs( + ElasticsearchTestDataSet.ES_INDEX, ElasticsearchTestDataSet.ES_TYPE, restClient); assertEquals(ElasticsearchTestDataSet.NUM_DOCS, currentNumDocs); } diff --git a/sdks/java/io/elasticsearch/src/test/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIOTest.java b/sdks/java/io/elasticsearch/src/test/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIOTest.java index 260af79bb0a4e..b349a29a4fcfd 100644 --- a/sdks/java/io/elasticsearch/src/test/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIOTest.java +++ b/sdks/java/io/elasticsearch/src/test/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIOTest.java @@ -39,11 +39,11 @@ import org.apache.beam.sdk.transforms.DoFnTester; import org.apache.beam.sdk.values.PCollection; import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.client.RestClient; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.node.Node; -import org.elasticsearch.node.NodeBuilder; import org.hamcrest.CustomMatcher; import org.junit.AfterClass; import org.junit.Before; @@ -74,9 +74,10 @@ public class ElasticsearchIOTest implements Serializable { private static final long BATCH_SIZE_BYTES = 2048L; private static Node node; + private static RestClient restClient; private static ElasticsearchIO.ConnectionConfiguration connectionConfiguration; - @ClassRule public static TemporaryFolder folder = new TemporaryFolder(); + @ClassRule public static final TemporaryFolder TEMPORARY_FOLDER = new TemporaryFolder(); @Rule public TestPipeline pipeline = TestPipeline.create(); @@ -91,8 +92,8 @@ public static void beforeClass() throws IOException { .put("cluster.name", "beam") .put("http.enabled", "true") .put("node.data", "true") - .put("path.data", folder.getRoot().getPath()) - .put("path.home", folder.getRoot().getPath()) + .put("path.data", TEMPORARY_FOLDER.getRoot().getPath()) + .put("path.home", TEMPORARY_FOLDER.getRoot().getPath()) .put("node.name", "beam") .put("network.host", ES_IP) .put("http.port", esHttpPort) @@ -100,27 +101,29 @@ public static void beforeClass() throws IOException { // had problems with some jdk, embedded ES was too slow for bulk insertion, // and queue of 50 was full. No pb with real ES instance (cf testWrite integration test) .put("threadpool.bulk.queue_size", 100); - node = NodeBuilder.nodeBuilder().settings(settingsBuilder).build(); + node = new Node(settingsBuilder.build()); LOG.info("Elasticsearch node created"); node.start(); connectionConfiguration = ElasticsearchIO.ConnectionConfiguration.create( new String[] {"http://" + ES_IP + ":" + esHttpPort}, ES_INDEX, ES_TYPE); + restClient = connectionConfiguration.createClient(); } @AfterClass - public static void afterClass() { + public static void afterClass() throws IOException{ + restClient.close(); node.close(); } @Before public void before() throws Exception { - ElasticSearchIOTestUtils.deleteIndex(ES_INDEX, node.client()); + ElasticSearchIOTestUtils.deleteIndex(ES_INDEX, restClient); } @Test public void testSizes() throws Exception { - ElasticSearchIOTestUtils.insertTestDocuments(ES_INDEX, ES_TYPE, NUM_DOCS, node.client()); + ElasticSearchIOTestUtils.insertTestDocuments(ES_INDEX, ES_TYPE, NUM_DOCS, restClient); PipelineOptions options = PipelineOptionsFactory.create(); ElasticsearchIO.Read read = ElasticsearchIO.read().withConnectionConfiguration(connectionConfiguration); @@ -134,7 +137,7 @@ public void testSizes() throws Exception { @Test public void testRead() throws Exception { - ElasticSearchIOTestUtils.insertTestDocuments(ES_INDEX, ES_TYPE, NUM_DOCS, node.client()); + ElasticSearchIOTestUtils.insertTestDocuments(ES_INDEX, ES_TYPE, NUM_DOCS, restClient); PCollection output = pipeline.apply( @@ -150,7 +153,7 @@ public void testRead() throws Exception { @Test public void testReadWithQuery() throws Exception { - ElasticSearchIOTestUtils.insertTestDocuments(ES_INDEX, ES_TYPE, NUM_DOCS, node.client()); + ElasticSearchIOTestUtils.insertTestDocuments(ES_INDEX, ES_TYPE, NUM_DOCS, restClient); String query = "{\n" @@ -185,7 +188,7 @@ public void testWrite() throws Exception { pipeline.run(); long currentNumDocs = - ElasticSearchIOTestUtils.upgradeIndexAndGetCurrentNumDocs(ES_INDEX, ES_TYPE, node.client()); + ElasticSearchIOTestUtils.refreshIndexAndGetCurrentNumDocs(ES_INDEX, ES_TYPE, restClient); assertEquals(NUM_DOCS, currentNumDocs); QueryBuilder queryBuilder = QueryBuilders.queryStringQuery("Einstein").field("scientist"); @@ -258,9 +261,8 @@ public void testWriteWithMaxBatchSize() throws Exception { if ((numDocsProcessed % 100) == 0) { // force the index to upgrade after inserting for the inserted docs // to be searchable immediately - long currentNumDocs = - ElasticSearchIOTestUtils.upgradeIndexAndGetCurrentNumDocs( - ES_INDEX, ES_TYPE, node.client()); + long currentNumDocs = ElasticSearchIOTestUtils + .refreshIndexAndGetCurrentNumDocs(ES_INDEX, ES_TYPE, restClient); if ((numDocsProcessed % BATCH_SIZE) == 0) { /* bundle end */ assertEquals( @@ -304,8 +306,8 @@ public void testWriteWithMaxBatchSizeBytes() throws Exception { // force the index to upgrade after inserting for the inserted docs // to be searchable immediately long currentNumDocs = - ElasticSearchIOTestUtils.upgradeIndexAndGetCurrentNumDocs( - ES_INDEX, ES_TYPE, node.client()); + ElasticSearchIOTestUtils.refreshIndexAndGetCurrentNumDocs( + ES_INDEX, ES_TYPE, restClient); if (sizeProcessed / BATCH_SIZE_BYTES > batchInserted) { /* bundle end */ assertThat( @@ -327,7 +329,7 @@ public void testWriteWithMaxBatchSizeBytes() throws Exception { @Test public void testSplit() throws Exception { - ElasticSearchIOTestUtils.insertTestDocuments(ES_INDEX, ES_TYPE, NUM_DOCS, node.client()); + ElasticSearchIOTestUtils.insertTestDocuments(ES_INDEX, ES_TYPE, NUM_DOCS, restClient); PipelineOptions options = PipelineOptionsFactory.create(); ElasticsearchIO.Read read = ElasticsearchIO.read().withConnectionConfiguration(connectionConfiguration); diff --git a/sdks/java/io/elasticsearch/src/test/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchTestDataSet.java b/sdks/java/io/elasticsearch/src/test/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchTestDataSet.java index 3a9aae6098f10..2a2dbe902defe 100644 --- a/sdks/java/io/elasticsearch/src/test/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchTestDataSet.java +++ b/sdks/java/io/elasticsearch/src/test/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchTestDataSet.java @@ -17,13 +17,11 @@ */ package org.apache.beam.sdk.io.elasticsearch; -import static java.net.InetAddress.getByName; import java.io.IOException; import org.apache.beam.sdk.io.common.IOTestPipelineOptions; import org.apache.beam.sdk.options.PipelineOptionsFactory; -import org.elasticsearch.client.transport.TransportClient; -import org.elasticsearch.common.transport.InetSocketTransportAddress; +import org.elasticsearch.client.RestClient; /** * Manipulates test data used by the {@link ElasticsearchIO} @@ -51,7 +49,6 @@ public class ElasticsearchTestDataSet { * -Dexec.mainClass=org.apache.beam.sdk.io.elasticsearch.ElasticsearchTestDataSet \ * -Dexec.args="--elasticsearchServer=1.2.3.4 \ * --elasticsearchHttpPort=9200 \ - * --elasticsearchTcpPort=9300" \ * -Dexec.classpathScope=test * * @@ -62,29 +59,20 @@ public static void main(String[] args) throws Exception { PipelineOptionsFactory.register(IOTestPipelineOptions.class); IOTestPipelineOptions options = PipelineOptionsFactory.fromArgs(args).as(IOTestPipelineOptions.class); - - createAndPopulateIndex(getClient(options), ReadOrWrite.READ); + createAndPopulateReadIndex(options); } - private static void createAndPopulateIndex(TransportClient client, ReadOrWrite rOw) - throws Exception { + private static void createAndPopulateReadIndex(IOTestPipelineOptions options) throws Exception { + RestClient restClient = getConnectionConfiguration(options, ReadOrWrite.READ).createClient(); // automatically creates the index and insert docs - ElasticSearchIOTestUtils.insertTestDocuments( - (rOw == ReadOrWrite.READ) ? ES_INDEX : writeIndex, ES_TYPE, NUM_DOCS, client); - } - - public static TransportClient getClient(IOTestPipelineOptions options) throws Exception { - TransportClient client = - TransportClient.builder() - .build() - .addTransportAddress( - new InetSocketTransportAddress( - getByName(options.getElasticsearchServer()), - options.getElasticsearchTcpPort())); - return client; + try { + ElasticSearchIOTestUtils.insertTestDocuments(ES_INDEX, ES_TYPE, NUM_DOCS, restClient); + } finally { + restClient.close(); + } } - public static ElasticsearchIO.ConnectionConfiguration getConnectionConfiguration( + static ElasticsearchIO.ConnectionConfiguration getConnectionConfiguration( IOTestPipelineOptions options, ReadOrWrite rOw) throws IOException { ElasticsearchIO.ConnectionConfiguration connectionConfiguration = ElasticsearchIO.ConnectionConfiguration.create( @@ -99,8 +87,9 @@ public static ElasticsearchIO.ConnectionConfiguration getConnectionConfiguration return connectionConfiguration; } - public static void deleteIndex(TransportClient client, ReadOrWrite rOw) throws Exception { - ElasticSearchIOTestUtils.deleteIndex((rOw == ReadOrWrite.READ) ? ES_INDEX : writeIndex, client); + static void deleteIndex(RestClient restClient, ReadOrWrite rOw) throws Exception { + ElasticSearchIOTestUtils + .deleteIndex((rOw == ReadOrWrite.READ) ? ES_INDEX : writeIndex, restClient); } /** Enum that tells whether we use the index for reading or for writing. */ From b7ae7ecffcd08b6a0ccc8296210d36b90306c171 Mon Sep 17 00:00:00 2001 From: Mark Liu Date: Wed, 7 Jun 2017 16:27:34 -0700 Subject: [PATCH 015/200] Fix compile error occurs in some JDKs --- .../flink/FlinkStreamingTransformTranslators.java | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java index ef46b63ae5619..fef32de77edfa 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java @@ -363,8 +363,13 @@ static void translateParDo( Map, OutputTag>> tagsToOutputTags = Maps.newHashMap(); for (Map.Entry, PValue> entry : outputs.entrySet()) { if (!tagsToOutputTags.containsKey(entry.getKey())) { - tagsToOutputTags.put(entry.getKey(), new OutputTag<>(entry.getKey().getId(), - (TypeInformation) context.getTypeInfo((PCollection) entry.getValue()))); + tagsToOutputTags.put( + entry.getKey(), + new OutputTag>( + entry.getKey().getId(), + (TypeInformation) context.getTypeInfo((PCollection) entry.getValue()) + ) + ); } } From fb61c540bc15bafb959d7accb7c08f6a681f62ef Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Thu, 8 Jun 2017 15:01:53 -0700 Subject: [PATCH 016/200] Use beam.Map rather than beam.ParDo for PubSub encoding. --- sdks/python/apache_beam/io/gcp/pubsub.py | 16 +++------------- sdks/python/apache_beam/io/gcp/pubsub_test.py | 10 ---------- 2 files changed, 3 insertions(+), 23 deletions(-) diff --git a/sdks/python/apache_beam/io/gcp/pubsub.py b/sdks/python/apache_beam/io/gcp/pubsub.py index 40326e10295aa..6dc15288276da 100644 --- a/sdks/python/apache_beam/io/gcp/pubsub.py +++ b/sdks/python/apache_beam/io/gcp/pubsub.py @@ -29,7 +29,7 @@ from apache_beam.io.iobase import Write from apache_beam.runners.dataflow.native_io import iobase as dataflow_io from apache_beam.transforms import PTransform -from apache_beam.transforms import ParDo +from apache_beam.transforms import Map from apache_beam.transforms.display import DisplayDataItem @@ -71,7 +71,7 @@ def __init__(self, topic=None, subscription=None, id_label=None): def expand(self, pvalue): pcoll = pvalue.pipeline | Read(self._source) pcoll.element_type = bytes - pcoll = pcoll | 'decode string' >> ParDo(_decodeUtf8String) + pcoll = pcoll | 'DecodeString' >> Map(lambda b: b.decode('utf-8')) pcoll.element_type = unicode return pcoll @@ -89,7 +89,7 @@ def __init__(self, topic): self._sink = _PubSubPayloadSink(topic) def expand(self, pcoll): - pcoll = pcoll | 'encode string' >> ParDo(_encodeUtf8String) + pcoll = pcoll | 'EncodeString' >> Map(lambda s: s.encode('utf-8')) pcoll.element_type = bytes return pcoll | Write(self._sink) @@ -162,16 +162,6 @@ def writer(self): 'PubSubPayloadSink is not supported in local execution.') -def _decodeUtf8String(encoded_value): - """Decodes a string in utf-8 format from bytes""" - return encoded_value.decode('utf-8') - - -def _encodeUtf8String(value): - """Encodes a string in utf-8 format to bytes""" - return value.encode('utf-8') - - class PubSubSource(dataflow_io.NativeSource): """Deprecated: do not use. diff --git a/sdks/python/apache_beam/io/gcp/pubsub_test.py b/sdks/python/apache_beam/io/gcp/pubsub_test.py index cf14e8c1d9217..5d3e985597c0b 100644 --- a/sdks/python/apache_beam/io/gcp/pubsub_test.py +++ b/sdks/python/apache_beam/io/gcp/pubsub_test.py @@ -22,8 +22,6 @@ import hamcrest as hc -from apache_beam.io.gcp.pubsub import _decodeUtf8String -from apache_beam.io.gcp.pubsub import _encodeUtf8String from apache_beam.io.gcp.pubsub import _PubSubPayloadSink from apache_beam.io.gcp.pubsub import _PubSubPayloadSource from apache_beam.io.gcp.pubsub import ReadStringsFromPubSub @@ -120,14 +118,6 @@ def test_display_data(self): hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items)) -class TestEncodeDecodeUtf8String(unittest.TestCase): - def test_encode(self): - self.assertEqual(b'test_data', _encodeUtf8String('test_data')) - - def test_decode(self): - self.assertEqual('test_data', _decodeUtf8String(b'test_data')) - - if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) unittest.main() From 261e7df2b860fe82d9f401e2621b020fe2020fea Mon Sep 17 00:00:00 2001 From: Thomas Groh Date: Tue, 6 Jun 2017 16:15:19 -0700 Subject: [PATCH 017/200] Visit a Transform Hierarchy in Topological Order This reverts commit 6ad6433ec0c02aec8656e9e3b27f6e0f974f8ece. --- .../spark/translation/StorageLevelTest.java | 4 +- .../beam/sdk/runners/TransformHierarchy.java | 79 ++++++- .../sdk/runners/TransformHierarchyTest.java | 197 ++++++++++++++++++ 3 files changed, 274 insertions(+), 6 deletions(-) diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/StorageLevelTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/StorageLevelTest.java index 8f2e681c9e91f..8bd6dae98a74f 100644 --- a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/StorageLevelTest.java +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/StorageLevelTest.java @@ -58,12 +58,12 @@ public static void teardown() { @Test public void test() throws Exception { - PCollection pCollection = pipeline.apply(Create.of("foo")); + PCollection pCollection = pipeline.apply("CreateFoo", Create.of("foo")); // by default, the Spark runner doesn't cache the RDD if it accessed only one time. // So, to "force" the caching of the RDD, we have to call the RDD at least two time. // That's why we are using Count fn on the PCollection. - pCollection.apply(Count.globally()); + pCollection.apply("CountAll", Count.globally()); PCollection output = pCollection.apply(new StorageLevelPTransform()); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformHierarchy.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformHierarchy.java index ee1ce7b2b6683..5e048ebb08197 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformHierarchy.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformHierarchy.java @@ -208,7 +208,7 @@ Node getProducer(PValue produced) { public Set visit(PipelineVisitor visitor) { finishSpecifying(); Set visitedValues = new HashSet<>(); - root.visit(visitor, visitedValues); + root.visit(visitor, visitedValues, new HashSet(), new HashSet()); return visitedValues; } @@ -503,10 +503,60 @@ public Map, PValue> getOutputs() { /** * Visit the transform node. * + *

The visit proceeds in the following order: + * + *

    + *
  • Visit all input {@link PValue PValues} returned by the flattened expansion of {@link + * Node#getInputs()}. + *
  • If the node is a composite: + *
      + *
    • Enter the node via {@link PipelineVisitor#enterCompositeTransform(Node)}. + *
    • If the result of {@link PipelineVisitor#enterCompositeTransform(Node)} was {@link + * CompositeBehavior#ENTER_TRANSFORM}, visit each child node of this {@link Node}. + *
    • Leave the node via {@link PipelineVisitor#leaveCompositeTransform(Node)}. + *
    + *
  • If the node is a primitive, visit it via {@link + * PipelineVisitor#visitPrimitiveTransform(Node)}. + *
  • Visit each {@link PValue} that was output by this node. + *
+ * + *

Additionally, the following ordering restrictions are observed: + * + *

    + *
  • A {@link Node} will be visited after its enclosing node has been entered and before its + * enclosing node has been left + *
  • A {@link Node} will not be visited if any enclosing {@link Node} has returned {@link + * CompositeBehavior#DO_NOT_ENTER_TRANSFORM} from the call to {@link + * PipelineVisitor#enterCompositeTransform(Node)}. + *
  • A {@link PValue} will only be visited after the {@link Node} that originally produced + * it has been visited. + *
+ * *

Provides an ordered visit of the input values, the primitive transform (or child nodes for * composite transforms), then the output values. */ - private void visit(PipelineVisitor visitor, Set visitedValues) { + private void visit( + PipelineVisitor visitor, + Set visitedValues, + Set visitedNodes, + Set skippedComposites) { + if (getEnclosingNode() != null && !visitedNodes.contains(getEnclosingNode())) { + // Recursively enter all enclosing nodes, as appropriate. + getEnclosingNode().visit(visitor, visitedValues, visitedNodes, skippedComposites); + } + // These checks occur after visiting the enclosing node to ensure that if this node has been + // visited while visiting the enclosing node the node is not revisited, or, if an enclosing + // Node is skipped, this node is also skipped. + if (!visitedNodes.add(this)) { + LOG.debug("Not revisiting previously visited node {}", this); + return; + } else if (childNodeOf(skippedComposites)) { + // This node is a child of a node that has been passed over via CompositeBehavior, and + // should also be skipped. All child nodes of a skipped composite should always be skipped. + LOG.debug("Not revisiting Node {} which is a child of a previously passed composite", this); + return; + } + if (!finishedSpecifying) { finishSpecifying(); } @@ -514,22 +564,31 @@ private void visit(PipelineVisitor visitor, Set visitedValues) { if (!isRootNode()) { // Visit inputs. for (PValue inputValue : inputs.values()) { + Node valueProducer = getProducer(inputValue); + if (!visitedNodes.contains(valueProducer)) { + valueProducer.visit(visitor, visitedValues, visitedNodes, skippedComposites); + } if (visitedValues.add(inputValue)) { - visitor.visitValue(inputValue, getProducer(inputValue)); + LOG.debug("Visiting input value {}", inputValue); + visitor.visitValue(inputValue, valueProducer); } } } if (isCompositeNode()) { + LOG.debug("Visiting composite node {}", this); PipelineVisitor.CompositeBehavior recurse = visitor.enterCompositeTransform(this); if (recurse.equals(CompositeBehavior.ENTER_TRANSFORM)) { for (Node child : parts) { - child.visit(visitor, visitedValues); + child.visit(visitor, visitedValues, visitedNodes, skippedComposites); } + } else { + skippedComposites.add(this); } visitor.leaveCompositeTransform(this); } else { + LOG.debug("Visiting primitive node {}", this); visitor.visitPrimitiveTransform(this); } @@ -538,12 +597,24 @@ private void visit(PipelineVisitor visitor, Set visitedValues) { // Visit outputs. for (PValue pValue : outputs.values()) { if (visitedValues.add(pValue)) { + LOG.debug("Visiting output value {}", pValue); visitor.visitValue(pValue, this); } } } } + private boolean childNodeOf(Set nodes) { + if (isRootNode()) { + return false; + } + Node parent = this.getEnclosingNode(); + while (!parent.isRootNode() && !nodes.contains(parent)) { + parent = parent.getEnclosingNode(); + } + return nodes.contains(parent); + } + /** * Finish specifying a transform. * diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/TransformHierarchyTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/TransformHierarchyTest.java index 1197d1b04eb67..93650dd8c9836 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/TransformHierarchyTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/TransformHierarchyTest.java @@ -19,6 +19,7 @@ import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasItem; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.not; import static org.junit.Assert.assertThat; @@ -32,6 +33,8 @@ import java.util.Set; import org.apache.beam.sdk.Pipeline.PipelineVisitor; import org.apache.beam.sdk.Pipeline.PipelineVisitor.Defaults; +import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.coders.VarIntCoder; import org.apache.beam.sdk.io.CountingSource; import org.apache.beam.sdk.io.GenerateSequence; import org.apache.beam.sdk.io.Read; @@ -492,4 +495,198 @@ public void visitPrimitiveTransform(Node node) { assertThat(visitedPrimitiveNodes, containsInAnyOrder(upstreamNode, replacementParNode)); assertThat(visitedValues, Matchers.containsInAnyOrder(upstream, output)); } + + @Test + public void visitIsTopologicallyOrdered() { + PCollection one = + PCollection.createPrimitiveOutputInternal( + pipeline, WindowingStrategy.globalDefault(), IsBounded.BOUNDED) + .setCoder(StringUtf8Coder.of()); + final PCollection two = + PCollection.createPrimitiveOutputInternal( + pipeline, WindowingStrategy.globalDefault(), IsBounded.UNBOUNDED) + .setCoder(VarIntCoder.of()); + final PDone done = PDone.in(pipeline); + final TupleTag oneTag = new TupleTag() {}; + final TupleTag twoTag = new TupleTag() {}; + final PCollectionTuple oneAndTwo = PCollectionTuple.of(oneTag, one).and(twoTag, two); + + PTransform, PDone> multiConsumer = + new PTransform, PDone>() { + @Override + public PDone expand(PCollection input) { + return done; + } + + @Override + public Map, PValue> getAdditionalInputs() { + return Collections., PValue>singletonMap(twoTag, two); + } + }; + hierarchy.pushNode("consumes_both", one, multiConsumer); + hierarchy.setOutput(done); + hierarchy.popNode(); + + final PTransform producer = + new PTransform() { + @Override + public PCollectionTuple expand(PBegin input) { + return oneAndTwo; + } + }; + hierarchy.pushNode( + "encloses_producer", + PBegin.in(pipeline), + new PTransform() { + @Override + public PCollectionTuple expand(PBegin input) { + return input.apply(producer); + } + }); + hierarchy.pushNode( + "creates_one_and_two", + PBegin.in(pipeline), producer); + hierarchy.setOutput(oneAndTwo); + hierarchy.popNode(); + hierarchy.setOutput(oneAndTwo); + hierarchy.popNode(); + + hierarchy.pushNode("second_copy_of_consumes_both", one, multiConsumer); + hierarchy.setOutput(done); + hierarchy.popNode(); + + final Set visitedNodes = new HashSet<>(); + final Set exitedNodes = new HashSet<>(); + final Set visitedValues = new HashSet<>(); + hierarchy.visit( + new PipelineVisitor.Defaults() { + + @Override + public CompositeBehavior enterCompositeTransform(Node node) { + for (PValue input : node.getInputs().values()) { + assertThat(visitedValues, hasItem(input)); + } + assertThat( + "Nodes should not be visited more than once", visitedNodes, not(hasItem(node))); + if (!node.isRootNode()) { + assertThat( + "Nodes should always be visited after their enclosing nodes", + visitedNodes, + hasItem(node.getEnclosingNode())); + } + visitedNodes.add(node); + return CompositeBehavior.ENTER_TRANSFORM; + } + + @Override + public void leaveCompositeTransform(Node node) { + assertThat(visitedNodes, hasItem(node)); + if (!node.isRootNode()) { + assertThat( + "Nodes should always be left before their enclosing nodes are left", + exitedNodes, + not(hasItem(node.getEnclosingNode()))); + } + assertThat(exitedNodes, not(hasItem(node))); + exitedNodes.add(node); + } + + @Override + public void visitPrimitiveTransform(Node node) { + assertThat(visitedNodes, hasItem(node.getEnclosingNode())); + assertThat(exitedNodes, not(hasItem(node.getEnclosingNode()))); + assertThat( + "Nodes should not be visited more than once", visitedNodes, not(hasItem(node))); + for (PValue input : node.getInputs().values()) { + assertThat(visitedValues, hasItem(input)); + } + visitedNodes.add(node); + } + + @Override + public void visitValue(PValue value, Node producer) { + assertThat(visitedNodes, hasItem(producer)); + assertThat(visitedValues, not(hasItem(value))); + visitedValues.add(value); + } + }); + assertThat("Should have visited all the nodes", visitedNodes.size(), equalTo(5)); + assertThat("Should have left all of the visited composites", exitedNodes.size(), equalTo(2)); + } + + @Test + public void visitDoesNotVisitSkippedNodes() { + PCollection one = + PCollection.createPrimitiveOutputInternal( + pipeline, WindowingStrategy.globalDefault(), IsBounded.BOUNDED) + .setCoder(StringUtf8Coder.of()); + final PCollection two = + PCollection.createPrimitiveOutputInternal( + pipeline, WindowingStrategy.globalDefault(), IsBounded.UNBOUNDED) + .setCoder(VarIntCoder.of()); + final PDone done = PDone.in(pipeline); + final TupleTag oneTag = new TupleTag() {}; + final TupleTag twoTag = new TupleTag() {}; + final PCollectionTuple oneAndTwo = PCollectionTuple.of(oneTag, one).and(twoTag, two); + + hierarchy.pushNode( + "consumes_both", + one, + new PTransform, PDone>() { + @Override + public PDone expand(PCollection input) { + return done; + } + + @Override + public Map, PValue> getAdditionalInputs() { + return Collections., PValue>singletonMap(twoTag, two); + } + }); + hierarchy.setOutput(done); + hierarchy.popNode(); + + final PTransform producer = + new PTransform() { + @Override + public PCollectionTuple expand(PBegin input) { + return oneAndTwo; + } + }; + final Node enclosing = + hierarchy.pushNode( + "encloses_producer", + PBegin.in(pipeline), + new PTransform() { + @Override + public PCollectionTuple expand(PBegin input) { + return input.apply(producer); + } + }); + Node enclosed = hierarchy.pushNode("creates_one_and_two", PBegin.in(pipeline), producer); + hierarchy.setOutput(oneAndTwo); + hierarchy.popNode(); + hierarchy.setOutput(oneAndTwo); + hierarchy.popNode(); + + final Set visitedNodes = new HashSet<>(); + hierarchy.visit( + new PipelineVisitor.Defaults() { + @Override + public CompositeBehavior enterCompositeTransform(Node node) { + visitedNodes.add(node); + return node.equals(enclosing) + ? CompositeBehavior.DO_NOT_ENTER_TRANSFORM + : CompositeBehavior.ENTER_TRANSFORM; + } + + @Override + public void visitPrimitiveTransform(Node node) { + visitedNodes.add(node); + } + }); + + assertThat(visitedNodes, hasItem(enclosing)); + assertThat(visitedNodes, not(hasItem(enclosed))); + } } From 696f8b28a3a17e7de81e2d46bb9774d57d6e265e Mon Sep 17 00:00:00 2001 From: Thomas Groh Date: Tue, 6 Jun 2017 17:00:09 -0700 Subject: [PATCH 018/200] Roll-forward Include Additional PTransform inputs in Transform Nodes Update DirectGraph to have All and Non-Additional Inputs This reverts commit 247f9bc1581984d026764b3d433cb594e700bc21 --- .../apex/translation/TranslationContext.java | 4 +- .../core/construction/TransformInputs.java | 50 ++++++ .../construction/TransformInputsTest.java | 166 ++++++++++++++++++ .../beam/runners/direct/DirectGraph.java | 34 +++- .../runners/direct/DirectGraphVisitor.java | 26 ++- .../ExecutorServiceParallelExecutor.java | 2 +- .../runners/direct/ParDoEvaluatorFactory.java | 9 +- ...ttableProcessElementsEvaluatorFactory.java | 2 + .../direct/StatefulParDoEvaluatorFactory.java | 1 + .../beam/runners/direct/WatermarkManager.java | 14 +- .../direct/DirectGraphVisitorTest.java | 10 +- .../runners/direct/EvaluationContextTest.java | 2 +- .../runners/direct/ParDoEvaluatorTest.java | 6 +- .../flink/FlinkBatchTranslationContext.java | 3 +- .../FlinkStreamingTranslationContext.java | 3 +- .../dataflow/DataflowPipelineTranslator.java | 5 +- .../spark/translation/EvaluationContext.java | 4 +- .../beam/sdk/runners/TransformHierarchy.java | 28 ++- 18 files changed, 323 insertions(+), 46 deletions(-) create mode 100644 runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/TransformInputs.java create mode 100644 runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/TransformInputsTest.java diff --git a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/TranslationContext.java b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/TranslationContext.java index aff3863624c42..94d13e177decb 100644 --- a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/TranslationContext.java +++ b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/TranslationContext.java @@ -34,6 +34,7 @@ import org.apache.beam.runners.apex.translation.utils.ApexStateInternals.ApexStateBackend; import org.apache.beam.runners.apex.translation.utils.ApexStreamTuple; import org.apache.beam.runners.apex.translation.utils.CoderAdapterStreamCodec; +import org.apache.beam.runners.core.construction.TransformInputs; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.runners.AppliedPTransform; import org.apache.beam.sdk.transforms.PTransform; @@ -93,7 +94,8 @@ public Map, PValue> getInputs() { } public InputT getInput() { - return (InputT) Iterables.getOnlyElement(getCurrentTransform().getInputs().values()); + return (InputT) + Iterables.getOnlyElement(TransformInputs.nonAdditionalInputs(getCurrentTransform())); } public Map, PValue> getOutputs() { diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/TransformInputs.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/TransformInputs.java new file mode 100644 index 0000000000000..2baf93a3c1282 --- /dev/null +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/TransformInputs.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.core.construction; + +import static com.google.common.base.Preconditions.checkArgument; + +import com.google.common.collect.ImmutableList; +import java.util.Collection; +import java.util.Map; +import org.apache.beam.sdk.runners.AppliedPTransform; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.values.PValue; +import org.apache.beam.sdk.values.TupleTag; + +/** Utilities for extracting subsets of inputs from an {@link AppliedPTransform}. */ +public class TransformInputs { + /** + * Gets all inputs of the {@link AppliedPTransform} that are not returned by {@link + * PTransform#getAdditionalInputs()}. + */ + public static Collection nonAdditionalInputs(AppliedPTransform application) { + ImmutableList.Builder mainInputs = ImmutableList.builder(); + PTransform transform = application.getTransform(); + for (Map.Entry, PValue> input : application.getInputs().entrySet()) { + if (!transform.getAdditionalInputs().containsKey(input.getKey())) { + mainInputs.add(input.getValue()); + } + } + checkArgument( + !mainInputs.build().isEmpty() || application.getInputs().isEmpty(), + "Expected at least one main input if any inputs exist"); + return mainInputs.build(); + } +} diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/TransformInputsTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/TransformInputsTest.java new file mode 100644 index 0000000000000..f5b2c11e7923d --- /dev/null +++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/TransformInputsTest.java @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.core.construction; + +import static org.junit.Assert.assertThat; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import org.apache.beam.sdk.coders.VoidCoder; +import org.apache.beam.sdk.io.GenerateSequence; +import org.apache.beam.sdk.runners.AppliedPTransform; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PDone; +import org.apache.beam.sdk.values.PInput; +import org.apache.beam.sdk.values.POutput; +import org.apache.beam.sdk.values.PValue; +import org.apache.beam.sdk.values.TupleTag; +import org.hamcrest.Matchers; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link TransformInputs}. */ +@RunWith(JUnit4.class) +public class TransformInputsTest { + @Rule public TestPipeline pipeline = TestPipeline.create().enableAbandonedNodeEnforcement(false); + @Rule public ExpectedException thrown = ExpectedException.none(); + + @Test + public void nonAdditionalInputsWithNoInputSucceeds() { + AppliedPTransform transform = + AppliedPTransform.of( + "input-free", + Collections., PValue>emptyMap(), + Collections., PValue>emptyMap(), + new TestTransform(), + pipeline); + + assertThat(TransformInputs.nonAdditionalInputs(transform), Matchers.empty()); + } + + @Test + public void nonAdditionalInputsWithOneMainInputSucceeds() { + PCollection input = pipeline.apply(GenerateSequence.from(1L)); + AppliedPTransform transform = + AppliedPTransform.of( + "input-single", + Collections., PValue>singletonMap(new TupleTag() {}, input), + Collections., PValue>emptyMap(), + new TestTransform(), + pipeline); + + assertThat( + TransformInputs.nonAdditionalInputs(transform), Matchers.containsInAnyOrder(input)); + } + + @Test + public void nonAdditionalInputsWithMultipleNonAdditionalInputsSucceeds() { + Map, PValue> allInputs = new HashMap<>(); + PCollection mainInts = pipeline.apply("MainInput", Create.of(12, 3)); + allInputs.put(new TupleTag() {}, mainInts); + PCollection voids = pipeline.apply("VoidInput", Create.empty(VoidCoder.of())); + allInputs.put(new TupleTag() {}, voids); + AppliedPTransform transform = + AppliedPTransform.of( + "additional-free", + allInputs, + Collections., PValue>emptyMap(), + new TestTransform(), + pipeline); + + assertThat( + TransformInputs.nonAdditionalInputs(transform), + Matchers.containsInAnyOrder(voids, mainInts)); + } + + @Test + public void nonAdditionalInputsWithAdditionalInputsSucceeds() { + Map, PValue> additionalInputs = new HashMap<>(); + additionalInputs.put(new TupleTag() {}, pipeline.apply(Create.of("1, 2", "3"))); + additionalInputs.put(new TupleTag() {}, pipeline.apply(GenerateSequence.from(3L))); + + Map, PValue> allInputs = new HashMap<>(); + PCollection mainInts = pipeline.apply("MainInput", Create.of(12, 3)); + allInputs.put(new TupleTag() {}, mainInts); + PCollection voids = pipeline.apply("VoidInput", Create.empty(VoidCoder.of())); + allInputs.put( + new TupleTag() {}, voids); + allInputs.putAll(additionalInputs); + + AppliedPTransform transform = + AppliedPTransform.of( + "additional", + allInputs, + Collections., PValue>emptyMap(), + new TestTransform(additionalInputs), + pipeline); + + assertThat( + TransformInputs.nonAdditionalInputs(transform), + Matchers.containsInAnyOrder(mainInts, voids)); + } + + @Test + public void nonAdditionalInputsWithOnlyAdditionalInputsThrows() { + Map, PValue> additionalInputs = new HashMap<>(); + additionalInputs.put(new TupleTag() {}, pipeline.apply(Create.of("1, 2", "3"))); + additionalInputs.put(new TupleTag() {}, pipeline.apply(GenerateSequence.from(3L))); + + AppliedPTransform transform = + AppliedPTransform.of( + "additional-only", + additionalInputs, + Collections., PValue>emptyMap(), + new TestTransform(additionalInputs), + pipeline); + + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("at least one"); + TransformInputs.nonAdditionalInputs(transform); + } + + private static class TestTransform extends PTransform { + private final Map, PValue> additionalInputs; + + private TestTransform() { + this(Collections., PValue>emptyMap()); + } + + private TestTransform(Map, PValue> additionalInputs) { + this.additionalInputs = additionalInputs; + } + + @Override + public POutput expand(PInput input) { + return PDone.in(input.getPipeline()); + } + + @Override + public Map, PValue> getAdditionalInputs() { + return additionalInputs; + } + } +} diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGraph.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGraph.java index 9ca745d4670d9..ad17b2b5c3e25 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGraph.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGraph.java @@ -17,6 +17,8 @@ */ package org.apache.beam.runners.direct; +import static com.google.common.base.Preconditions.checkArgument; + import com.google.common.collect.ListMultimap; import java.util.Collection; import java.util.List; @@ -36,7 +38,8 @@ class DirectGraph { private final Map, AppliedPTransform> producers; private final Map, AppliedPTransform> viewWriters; - private final ListMultimap> primitiveConsumers; + private final ListMultimap> perElementConsumers; + private final ListMultimap> allConsumers; private final Set> rootTransforms; private final Map, String> stepNames; @@ -44,23 +47,36 @@ class DirectGraph { public static DirectGraph create( Map, AppliedPTransform> producers, Map, AppliedPTransform> viewWriters, - ListMultimap> primitiveConsumers, + ListMultimap> perElementConsumers, + ListMultimap> allConsumers, Set> rootTransforms, Map, String> stepNames) { - return new DirectGraph(producers, viewWriters, primitiveConsumers, rootTransforms, stepNames); + return new DirectGraph( + producers, viewWriters, perElementConsumers, allConsumers, rootTransforms, stepNames); } private DirectGraph( Map, AppliedPTransform> producers, Map, AppliedPTransform> viewWriters, - ListMultimap> primitiveConsumers, + ListMultimap> perElementConsumers, + ListMultimap> allConsumers, Set> rootTransforms, Map, String> stepNames) { this.producers = producers; this.viewWriters = viewWriters; - this.primitiveConsumers = primitiveConsumers; + this.perElementConsumers = perElementConsumers; + this.allConsumers = allConsumers; this.rootTransforms = rootTransforms; this.stepNames = stepNames; + for (AppliedPTransform step : stepNames.keySet()) { + for (PValue input : step.getInputs().values()) { + checkArgument( + allConsumers.get(input).contains(step), + "Step %s lists value %s as input, but it is not in the graph of consumers", + step.getFullName(), + input); + } + } } AppliedPTransform getProducer(PCollection produced) { @@ -71,8 +87,12 @@ private DirectGraph( return viewWriters.get(view); } - List> getPrimitiveConsumers(PValue consumed) { - return primitiveConsumers.get(consumed); + List> getPerElementConsumers(PValue consumed) { + return perElementConsumers.get(consumed); + } + + List> getAllConsumers(PValue consumed) { + return allConsumers.get(consumed); } Set> getRootTransforms() { diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGraphVisitor.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGraphVisitor.java index 07bcf06926cb4..675de2c15650a 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGraphVisitor.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGraphVisitor.java @@ -22,10 +22,12 @@ import com.google.common.collect.ArrayListMultimap; import com.google.common.collect.ListMultimap; import com.google.common.collect.Sets; +import java.util.Collection; import java.util.HashMap; import java.util.HashSet; import java.util.Map; import java.util.Set; +import org.apache.beam.runners.core.construction.TransformInputs; import org.apache.beam.runners.direct.ViewOverrideFactory.WriteView; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.Pipeline.PipelineVisitor; @@ -37,6 +39,8 @@ import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.PInput; import org.apache.beam.sdk.values.PValue; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * Tracks the {@link AppliedPTransform AppliedPTransforms} that consume each {@link PValue} in the @@ -44,12 +48,15 @@ * input after the upstream transform has produced and committed output. */ class DirectGraphVisitor extends PipelineVisitor.Defaults { + private static final Logger LOG = LoggerFactory.getLogger(DirectGraphVisitor.class); private Map, AppliedPTransform> producers = new HashMap<>(); private Map, AppliedPTransform> viewWriters = new HashMap<>(); private Set> consumedViews = new HashSet<>(); - private ListMultimap> primitiveConsumers = + private ListMultimap> perElementConsumers = + ArrayListMultimap.create(); + private ListMultimap> allConsumers = ArrayListMultimap.create(); private Set> rootTransforms = new HashSet<>(); @@ -94,8 +101,19 @@ public void visitPrimitiveTransform(TransformHierarchy.Node node) { if (node.getInputs().isEmpty()) { rootTransforms.add(appliedTransform); } else { + Collection mainInputs = + TransformInputs.nonAdditionalInputs(node.toAppliedPTransform(getPipeline())); + if (!mainInputs.containsAll(node.getInputs().values())) { + LOG.debug( + "Inputs reduced to {} from {} by removing additional inputs", + mainInputs, + node.getInputs().values()); + } + for (PValue value : mainInputs) { + perElementConsumers.put(value, appliedTransform); + } for (PValue value : node.getInputs().values()) { - primitiveConsumers.put(value, appliedTransform); + allConsumers.put(value, appliedTransform); } } if (node.getTransform() instanceof ParDo.MultiOutput) { @@ -106,7 +124,7 @@ public void visitPrimitiveTransform(TransformHierarchy.Node node) { } } - @Override + @Override public void visitValue(PValue value, TransformHierarchy.Node producer) { AppliedPTransform appliedTransform = getAppliedTransform(producer); if (value instanceof PCollection && !producers.containsKey(value)) { @@ -131,6 +149,6 @@ private String genStepName() { public DirectGraph getGraph() { checkState(finalized, "Can't get a graph before the Pipeline has been completely traversed"); return DirectGraph.create( - producers, viewWriters, primitiveConsumers, rootTransforms, stepNames); + producers, viewWriters, perElementConsumers, allConsumers, rootTransforms, stepNames); } } diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ExecutorServiceParallelExecutor.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ExecutorServiceParallelExecutor.java index 71ab4cc0e08ca..6fe8ebd2609c7 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ExecutorServiceParallelExecutor.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ExecutorServiceParallelExecutor.java @@ -355,7 +355,7 @@ public final CommittedResult handleResult( for (CommittedBundle outputBundle : committedResult.getOutputs()) { allUpdates.offer( ExecutorUpdate.fromBundle( - outputBundle, graph.getPrimitiveConsumers(outputBundle.getPCollection()))); + outputBundle, graph.getPerElementConsumers(outputBundle.getPCollection()))); } CommittedBundle unprocessedInputs = committedResult.getUnprocessedInputs(); if (unprocessedInputs != null && !Iterables.isEmpty(unprocessedInputs.getElements())) { diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoEvaluatorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoEvaluatorFactory.java index 8aa75cf1445e4..516f798aba975 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoEvaluatorFactory.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoEvaluatorFactory.java @@ -20,7 +20,6 @@ import com.google.common.cache.CacheBuilder; import com.google.common.cache.CacheLoader; import com.google.common.cache.LoadingCache; -import com.google.common.collect.Iterables; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -79,6 +78,7 @@ public TransformEvaluator forApplication( (TransformEvaluator) createEvaluator( (AppliedPTransform) application, + (PCollection) inputBundle.getPCollection(), inputBundle.getKey(), doFn, transform.getSideInputs(), @@ -102,6 +102,7 @@ public void cleanup() throws Exception { @SuppressWarnings({"unchecked", "rawtypes"}) DoFnLifecycleManagerRemovingTransformEvaluator createEvaluator( AppliedPTransform, PCollectionTuple, ?> application, + PCollection mainInput, StructuralKey inputBundleKey, DoFn doFn, List> sideInputs, @@ -120,6 +121,7 @@ DoFnLifecycleManagerRemovingTransformEvaluator createEvaluator( createParDoEvaluator( application, inputBundleKey, + mainInput, sideInputs, mainOutputTag, additionalOutputTags, @@ -132,6 +134,7 @@ DoFnLifecycleManagerRemovingTransformEvaluator createEvaluator( ParDoEvaluator createParDoEvaluator( AppliedPTransform, PCollectionTuple, ?> application, StructuralKey key, + PCollection mainInput, List> sideInputs, TupleTag mainOutputTag, List> additionalOutputTags, @@ -144,8 +147,7 @@ ParDoEvaluator createParDoEvaluator( evaluationContext, stepContext, application, - ((PCollection) Iterables.getOnlyElement(application.getInputs().values())) - .getWindowingStrategy(), + mainInput.getWindowingStrategy(), fn, key, sideInputs, @@ -173,5 +175,4 @@ static Map, PCollection> pcollections(Map, PValue> ou } return pcs; } - } diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/SplittableProcessElementsEvaluatorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/SplittableProcessElementsEvaluatorFactory.java index b85f481c14890..eccc83a031cb2 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/SplittableProcessElementsEvaluatorFactory.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/SplittableProcessElementsEvaluatorFactory.java @@ -116,6 +116,8 @@ public void cleanup() throws Exception { delegateFactory.createParDoEvaluator( application, inputBundle.getKey(), + (PCollection>>) + inputBundle.getPCollection(), transform.getSideInputs(), transform.getMainOutputTag(), transform.getAdditionalOutputTags().getAll(), diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactory.java index 506c84cec6390..3619d05b47caf 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactory.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactory.java @@ -117,6 +117,7 @@ private TransformEvaluator>> createEvaluator( DoFnLifecycleManagerRemovingTransformEvaluator> delegateEvaluator = delegateFactory.createEvaluator( (AppliedPTransform) application, + (PCollection) inputBundle.getPCollection(), inputBundle.getKey(), doFn, application.getTransform().getUnderlyingParDo().getSideInputs(), diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WatermarkManager.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WatermarkManager.java index 40ce163012fba..80a3504599d8c 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WatermarkManager.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WatermarkManager.java @@ -54,6 +54,7 @@ import org.apache.beam.runners.core.StateNamespace; import org.apache.beam.runners.core.TimerInternals; import org.apache.beam.runners.core.TimerInternals.TimerData; +import org.apache.beam.runners.core.construction.TransformInputs; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.runners.AppliedPTransform; import org.apache.beam.sdk.state.TimeDomain; @@ -62,7 +63,6 @@ import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.PValue; -import org.apache.beam.sdk.values.TupleTag; import org.joda.time.Instant; /** @@ -831,11 +831,11 @@ private TransformWatermarks getTransformWatermark(AppliedPTransform tra private Collection getInputProcessingWatermarks(AppliedPTransform transform) { ImmutableList.Builder inputWmsBuilder = ImmutableList.builder(); - Map, PValue> inputs = transform.getInputs(); + Collection inputs = TransformInputs.nonAdditionalInputs(transform); if (inputs.isEmpty()) { inputWmsBuilder.add(THE_END_OF_TIME); } - for (PValue pvalue : inputs.values()) { + for (PValue pvalue : inputs) { Watermark producerOutputWatermark = getValueWatermark(pvalue).synchronizedProcessingOutputWatermark; inputWmsBuilder.add(producerOutputWatermark); @@ -845,11 +845,11 @@ private Collection getInputProcessingWatermarks(AppliedPTransform getInputWatermarks(AppliedPTransform transform) { ImmutableList.Builder inputWatermarksBuilder = ImmutableList.builder(); - Map, PValue> inputs = transform.getInputs(); + Collection< PValue> inputs = TransformInputs.nonAdditionalInputs(transform); if (inputs.isEmpty()) { inputWatermarksBuilder.add(THE_END_OF_TIME); } - for (PValue pvalue : inputs.values()) { + for (PValue pvalue : inputs) { Watermark producerOutputWatermark = getValueWatermark(pvalue).outputWatermark; inputWatermarksBuilder.add(producerOutputWatermark); } @@ -987,7 +987,7 @@ private void updatePending( // refresh. for (CommittedBundle bundle : result.getOutputs()) { for (AppliedPTransform consumer : - graph.getPrimitiveConsumers(bundle.getPCollection())) { + graph.getPerElementConsumers(bundle.getPCollection())) { TransformWatermarks watermarks = transformToWatermarks.get(consumer); watermarks.addPending(bundle); } @@ -1035,7 +1035,7 @@ synchronized void refreshAll() { if (updateResult.isAdvanced()) { Set> additionalRefreshes = new HashSet<>(); for (PValue outputPValue : toRefresh.getOutputs().values()) { - additionalRefreshes.addAll(graph.getPrimitiveConsumers(outputPValue)); + additionalRefreshes.addAll(graph.getPerElementConsumers(outputPValue)); } return additionalRefreshes; } diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectGraphVisitorTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectGraphVisitorTest.java index 576edf364fae5..bf3e83e88fbe5 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectGraphVisitorTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectGraphVisitorTest.java @@ -151,13 +151,13 @@ public void processElement(DoFn.ProcessContext c) graph.getProducer(flattened); assertThat( - graph.getPrimitiveConsumers(created), + graph.getPerElementConsumers(created), Matchers.>containsInAnyOrder( transformedProducer, flattenedProducer)); assertThat( - graph.getPrimitiveConsumers(transformed), + graph.getPerElementConsumers(transformed), Matchers.>containsInAnyOrder(flattenedProducer)); - assertThat(graph.getPrimitiveConsumers(flattened), emptyIterable()); + assertThat(graph.getPerElementConsumers(flattened), emptyIterable()); } @Test @@ -173,10 +173,10 @@ public void getValueToConsumersWithDuplicateInputSucceeds() { AppliedPTransform flattenedProducer = graph.getProducer(flattened); assertThat( - graph.getPrimitiveConsumers(created), + graph.getPerElementConsumers(created), Matchers.>containsInAnyOrder(flattenedProducer, flattenedProducer)); - assertThat(graph.getPrimitiveConsumers(flattened), emptyIterable()); + assertThat(graph.getPerElementConsumers(flattened), emptyIterable()); } @Test diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/EvaluationContextTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/EvaluationContextTest.java index f3edf552b27e8..699a31870d65e 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/EvaluationContextTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/EvaluationContextTest.java @@ -414,7 +414,7 @@ public void isDoneWithPartiallyDone() { StepTransformResult.withoutHold(unboundedProducer).build()); assertThat(context.isDone(), is(false)); - for (AppliedPTransform consumers : graph.getPrimitiveConsumers(created)) { + for (AppliedPTransform consumers : graph.getPerElementConsumers(created)) { context.handleResult( committedBundle, ImmutableList.of(), diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ParDoEvaluatorTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ParDoEvaluatorTest.java index df84cbf6f2c9a..7912538d58091 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ParDoEvaluatorTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ParDoEvaluatorTest.java @@ -98,7 +98,7 @@ public void sideInputsNotReadyResultHasUnprocessedElements() { when(evaluationContext.createBundle(output)).thenReturn(outputBundle); ParDoEvaluator evaluator = - createEvaluator(singletonView, fn, output); + createEvaluator(singletonView, fn, inputPc, output); IntervalWindow nonGlobalWindow = new IntervalWindow(new Instant(0), new Instant(10_000L)); WindowedValue first = WindowedValue.valueInGlobalWindow(3); @@ -132,6 +132,7 @@ public void sideInputsNotReadyResultHasUnprocessedElements() { private ParDoEvaluator createEvaluator( PCollectionView singletonView, RecorderFn fn, + PCollection input, PCollection output) { when( evaluationContext.createSideInputReader( @@ -157,8 +158,7 @@ private ParDoEvaluator createEvaluator( evaluationContext, stepContext, transform, - ((PCollection) Iterables.getOnlyElement(transform.getInputs().values())) - .getWindowingStrategy(), + input.getWindowingStrategy(), fn, null /* key */, ImmutableList.>of(singletonView), diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchTranslationContext.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchTranslationContext.java index 0439119dfc40b..6e7019848b194 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchTranslationContext.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchTranslationContext.java @@ -20,6 +20,7 @@ import com.google.common.collect.Iterables; import java.util.HashMap; import java.util.Map; +import org.apache.beam.runners.core.construction.TransformInputs; import org.apache.beam.runners.flink.translation.types.CoderTypeInformation; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.options.PipelineOptions; @@ -143,7 +144,7 @@ Map, PValue> getInputs(PTransform transform) { @SuppressWarnings("unchecked") T getInput(PTransform transform) { - return (T) Iterables.getOnlyElement(currentTransform.getInputs().values()); + return (T) Iterables.getOnlyElement(TransformInputs.nonAdditionalInputs(currentTransform)); } Map, PValue> getOutputs(PTransform transform) { diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTranslationContext.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTranslationContext.java index ea5f6b3162af2..74a5fb971144a 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTranslationContext.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTranslationContext.java @@ -22,6 +22,7 @@ import com.google.common.collect.Iterables; import java.util.HashMap; import java.util.Map; +import org.apache.beam.runners.core.construction.TransformInputs; import org.apache.beam.runners.flink.translation.types.CoderTypeInformation; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.options.PipelineOptions; @@ -113,7 +114,7 @@ public TypeInformation> getTypeInfo(PCollection collecti @SuppressWarnings("unchecked") public T getInput(PTransform transform) { - return (T) Iterables.getOnlyElement(currentTransform.getInputs().values()); + return (T) Iterables.getOnlyElement(TransformInputs.nonAdditionalInputs(currentTransform)); } public Map, PValue> getInputs(PTransform transform) { diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java index a3a7ab6bb1611..afc34e6fc8833 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java @@ -56,6 +56,7 @@ import java.util.Map; import java.util.concurrent.atomic.AtomicLong; import javax.annotation.Nullable; +import org.apache.beam.runners.core.construction.TransformInputs; import org.apache.beam.runners.core.construction.WindowingStrategyTranslation; import org.apache.beam.runners.dataflow.BatchViewOverrides.GroupByKeyAndSortValuesOnly; import org.apache.beam.runners.dataflow.DataflowRunner.CombineGroupedValues; @@ -395,7 +396,9 @@ public Map, PValue> getInputs( @Override public InputT getInput(PTransform transform) { - return (InputT) Iterables.getOnlyElement(getInputs(transform).values()); + return (InputT) + Iterables.getOnlyElement( + TransformInputs.nonAdditionalInputs(getCurrentTransform(transform))); } @Override diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java index 8102926f6daad..0c6c4d1cb6607 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java @@ -26,6 +26,7 @@ import java.util.LinkedHashSet; import java.util.Map; import java.util.Set; +import org.apache.beam.runners.core.construction.TransformInputs; import org.apache.beam.runners.spark.SparkPipelineOptions; import org.apache.beam.runners.spark.coders.CoderHelpers; import org.apache.beam.sdk.Pipeline; @@ -103,7 +104,8 @@ public void setCurrentTransform(AppliedPTransform transform) { public T getInput(PTransform transform) { @SuppressWarnings("unchecked") - T input = (T) Iterables.getOnlyElement(getInputs(transform).values()); + T input = + (T) Iterables.getOnlyElement(TransformInputs.nonAdditionalInputs(getCurrentTransform())); return input; } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformHierarchy.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformHierarchy.java index 5e048ebb08197..9c5f14843c240 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformHierarchy.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformHierarchy.java @@ -34,7 +34,6 @@ import java.util.Map; import java.util.Map.Entry; import java.util.Set; -import javax.annotation.Nullable; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.Pipeline.PipelineVisitor; import org.apache.beam.sdk.Pipeline.PipelineVisitor.CompositeBehavior; @@ -71,7 +70,7 @@ public TransformHierarchy() { producers = new HashMap<>(); producerInput = new HashMap<>(); unexpandedInputs = new HashMap<>(); - root = new Node(null, null, "", null); + root = new Node(); current = root; } @@ -296,26 +295,37 @@ public class Node { @VisibleForTesting boolean finishedSpecifying = false; + /** + * Creates the root-level node. The root level node has a null enclosing node, a null transform, + * an empty map of inputs, and a name equal to the empty string. + */ + private Node() { + this.enclosingNode = null; + this.transform = null; + this.fullName = ""; + this.inputs = Collections.emptyMap(); + } + /** * Creates a new Node with the given parent and transform. * - *

EnclosingNode and transform may both be null for a root-level node, which holds all other - * nodes. - * * @param enclosingNode the composite node containing this node * @param transform the PTransform tracked by this node * @param fullName the fully qualified name of the transform * @param input the unexpanded input to the transform */ private Node( - @Nullable Node enclosingNode, - @Nullable PTransform transform, + Node enclosingNode, + PTransform transform, String fullName, - @Nullable PInput input) { + PInput input) { this.enclosingNode = enclosingNode; this.transform = transform; this.fullName = fullName; - this.inputs = input == null ? Collections., PValue>emptyMap() : input.expand(); + ImmutableMap.Builder, PValue> inputs = ImmutableMap.builder(); + inputs.putAll(input.expand()); + inputs.putAll(transform.getAdditionalInputs()); + this.inputs = inputs.build(); } /** From 8c5b57ea8445cd50a35c6dffb460dcf0f426e700 Mon Sep 17 00:00:00 2001 From: Kenneth Knowles Date: Fri, 26 May 2017 14:26:55 -0700 Subject: [PATCH 019/200] Port ViewOverrideFactory to SDK-agnostic APIs --- .../CreatePCollectionViewTranslation.java | 4 +- .../runners/direct/ViewOverrideFactory.java | 48 +++++++++++-------- .../direct/ViewEvaluatorFactoryTest.java | 3 +- .../direct/ViewOverrideFactoryTest.java | 23 +++++++-- .../beam/sdk/values/PCollectionViews.java | 10 ++++ 5 files changed, 62 insertions(+), 26 deletions(-) diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CreatePCollectionViewTranslation.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CreatePCollectionViewTranslation.java index aa24909c2bf6c..8fc99b9f480f2 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CreatePCollectionViewTranslation.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CreatePCollectionViewTranslation.java @@ -56,8 +56,8 @@ public class CreatePCollectionViewTranslation { @Deprecated public static PCollectionView getView( AppliedPTransform< - PCollection, PCollectionView, - PTransform, PCollectionView>> + PCollection, PCollection, + PTransform, PCollection>> application) throws IOException { diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ViewOverrideFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ViewOverrideFactory.java index 06a73889a1aeb..5dcf0165c0f07 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ViewOverrideFactory.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ViewOverrideFactory.java @@ -18,8 +18,9 @@ package org.apache.beam.runners.direct; +import java.io.IOException; import java.util.Map; -import org.apache.beam.runners.core.construction.ForwardingPTransform; +import org.apache.beam.runners.core.construction.CreatePCollectionViewTranslation; import org.apache.beam.runners.core.construction.PTransformReplacements; import org.apache.beam.runners.core.construction.PTransformTranslation.RawPTransform; import org.apache.beam.runners.core.construction.ReplacementOutputs; @@ -43,16 +44,30 @@ */ class ViewOverrideFactory implements PTransformOverrideFactory< - PCollection, PCollection, CreatePCollectionView> { + PCollection, PCollection, + PTransform, PCollection>> { @Override public PTransformReplacement, PCollection> getReplacementTransform( AppliedPTransform< - PCollection, PCollection, CreatePCollectionView> + PCollection, PCollection, + PTransform, PCollection>> transform) { - return PTransformReplacement.of( + + PCollectionView view; + try { + view = CreatePCollectionViewTranslation.getView(transform); + } catch (IOException exc) { + throw new RuntimeException( + String.format( + "Could not extract %s from transform %s", + PCollectionView.class.getSimpleName(), transform), + exc); + } + + return PTransformReplacement.of( PTransformReplacements.getSingletonMainInput(transform), - new GroupAndWriteView<>(transform.getTransform())); + new GroupAndWriteView(view)); } @Override @@ -63,11 +78,11 @@ public Map mapOutputs( /** The {@link DirectRunner} composite override for {@link CreatePCollectionView}. */ static class GroupAndWriteView - extends ForwardingPTransform, PCollection> { - private final CreatePCollectionView og; + extends PTransform, PCollection> { + private final PCollectionView view; - private GroupAndWriteView(CreatePCollectionView og) { - this.og = og; + private GroupAndWriteView(PCollectionView view) { + this.view = view; } @Override @@ -77,14 +92,9 @@ public PCollection expand(final PCollection input) { .setCoder(KvCoder.of(VoidCoder.of(), input.getCoder())) .apply(GroupByKey.create()) .apply(Values.>create()) - .apply(new WriteView(og)); + .apply(new WriteView(view)); return input; } - - @Override - protected PTransform, PCollection> delegate() { - return og; - } } /** @@ -96,10 +106,10 @@ protected PTransform, PCollection> delegate() { */ static final class WriteView extends RawPTransform>, PCollection>> { - private final CreatePCollectionView og; + private final PCollectionView view; - WriteView(CreatePCollectionView og) { - this.og = og; + WriteView(PCollectionView view) { + this.view = view; } @Override @@ -112,7 +122,7 @@ public PCollection> expand(PCollection> input) { @SuppressWarnings("deprecation") public PCollectionView getView() { - return og.getView(); + return view; } @Override diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ViewEvaluatorFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ViewEvaluatorFactoryTest.java index ad1aecce623b1..5bc48b7104c02 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ViewEvaluatorFactoryTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ViewEvaluatorFactoryTest.java @@ -66,7 +66,8 @@ public void testInMemoryEvaluator() throws Exception { .apply(GroupByKey.create()) .apply(Values.>create()); PCollection> view = - concat.apply(new ViewOverrideFactory.WriteView<>(createView)); + concat.apply( + new ViewOverrideFactory.WriteView>(createView.getView())); EvaluationContext context = mock(EvaluationContext.class); TestViewWriter> viewWriter = new TestViewWriter<>(); diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ViewOverrideFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ViewOverrideFactoryTest.java index 94728c7909209..6af9273ef021b 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ViewOverrideFactoryTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ViewOverrideFactoryTest.java @@ -18,6 +18,7 @@ package org.apache.beam.runners.direct; +import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.is; import static org.junit.Assert.assertThat; @@ -36,8 +37,11 @@ import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.View.CreatePCollectionView; +import org.apache.beam.sdk.transforms.ViewFn; +import org.apache.beam.sdk.transforms.windowing.WindowMappingFn; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.PCollectionViews; @@ -67,7 +71,7 @@ public void replacementSucceeds() { factory.getReplacementTransform( AppliedPTransform ., PCollection, - CreatePCollectionView>> + PTransform, PCollection>> of( "foo", ints.expand(), @@ -102,7 +106,7 @@ public void replacementGetViewReturnsOriginal() { factory.getReplacementTransform( AppliedPTransform ., PCollection, - CreatePCollectionView>> + PTransform, PCollection>> of( "foo", ints.expand(), @@ -120,8 +124,19 @@ public void visitPrimitiveTransform(Node node) { "There should only be one WriteView primitive in the graph", writeViewVisited.getAndSet(true), is(false)); - PCollectionView replacementView = ((WriteView) node.getTransform()).getView(); - assertThat(replacementView, Matchers.theInstance(view)); + PCollectionView replacementView = ((WriteView) node.getTransform()).getView(); + + // replacementView.getPCollection() is null, but that is not a requirement + // so not asserted one way or the other + assertThat( + replacementView.getTagInternal(), + equalTo(view.getTagInternal())); + assertThat( + replacementView.getViewFn(), + Matchers.>equalTo(view.getViewFn())); + assertThat( + replacementView.getWindowMappingFn(), + Matchers.>equalTo(view.getWindowMappingFn())); assertThat(node.getInputs().entrySet(), hasSize(1)); } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionViews.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionViews.java index 5e2e2c316e2e1..0c04370a11c51 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionViews.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionViews.java @@ -282,6 +282,16 @@ public T apply(WindowedValue input) { } })); } + + @Override + public boolean equals(Object other) { + return other instanceof ListViewFn; + } + + @Override + public int hashCode() { + return ListViewFn.class.hashCode(); + } } /** From 02dbaefd2bbad0f0ff0b87469d184137b220fae7 Mon Sep 17 00:00:00 2001 From: Kenneth Knowles Date: Fri, 26 May 2017 14:27:23 -0700 Subject: [PATCH 020/200] Port DirectGroupByKey to SDK-agnostic APIs --- .../beam/runners/direct/DirectGroupByKey.java | 13 +++++++------ .../direct/DirectGroupByKeyOverrideFactory.java | 14 +++++++++++--- 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGroupByKey.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGroupByKey.java index 2fc0dd4ae409a..06b8e29962be4 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGroupByKey.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGroupByKey.java @@ -36,13 +36,17 @@ class DirectGroupByKey extends ForwardingPTransform>, PCollection>>> { - private final GroupByKey original; + private final PTransform>, PCollection>>> original; static final String DIRECT_GBKO_URN = "urn:beam:directrunner:transforms:gbko:v1"; static final String DIRECT_GABW_URN = "urn:beam:directrunner:transforms:gabw:v1"; + private final WindowingStrategy outputWindowingStrategy; - DirectGroupByKey(GroupByKey from) { - this.original = from; + DirectGroupByKey( + PTransform>, PCollection>>> original, + WindowingStrategy outputWindowingStrategy) { + this.original = original; + this.outputWindowingStrategy = outputWindowingStrategy; } @Override @@ -57,9 +61,6 @@ public PCollection>> expand(PCollection> input) { // key/value input elements and the window merge operation of the // window function associated with the input PCollection. WindowingStrategy inputWindowingStrategy = input.getWindowingStrategy(); - // Update the windowing strategy as appropriate. - WindowingStrategy outputWindowingStrategy = - original.updateWindowingStrategy(inputWindowingStrategy); // By default, implement GroupByKey via a series of lower-level operations. return input diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGroupByKeyOverrideFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGroupByKeyOverrideFactory.java index c2eb5e72b8425..9c2de3dc33794 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGroupByKeyOverrideFactory.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGroupByKeyOverrideFactory.java @@ -17,26 +17,34 @@ */ package org.apache.beam.runners.direct; +import com.google.common.collect.Iterables; import org.apache.beam.runners.core.construction.PTransformReplacements; import org.apache.beam.runners.core.construction.SingleInputOutputOverrideFactory; import org.apache.beam.sdk.runners.AppliedPTransform; import org.apache.beam.sdk.runners.PTransformOverrideFactory; import org.apache.beam.sdk.transforms.GroupByKey; +import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; /** A {@link PTransformOverrideFactory} for {@link GroupByKey} PTransforms. */ final class DirectGroupByKeyOverrideFactory extends SingleInputOutputOverrideFactory< - PCollection>, PCollection>>, GroupByKey> { + PCollection>, PCollection>>, + PTransform>, PCollection>>>> { @Override public PTransformReplacement>, PCollection>>> getReplacementTransform( AppliedPTransform< - PCollection>, PCollection>>, GroupByKey> + PCollection>, PCollection>>, + PTransform>, PCollection>>>> transform) { + + PCollection>> output = + (PCollection>>) Iterables.getOnlyElement(transform.getOutputs().values()); + return PTransformReplacement.of( PTransformReplacements.getSingletonMainInput(transform), - new DirectGroupByKey<>(transform.getTransform())); + new DirectGroupByKey<>(transform.getTransform(), output.getWindowingStrategy())); } } From ed6bd18bffe8a51d5fc2a59ff9aaa731b196d58a Mon Sep 17 00:00:00 2001 From: Kenneth Knowles Date: Fri, 26 May 2017 16:07:45 -0700 Subject: [PATCH 021/200] Port DirectRunner WriteFiles override to SDK-agnostic APIs --- .../core/construction/PTransformMatchers.java | 17 ++++++++--- .../direct/WriteWithShardingFactory.java | 30 +++++++++++++------ .../direct/WriteWithShardingFactoryTest.java | 26 +++++++++++----- 3 files changed, 52 insertions(+), 21 deletions(-) diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformMatchers.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformMatchers.java index c339891d51eda..0d272411f079c 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformMatchers.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformMatchers.java @@ -17,13 +17,14 @@ */ package org.apache.beam.runners.core.construction; +import static org.apache.beam.runners.core.construction.PTransformTranslation.WRITE_FILES_TRANSFORM_URN; + import com.google.common.base.MoreObjects; import java.io.IOException; import java.util.HashSet; import java.util.Set; import org.apache.beam.sdk.annotations.Experimental; import org.apache.beam.sdk.annotations.Experimental.Kind; -import org.apache.beam.sdk.io.WriteFiles; import org.apache.beam.sdk.runners.AppliedPTransform; import org.apache.beam.sdk.runners.PTransformMatcher; import org.apache.beam.sdk.transforms.DoFn; @@ -359,10 +360,18 @@ public static PTransformMatcher writeWithRunnerDeterminedSharding() { return new PTransformMatcher() { @Override public boolean matches(AppliedPTransform application) { - if (PTransformTranslation.WRITE_FILES_TRANSFORM_URN.equals( + if (WRITE_FILES_TRANSFORM_URN.equals( PTransformTranslation.urnForTransformOrNull(application.getTransform()))) { - WriteFiles write = (WriteFiles) application.getTransform(); - return write.getSharding() == null && write.getNumShards() == null; + try { + return WriteFilesTranslation.isRunnerDeterminedSharding( + (AppliedPTransform) application); + } catch (IOException exc) { + throw new RuntimeException( + String.format( + "Transform with URN %s failed to parse: %s", + WRITE_FILES_TRANSFORM_URN, application.getTransform()), + exc); + } } return false; } diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java index 65a5a19382c29..d8734a1c55544 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java @@ -21,11 +21,13 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Supplier; import com.google.common.base.Suppliers; +import java.io.IOException; import java.io.Serializable; import java.util.Collections; import java.util.Map; import java.util.concurrent.ThreadLocalRandom; import org.apache.beam.runners.core.construction.PTransformReplacements; +import org.apache.beam.runners.core.construction.WriteFilesTranslation; import org.apache.beam.sdk.io.WriteFiles; import org.apache.beam.sdk.runners.AppliedPTransform; import org.apache.beam.sdk.runners.PTransformOverrideFactory; @@ -43,23 +45,33 @@ import org.apache.beam.sdk.values.TupleTag; /** - * A {@link PTransformOverrideFactory} that overrides {@link WriteFiles} - * {@link PTransform PTransforms} with an unspecified number of shards with a write with a - * specified number of shards. The number of shards is the log base 10 of the number of input - * records, with up to 2 additional shards. + * A {@link PTransformOverrideFactory} that overrides {@link WriteFiles} {@link PTransform + * PTransforms} with an unspecified number of shards with a write with a specified number of shards. + * The number of shards is the log base 10 of the number of input records, with up to 2 additional + * shards. */ class WriteWithShardingFactory - implements PTransformOverrideFactory, PDone, WriteFiles> { + implements PTransformOverrideFactory< + PCollection, PDone, PTransform, PDone>> { static final int MAX_RANDOM_EXTRA_SHARDS = 3; @VisibleForTesting static final int MIN_SHARDS_FOR_LOG = 3; @Override public PTransformReplacement, PDone> getReplacementTransform( - AppliedPTransform, PDone, WriteFiles> transform) { + AppliedPTransform, PDone, PTransform, PDone>> + transform) { - return PTransformReplacement.of( - PTransformReplacements.getSingletonMainInput(transform), - transform.getTransform().withSharding(new LogElementShardsWithDrift())); + try { + WriteFiles replacement = WriteFiles.to(WriteFilesTranslation.getSink(transform)); + if (WriteFilesTranslation.isWindowedWrites(transform)) { + replacement = replacement.withWindowedWrites(); + } + return PTransformReplacement.of( + PTransformReplacements.getSingletonMainInput(transform), + replacement.withSharding(new LogElementShardsWithDrift())); + } catch (IOException e) { + throw new RuntimeException(e); + } } @Override diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WriteWithShardingFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WriteWithShardingFactoryTest.java index a88d95e5e5c8d..41d671f5c8e17 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WriteWithShardingFactoryTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WriteWithShardingFactoryTest.java @@ -30,6 +30,7 @@ import java.io.File; import java.io.FileReader; import java.io.Reader; +import java.io.Serializable; import java.nio.CharBuffer; import java.util.ArrayList; import java.util.Collections; @@ -53,6 +54,7 @@ import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.DoFnTester; +import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.windowing.GlobalWindow; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionView; @@ -71,11 +73,17 @@ * Tests for {@link WriteWithShardingFactory}. */ @RunWith(JUnit4.class) -public class WriteWithShardingFactoryTest { +public class WriteWithShardingFactoryTest implements Serializable { + private static final int INPUT_SIZE = 10000; - @Rule public TemporaryFolder tmp = new TemporaryFolder(); - private WriteWithShardingFactory factory = new WriteWithShardingFactory<>(); - @Rule public final TestPipeline p = TestPipeline.create().enableAbandonedNodeEnforcement(false); + + @Rule public transient TemporaryFolder tmp = new TemporaryFolder(); + + private transient WriteWithShardingFactory factory = new WriteWithShardingFactory<>(); + + @Rule + public final transient TestPipeline p = + TestPipeline.create().enableAbandonedNodeEnforcement(false); @Test public void dynamicallyReshardedWrite() throws Exception { @@ -135,7 +143,8 @@ public void withNoShardingSpecifiedReturnsNewTransform() { DefaultFilenamePolicy.DEFAULT_UNWINDOWED_SHARD_TEMPLATE, "", false); - WriteFiles original = + + PTransform, PDone> original = WriteFiles.to( new FileBasedSink(StaticValueProvider.of(outputDirectory), policy) { @Override @@ -146,9 +155,10 @@ public WriteOperation createWriteOperation() { @SuppressWarnings("unchecked") PCollection objs = (PCollection) p.apply(Create.empty(VoidCoder.of())); - AppliedPTransform, PDone, WriteFiles> originalApplication = - AppliedPTransform.of( - "write", objs.expand(), Collections., PValue>emptyMap(), original, p); + AppliedPTransform, PDone, PTransform, PDone>> + originalApplication = + AppliedPTransform.of( + "write", objs.expand(), Collections., PValue>emptyMap(), original, p); assertThat( factory.getReplacementTransform(originalApplication).getTransform(), From eaaf45fa33d500a9f0fd0c2861aac4889ee5086c Mon Sep 17 00:00:00 2001 From: Kenneth Knowles Date: Thu, 8 Jun 2017 13:39:32 -0700 Subject: [PATCH 022/200] Port DirectRunner TestStream override to SDK-agnostic APIs --- .../construction/TestStreamTranslation.java | 49 ++++++++++++++++++- .../direct/TestStreamEvaluatorFactory.java | 20 ++++++-- .../apache/beam/sdk/testing/TestStream.java | 12 +++++ 3 files changed, 75 insertions(+), 6 deletions(-) diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/TestStreamTranslation.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/TestStreamTranslation.java index 90e63047cab17..515de575ecb1f 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/TestStreamTranslation.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/TestStreamTranslation.java @@ -18,6 +18,9 @@ package org.apache.beam.runners.core.construction; +import static com.google.common.base.Preconditions.checkArgument; +import static org.apache.beam.runners.core.construction.PTransformTranslation.TEST_STREAM_TRANSFORM_URN; + import com.google.auto.service.AutoService; import com.google.protobuf.Any; import com.google.protobuf.ByteString; @@ -33,6 +36,8 @@ import org.apache.beam.sdk.testing.TestStream; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.util.CoderUtils; +import org.apache.beam.sdk.values.PBegin; +import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.TimestampedValue; import org.joda.time.Duration; import org.joda.time.Instant; @@ -57,6 +62,48 @@ static RunnerApi.TestStreamPayload testStreamToPayload( return builder.build(); } + private static TestStream fromProto( + RunnerApi.TestStreamPayload testStreamPayload, RunnerApi.Components components) + throws IOException { + + Coder coder = + (Coder) + CoderTranslation.fromProto( + components.getCodersOrThrow(testStreamPayload.getCoderId()), components); + + List> events = new ArrayList<>(); + + for (RunnerApi.TestStreamPayload.Event event : testStreamPayload.getEventsList()) { + events.add(fromProto(event, coder)); + } + return TestStream.fromRawEvents(coder, events); + } + + /** + * Converts an {@link AppliedPTransform}, which may be a rehydrated transform or an original + * {@link TestStream}, to a {@link TestStream}. + */ + public static TestStream getTestStream( + AppliedPTransform, PTransform>> application) + throws IOException { + // For robustness, we don't take this shortcut: + // if (application.getTransform() instanceof TestStream) { + // return application.getTransform() + // } + + SdkComponents sdkComponents = SdkComponents.create(); + RunnerApi.PTransform transformProto = PTransformTranslation.toProto(application, sdkComponents); + checkArgument( + TEST_STREAM_TRANSFORM_URN.equals(transformProto.getSpec().getUrn()), + "Attempt to get %s from a transform with wrong URN %s", + TestStream.class.getSimpleName(), + transformProto.getSpec().getUrn()); + RunnerApi.TestStreamPayload testStreamPayload = + transformProto.getSpec().getParameter().unpack(RunnerApi.TestStreamPayload.class); + + return (TestStream) fromProto(testStreamPayload, sdkComponents.toComponents()); + } + static RunnerApi.TestStreamPayload.Event toProto(TestStream.Event event, Coder coder) throws IOException { switch (event.getType()) { @@ -130,7 +177,7 @@ static TestStream.Event fromProto( static class TestStreamTranslator implements TransformPayloadTranslator> { @Override public String getUrn(TestStream transform) { - return PTransformTranslation.TEST_STREAM_TRANSFORM_URN; + return TEST_STREAM_TRANSFORM_URN; } @Override diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactory.java index 2da7a71c3d425..16c8589c5f749 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactory.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactory.java @@ -22,6 +22,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Supplier; import com.google.common.collect.Iterables; +import java.io.IOException; import java.util.Collection; import java.util.Collections; import java.util.List; @@ -30,6 +31,7 @@ import javax.annotation.Nullable; import org.apache.beam.runners.core.construction.PTransformTranslation; import org.apache.beam.runners.core.construction.ReplacementOutputs; +import org.apache.beam.runners.core.construction.TestStreamTranslation; import org.apache.beam.sdk.runners.AppliedPTransform; import org.apache.beam.sdk.runners.PTransformOverrideFactory; import org.apache.beam.sdk.testing.TestStream; @@ -160,7 +162,8 @@ public Clock get() { } static class DirectTestStreamFactory - implements PTransformOverrideFactory, TestStream> { + implements PTransformOverrideFactory< + PBegin, PCollection, PTransform>> { private final DirectRunner runner; DirectTestStreamFactory(DirectRunner runner) { @@ -169,10 +172,17 @@ static class DirectTestStreamFactory @Override public PTransformReplacement> getReplacementTransform( - AppliedPTransform, TestStream> transform) { - return PTransformReplacement.of( - transform.getPipeline().begin(), - new DirectTestStream(runner, transform.getTransform())); + AppliedPTransform, PTransform>> transform) { + try { + return PTransformReplacement.of( + transform.getPipeline().begin(), + new DirectTestStream(runner, TestStreamTranslation.getTestStream(transform))); + } catch (IOException exc) { + throw new RuntimeException( + String.format( + "Transform could not be converted to %s", TestStream.class.getSimpleName()), + exc); + } } @Override diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/TestStream.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/TestStream.java index 9ad8fd8ea6493..d13fcf1e86b7f 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/TestStream.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/TestStream.java @@ -271,6 +271,18 @@ public List> getEvents() { return events; } + /** + * For internal use only. No backwards-compatibility guarantees. + * + *

Builder a test stream directly from events. No validation is performed on + * watermark monotonicity, etc. This is assumed to be a previously-serialized + * {@link TestStream} transform that is correct by construction. + */ + @Internal + public static TestStream fromRawEvents(Coder coder, List> events) { + return new TestStream<>(coder, events); + } + @Override public boolean equals(Object other) { if (!(other instanceof TestStream)) { From 8362bdb9cd35cc02ed179b3a64fd72f1264a99be Mon Sep 17 00:00:00 2001 From: JingsongLi Date: Thu, 8 Jun 2017 01:31:34 +0800 Subject: [PATCH 023/200] [BEAM-2423] Abstract StateInternalsTest for the different state internals --- pom.xml | 7 + .../core/InMemoryStateInternalsTest.java | 555 ++--------------- .../beam/runners/core/StateInternalsTest.java | 573 ++++++++++++++++++ runners/flink/pom.xml | 8 + .../streaming/FlinkStateInternalsTest.java | 348 +---------- 5 files changed, 641 insertions(+), 850 deletions(-) create mode 100644 runners/core-java/src/test/java/org/apache/beam/runners/core/StateInternalsTest.java diff --git a/pom.xml b/pom.xml index 805a8d64e9c86..9373a40b19dd5 100644 --- a/pom.xml +++ b/pom.xml @@ -509,6 +509,13 @@ ${project.version} + + org.apache.beam + beam-runners-core-java + ${project.version} + test-jar + + org.apache.beam beam-runners-direct-java diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/InMemoryStateInternalsTest.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/InMemoryStateInternalsTest.java index b526305cea00b..335c2f853c97a 100644 --- a/runners/core-java/src/test/java/org/apache/beam/runners/core/InMemoryStateInternalsTest.java +++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/InMemoryStateInternalsTest.java @@ -17,545 +17,58 @@ */ package org.apache.beam.runners.core; -import static org.hamcrest.Matchers.containsInAnyOrder; -import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.hasItems; -import static org.hamcrest.Matchers.not; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertNull; import static org.junit.Assert.assertThat; -import static org.junit.Assert.assertTrue; -import java.util.Arrays; -import java.util.Map; -import java.util.Objects; -import org.apache.beam.sdk.coders.StringUtf8Coder; -import org.apache.beam.sdk.coders.VarIntCoder; -import org.apache.beam.sdk.state.BagState; -import org.apache.beam.sdk.state.CombiningState; -import org.apache.beam.sdk.state.GroupingState; -import org.apache.beam.sdk.state.MapState; -import org.apache.beam.sdk.state.ReadableState; -import org.apache.beam.sdk.state.SetState; -import org.apache.beam.sdk.state.ValueState; -import org.apache.beam.sdk.state.WatermarkHoldState; -import org.apache.beam.sdk.transforms.Sum; -import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.transforms.windowing.IntervalWindow; -import org.apache.beam.sdk.transforms.windowing.TimestampCombiner; +import org.apache.beam.sdk.state.State; import org.hamcrest.Matchers; -import org.joda.time.Instant; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.junit.runners.Suite; /** - * Tests for {@link InMemoryStateInternals}. + * Tests for {@link InMemoryStateInternals}. This is based on {@link StateInternalsTest}. */ -@RunWith(JUnit4.class) +@RunWith(Suite.class) +@Suite.SuiteClasses({ + InMemoryStateInternalsTest.StandardStateInternalsTests.class, + InMemoryStateInternalsTest.OtherTests.class +}) public class InMemoryStateInternalsTest { - private static final BoundedWindow WINDOW_1 = new IntervalWindow(new Instant(0), new Instant(10)); - private static final StateNamespace NAMESPACE_1 = new StateNamespaceForTest("ns1"); - private static final StateNamespace NAMESPACE_2 = new StateNamespaceForTest("ns2"); - private static final StateNamespace NAMESPACE_3 = new StateNamespaceForTest("ns3"); - private static final StateTag> STRING_VALUE_ADDR = - StateTags.value("stringValue", StringUtf8Coder.of()); - private static final StateTag> - SUM_INTEGER_ADDR = StateTags.combiningValueFromInputInternal( - "sumInteger", VarIntCoder.of(), Sum.ofIntegers()); - private static final StateTag> STRING_BAG_ADDR = - StateTags.bag("stringBag", StringUtf8Coder.of()); - private static final StateTag> STRING_SET_ADDR = - StateTags.set("stringSet", StringUtf8Coder.of()); - private static final StateTag> STRING_MAP_ADDR = - StateTags.map("stringMap", StringUtf8Coder.of(), VarIntCoder.of()); - private static final StateTag WATERMARK_EARLIEST_ADDR = - StateTags.watermarkStateInternal("watermark", TimestampCombiner.EARLIEST); - private static final StateTag WATERMARK_LATEST_ADDR = - StateTags.watermarkStateInternal("watermark", TimestampCombiner.LATEST); - private static final StateTag WATERMARK_EOW_ADDR = - StateTags.watermarkStateInternal("watermark", TimestampCombiner.END_OF_WINDOW); - - InMemoryStateInternals underTest = InMemoryStateInternals.forKey("dummyKey"); - - @Test - public void testValue() throws Exception { - ValueState value = underTest.state(NAMESPACE_1, STRING_VALUE_ADDR); - - // State instances are cached, but depend on the namespace. - assertThat(underTest.state(NAMESPACE_1, STRING_VALUE_ADDR), Matchers.sameInstance(value)); - assertThat( - underTest.state(NAMESPACE_2, STRING_VALUE_ADDR), - Matchers.not(Matchers.sameInstance(value))); - - assertThat(value.read(), Matchers.nullValue()); - value.write("hello"); - assertThat(value.read(), equalTo("hello")); - value.write("world"); - assertThat(value.read(), equalTo("world")); - - value.clear(); - assertThat(value.read(), Matchers.nullValue()); - assertThat(underTest.state(NAMESPACE_1, STRING_VALUE_ADDR), Matchers.sameInstance(value)); - } - - @Test - public void testBag() throws Exception { - BagState value = underTest.state(NAMESPACE_1, STRING_BAG_ADDR); - - // State instances are cached, but depend on the namespace. - assertThat(value, equalTo(underTest.state(NAMESPACE_1, STRING_BAG_ADDR))); - assertThat(value, not(equalTo(underTest.state(NAMESPACE_2, STRING_BAG_ADDR)))); - - assertThat(value.read(), Matchers.emptyIterable()); - value.add("hello"); - assertThat(value.read(), containsInAnyOrder("hello")); - - value.add("world"); - assertThat(value.read(), containsInAnyOrder("hello", "world")); - - value.clear(); - assertThat(value.read(), Matchers.emptyIterable()); - assertThat(underTest.state(NAMESPACE_1, STRING_BAG_ADDR), Matchers.sameInstance(value)); - } - - @Test - public void testBagIsEmpty() throws Exception { - BagState value = underTest.state(NAMESPACE_1, STRING_BAG_ADDR); - - assertThat(value.isEmpty().read(), Matchers.is(true)); - ReadableState readFuture = value.isEmpty(); - value.add("hello"); - assertThat(readFuture.read(), Matchers.is(false)); - - value.clear(); - assertThat(readFuture.read(), Matchers.is(true)); - } - - @Test - public void testMergeBagIntoSource() throws Exception { - BagState bag1 = underTest.state(NAMESPACE_1, STRING_BAG_ADDR); - BagState bag2 = underTest.state(NAMESPACE_2, STRING_BAG_ADDR); - - bag1.add("Hello"); - bag2.add("World"); - bag1.add("!"); - - StateMerging.mergeBags(Arrays.asList(bag1, bag2), bag1); - - // Reading the merged bag gets both the contents - assertThat(bag1.read(), containsInAnyOrder("Hello", "World", "!")); - assertThat(bag2.read(), Matchers.emptyIterable()); - } - - @Test - public void testMergeBagIntoNewNamespace() throws Exception { - BagState bag1 = underTest.state(NAMESPACE_1, STRING_BAG_ADDR); - BagState bag2 = underTest.state(NAMESPACE_2, STRING_BAG_ADDR); - BagState bag3 = underTest.state(NAMESPACE_3, STRING_BAG_ADDR); - - bag1.add("Hello"); - bag2.add("World"); - bag1.add("!"); - - StateMerging.mergeBags(Arrays.asList(bag1, bag2, bag3), bag3); - - // Reading the merged bag gets both the contents - assertThat(bag3.read(), containsInAnyOrder("Hello", "World", "!")); - assertThat(bag1.read(), Matchers.emptyIterable()); - assertThat(bag2.read(), Matchers.emptyIterable()); - } - - @Test - public void testSet() throws Exception { - SetState value = underTest.state(NAMESPACE_1, STRING_SET_ADDR); - - // State instances are cached, but depend on the namespace. - assertThat(value, equalTo(underTest.state(NAMESPACE_1, STRING_SET_ADDR))); - assertThat(value, not(equalTo(underTest.state(NAMESPACE_2, STRING_SET_ADDR)))); - - // empty - assertThat(value.read(), Matchers.emptyIterable()); - assertFalse(value.contains("A").read()); - - // add - value.add("A"); - value.add("B"); - value.add("A"); - assertFalse(value.addIfAbsent("B").read()); - assertThat(value.read(), containsInAnyOrder("A", "B")); - - // remove - value.remove("A"); - assertThat(value.read(), containsInAnyOrder("B")); - value.remove("C"); - assertThat(value.read(), containsInAnyOrder("B")); - - // contains - assertFalse(value.contains("A").read()); - assertTrue(value.contains("B").read()); - value.add("C"); - value.add("D"); - - // readLater - assertThat(value.readLater().read(), containsInAnyOrder("B", "C", "D")); - SetState later = value.readLater(); - assertThat(later.read(), hasItems("C", "D")); - assertFalse(later.contains("A").read()); - - // clear - value.clear(); - assertThat(value.read(), Matchers.emptyIterable()); - assertThat(underTest.state(NAMESPACE_1, STRING_SET_ADDR), Matchers.sameInstance(value)); - - } - - @Test - public void testSetIsEmpty() throws Exception { - SetState value = underTest.state(NAMESPACE_1, STRING_SET_ADDR); - - assertThat(value.isEmpty().read(), Matchers.is(true)); - ReadableState readFuture = value.isEmpty(); - value.add("hello"); - assertThat(readFuture.read(), Matchers.is(false)); - - value.clear(); - assertThat(readFuture.read(), Matchers.is(true)); - } - - @Test - public void testMergeSetIntoSource() throws Exception { - SetState set1 = underTest.state(NAMESPACE_1, STRING_SET_ADDR); - SetState set2 = underTest.state(NAMESPACE_2, STRING_SET_ADDR); - - set1.add("Hello"); - set2.add("Hello"); - set2.add("World"); - set1.add("!"); - - StateMerging.mergeSets(Arrays.asList(set1, set2), set1); - - // Reading the merged set gets both the contents - assertThat(set1.read(), containsInAnyOrder("Hello", "World", "!")); - assertThat(set2.read(), Matchers.emptyIterable()); - } - - @Test - public void testMergeSetIntoNewNamespace() throws Exception { - SetState set1 = underTest.state(NAMESPACE_1, STRING_SET_ADDR); - SetState set2 = underTest.state(NAMESPACE_2, STRING_SET_ADDR); - SetState set3 = underTest.state(NAMESPACE_3, STRING_SET_ADDR); - - set1.add("Hello"); - set2.add("Hello"); - set2.add("World"); - set1.add("!"); - - StateMerging.mergeSets(Arrays.asList(set1, set2, set3), set3); - - // Reading the merged set gets both the contents - assertThat(set3.read(), containsInAnyOrder("Hello", "World", "!")); - assertThat(set1.read(), Matchers.emptyIterable()); - assertThat(set2.read(), Matchers.emptyIterable()); - } - - // for testMap - private static class MapEntry implements Map.Entry { - private K key; - private V value; - - private MapEntry(K key, V value) { - this.key = key; - this.value = value; - } - - static Map.Entry of(K k, V v) { - return new MapEntry<>(k, v); + /** + * A standard StateInternals test. + */ + @RunWith(JUnit4.class) + public static class StandardStateInternalsTests extends StateInternalsTest { + @Override + protected StateInternals createStateInternals() { + return new InMemoryStateInternals<>("dummyKey"); } + } - public final K getKey() { - return key; - } - public final V getValue() { - return value; - } + /** + * A specific test of InMemoryStateInternals. + */ + @RunWith(JUnit4.class) + public static class OtherTests { - public final String toString() { - return key + "=" + value; - } + StateInternals underTest = new InMemoryStateInternals<>("dummyKey"); - public final int hashCode() { - return Objects.hashCode(key) ^ Objects.hashCode(value); + @Test + public void testSameInstance() { + assertSameInstance(StateInternalsTest.STRING_VALUE_ADDR); + assertSameInstance(StateInternalsTest.SUM_INTEGER_ADDR); + assertSameInstance(StateInternalsTest.STRING_BAG_ADDR); + assertSameInstance(StateInternalsTest.STRING_SET_ADDR); + assertSameInstance(StateInternalsTest.STRING_MAP_ADDR); + assertSameInstance(StateInternalsTest.WATERMARK_EARLIEST_ADDR); } - public final V setValue(V newValue) { - V oldValue = value; - value = newValue; - return oldValue; + private void assertSameInstance(StateTag address) { + assertThat(underTest.state(StateInternalsTest.NAMESPACE_1, address), + Matchers.sameInstance(underTest.state(StateInternalsTest.NAMESPACE_1, address))); } - - public final boolean equals(Object o) { - if (o == this) { - return true; - } - if (o instanceof Map.Entry) { - Map.Entry e = (Map.Entry) o; - if (Objects.equals(key, e.getKey()) - && Objects.equals(value, e.getValue())) { - return true; - } - } - return false; - } - } - - @Test - public void testMap() throws Exception { - MapState value = underTest.state(NAMESPACE_1, STRING_MAP_ADDR); - - // State instances are cached, but depend on the namespace. - assertThat(value, equalTo(underTest.state(NAMESPACE_1, STRING_MAP_ADDR))); - assertThat(value, not(equalTo(underTest.state(NAMESPACE_2, STRING_MAP_ADDR)))); - - // put - assertThat(value.entries().read(), Matchers.emptyIterable()); - value.put("A", 1); - value.put("B", 2); - value.put("A", 11); - assertThat(value.putIfAbsent("B", 22).read(), equalTo(2)); - assertThat(value.entries().read(), containsInAnyOrder(MapEntry.of("A", 11), - MapEntry.of("B", 2))); - - // remove - value.remove("A"); - assertThat(value.entries().read(), containsInAnyOrder(MapEntry.of("B", 2))); - value.remove("C"); - assertThat(value.entries().read(), containsInAnyOrder(MapEntry.of("B", 2))); - - // get - assertNull(value.get("A").read()); - assertThat(value.get("B").read(), equalTo(2)); - value.put("C", 3); - value.put("D", 4); - assertThat(value.get("C").read(), equalTo(3)); - - // iterate - value.put("E", 5); - value.remove("C"); - assertThat(value.keys().read(), containsInAnyOrder("B", "D", "E")); - assertThat(value.values().read(), containsInAnyOrder(2, 4, 5)); - assertThat( - value.entries().read(), - containsInAnyOrder(MapEntry.of("B", 2), MapEntry.of("D", 4), MapEntry.of("E", 5))); - - // readLater - assertThat(value.get("B").readLater().read(), equalTo(2)); - assertNull(value.get("A").readLater().read()); - assertThat( - value.entries().readLater().read(), - containsInAnyOrder(MapEntry.of("B", 2), MapEntry.of("D", 4), MapEntry.of("E", 5))); - - // clear - value.clear(); - assertThat(value.entries().read(), Matchers.emptyIterable()); - assertThat(underTest.state(NAMESPACE_1, STRING_MAP_ADDR), Matchers.sameInstance(value)); } - @Test - public void testCombiningValue() throws Exception { - GroupingState value = underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR); - - // State instances are cached, but depend on the namespace. - assertEquals(value, underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR)); - assertFalse(value.equals(underTest.state(NAMESPACE_2, SUM_INTEGER_ADDR))); - - assertThat(value.read(), equalTo(0)); - value.add(2); - assertThat(value.read(), equalTo(2)); - - value.add(3); - assertThat(value.read(), equalTo(5)); - - value.clear(); - assertThat(value.read(), equalTo(0)); - assertThat(underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR), Matchers.sameInstance(value)); - } - - @Test - public void testCombiningIsEmpty() throws Exception { - GroupingState value = underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR); - - assertThat(value.isEmpty().read(), Matchers.is(true)); - ReadableState readFuture = value.isEmpty(); - value.add(5); - assertThat(readFuture.read(), Matchers.is(false)); - - value.clear(); - assertThat(readFuture.read(), Matchers.is(true)); - } - - @Test - public void testMergeCombiningValueIntoSource() throws Exception { - CombiningState value1 = - underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR); - CombiningState value2 = - underTest.state(NAMESPACE_2, SUM_INTEGER_ADDR); - - value1.add(5); - value2.add(10); - value1.add(6); - - assertThat(value1.read(), equalTo(11)); - assertThat(value2.read(), equalTo(10)); - - // Merging clears the old values and updates the result value. - StateMerging.mergeCombiningValues(Arrays.asList(value1, value2), value1); - - assertThat(value1.read(), equalTo(21)); - assertThat(value2.read(), equalTo(0)); - } - - @Test - public void testMergeCombiningValueIntoNewNamespace() throws Exception { - CombiningState value1 = - underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR); - CombiningState value2 = - underTest.state(NAMESPACE_2, SUM_INTEGER_ADDR); - CombiningState value3 = - underTest.state(NAMESPACE_3, SUM_INTEGER_ADDR); - - value1.add(5); - value2.add(10); - value1.add(6); - - StateMerging.mergeCombiningValues(Arrays.asList(value1, value2), value3); - - // Merging clears the old values and updates the result value. - assertThat(value1.read(), equalTo(0)); - assertThat(value2.read(), equalTo(0)); - assertThat(value3.read(), equalTo(21)); - } - - @Test - public void testWatermarkEarliestState() throws Exception { - WatermarkHoldState value = - underTest.state(NAMESPACE_1, WATERMARK_EARLIEST_ADDR); - - // State instances are cached, but depend on the namespace. - assertEquals(value, underTest.state(NAMESPACE_1, WATERMARK_EARLIEST_ADDR)); - assertFalse(value.equals(underTest.state(NAMESPACE_2, WATERMARK_EARLIEST_ADDR))); - - assertThat(value.read(), Matchers.nullValue()); - value.add(new Instant(2000)); - assertThat(value.read(), equalTo(new Instant(2000))); - - value.add(new Instant(3000)); - assertThat(value.read(), equalTo(new Instant(2000))); - - value.add(new Instant(1000)); - assertThat(value.read(), equalTo(new Instant(1000))); - - value.clear(); - assertThat(value.read(), equalTo(null)); - assertThat(underTest.state(NAMESPACE_1, WATERMARK_EARLIEST_ADDR), Matchers.sameInstance(value)); - } - - @Test - public void testWatermarkLatestState() throws Exception { - WatermarkHoldState value = - underTest.state(NAMESPACE_1, WATERMARK_LATEST_ADDR); - - // State instances are cached, but depend on the namespace. - assertEquals(value, underTest.state(NAMESPACE_1, WATERMARK_LATEST_ADDR)); - assertFalse(value.equals(underTest.state(NAMESPACE_2, WATERMARK_LATEST_ADDR))); - - assertThat(value.read(), Matchers.nullValue()); - value.add(new Instant(2000)); - assertThat(value.read(), equalTo(new Instant(2000))); - - value.add(new Instant(3000)); - assertThat(value.read(), equalTo(new Instant(3000))); - - value.add(new Instant(1000)); - assertThat(value.read(), equalTo(new Instant(3000))); - - value.clear(); - assertThat(value.read(), equalTo(null)); - assertThat(underTest.state(NAMESPACE_1, WATERMARK_LATEST_ADDR), Matchers.sameInstance(value)); - } - - @Test - public void testWatermarkEndOfWindowState() throws Exception { - WatermarkHoldState value = underTest.state(NAMESPACE_1, WATERMARK_EOW_ADDR); - - // State instances are cached, but depend on the namespace. - assertEquals(value, underTest.state(NAMESPACE_1, WATERMARK_EOW_ADDR)); - assertFalse(value.equals(underTest.state(NAMESPACE_2, WATERMARK_EOW_ADDR))); - - assertThat(value.read(), Matchers.nullValue()); - value.add(new Instant(2000)); - assertThat(value.read(), equalTo(new Instant(2000))); - - value.clear(); - assertThat(value.read(), equalTo(null)); - assertThat(underTest.state(NAMESPACE_1, WATERMARK_EOW_ADDR), Matchers.sameInstance(value)); - } - - @Test - public void testWatermarkStateIsEmpty() throws Exception { - WatermarkHoldState value = - underTest.state(NAMESPACE_1, WATERMARK_EARLIEST_ADDR); - - assertThat(value.isEmpty().read(), Matchers.is(true)); - ReadableState readFuture = value.isEmpty(); - value.add(new Instant(1000)); - assertThat(readFuture.read(), Matchers.is(false)); - - value.clear(); - assertThat(readFuture.read(), Matchers.is(true)); - } - - @Test - public void testMergeEarliestWatermarkIntoSource() throws Exception { - WatermarkHoldState value1 = - underTest.state(NAMESPACE_1, WATERMARK_EARLIEST_ADDR); - WatermarkHoldState value2 = - underTest.state(NAMESPACE_2, WATERMARK_EARLIEST_ADDR); - - value1.add(new Instant(3000)); - value2.add(new Instant(5000)); - value1.add(new Instant(4000)); - value2.add(new Instant(2000)); - - // Merging clears the old values and updates the merged value. - StateMerging.mergeWatermarks(Arrays.asList(value1, value2), value1, WINDOW_1); - - assertThat(value1.read(), equalTo(new Instant(2000))); - assertThat(value2.read(), equalTo(null)); - } - - @Test - public void testMergeLatestWatermarkIntoSource() throws Exception { - WatermarkHoldState value1 = - underTest.state(NAMESPACE_1, WATERMARK_LATEST_ADDR); - WatermarkHoldState value2 = - underTest.state(NAMESPACE_2, WATERMARK_LATEST_ADDR); - WatermarkHoldState value3 = - underTest.state(NAMESPACE_3, WATERMARK_LATEST_ADDR); - - value1.add(new Instant(3000)); - value2.add(new Instant(5000)); - value1.add(new Instant(4000)); - value2.add(new Instant(2000)); - - // Merging clears the old values and updates the result value. - StateMerging.mergeWatermarks(Arrays.asList(value1, value2), value3, WINDOW_1); - - // Merging clears the old values and updates the result value. - assertThat(value3.read(), equalTo(new Instant(5000))); - assertThat(value1.read(), equalTo(null)); - assertThat(value2.read(), equalTo(null)); - } } diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/StateInternalsTest.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/StateInternalsTest.java new file mode 100644 index 0000000000000..bf3156aad110e --- /dev/null +++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/StateInternalsTest.java @@ -0,0 +1,573 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.core; + +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasItems; +import static org.hamcrest.Matchers.not; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; + +import java.util.Arrays; +import java.util.Map; +import java.util.Objects; +import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.coders.VarIntCoder; +import org.apache.beam.sdk.state.BagState; +import org.apache.beam.sdk.state.CombiningState; +import org.apache.beam.sdk.state.GroupingState; +import org.apache.beam.sdk.state.MapState; +import org.apache.beam.sdk.state.ReadableState; +import org.apache.beam.sdk.state.SetState; +import org.apache.beam.sdk.state.ValueState; +import org.apache.beam.sdk.state.WatermarkHoldState; +import org.apache.beam.sdk.transforms.Sum; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.IntervalWindow; +import org.apache.beam.sdk.transforms.windowing.TimestampCombiner; +import org.hamcrest.Matchers; +import org.joda.time.Instant; +import org.junit.Before; +import org.junit.Test; + +/** + * Tests for {@link StateInternals}. + */ +public abstract class StateInternalsTest { + + private static final BoundedWindow WINDOW_1 = new IntervalWindow(new Instant(0), new Instant(10)); + static final StateNamespace NAMESPACE_1 = new StateNamespaceForTest("ns1"); + private static final StateNamespace NAMESPACE_2 = new StateNamespaceForTest("ns2"); + private static final StateNamespace NAMESPACE_3 = new StateNamespaceForTest("ns3"); + + static final StateTag> STRING_VALUE_ADDR = + StateTags.value("stringValue", StringUtf8Coder.of()); + static final StateTag> + SUM_INTEGER_ADDR = StateTags.combiningValueFromInputInternal( + "sumInteger", VarIntCoder.of(), Sum.ofIntegers()); + static final StateTag> STRING_BAG_ADDR = + StateTags.bag("stringBag", StringUtf8Coder.of()); + static final StateTag> STRING_SET_ADDR = + StateTags.set("stringSet", StringUtf8Coder.of()); + static final StateTag> STRING_MAP_ADDR = + StateTags.map("stringMap", StringUtf8Coder.of(), VarIntCoder.of()); + static final StateTag WATERMARK_EARLIEST_ADDR = + StateTags.watermarkStateInternal("watermark", TimestampCombiner.EARLIEST); + private static final StateTag WATERMARK_LATEST_ADDR = + StateTags.watermarkStateInternal("watermark", TimestampCombiner.LATEST); + private static final StateTag WATERMARK_EOW_ADDR = + StateTags.watermarkStateInternal("watermark", TimestampCombiner.END_OF_WINDOW); + + private StateInternals underTest; + + @Before + public void setUp() { + this.underTest = createStateInternals(); + } + + protected abstract StateInternals createStateInternals(); + + @Test + public void testValue() throws Exception { + ValueState value = underTest.state(NAMESPACE_1, STRING_VALUE_ADDR); + + // State instances are cached, but depend on the namespace. + assertThat(underTest.state(NAMESPACE_1, STRING_VALUE_ADDR), equalTo(value)); + assertThat( + underTest.state(NAMESPACE_2, STRING_VALUE_ADDR), + Matchers.not(equalTo(value))); + + assertThat(value.read(), Matchers.nullValue()); + value.write("hello"); + assertThat(value.read(), equalTo("hello")); + value.write("world"); + assertThat(value.read(), equalTo("world")); + + value.clear(); + assertThat(value.read(), Matchers.nullValue()); + assertThat(underTest.state(NAMESPACE_1, STRING_VALUE_ADDR), equalTo(value)); + } + + @Test + public void testBag() throws Exception { + BagState value = underTest.state(NAMESPACE_1, STRING_BAG_ADDR); + + // State instances are cached, but depend on the namespace. + assertThat(value, equalTo(underTest.state(NAMESPACE_1, STRING_BAG_ADDR))); + assertThat(value, not(equalTo(underTest.state(NAMESPACE_2, STRING_BAG_ADDR)))); + + assertThat(value.read(), Matchers.emptyIterable()); + value.add("hello"); + assertThat(value.read(), containsInAnyOrder("hello")); + + value.add("world"); + assertThat(value.read(), containsInAnyOrder("hello", "world")); + + value.clear(); + assertThat(value.read(), Matchers.emptyIterable()); + assertThat(underTest.state(NAMESPACE_1, STRING_BAG_ADDR), equalTo(value)); + } + + @Test + public void testBagIsEmpty() throws Exception { + BagState value = underTest.state(NAMESPACE_1, STRING_BAG_ADDR); + + assertThat(value.isEmpty().read(), Matchers.is(true)); + ReadableState readFuture = value.isEmpty(); + value.add("hello"); + assertThat(readFuture.read(), Matchers.is(false)); + + value.clear(); + assertThat(readFuture.read(), Matchers.is(true)); + } + + @Test + public void testMergeBagIntoSource() throws Exception { + BagState bag1 = underTest.state(NAMESPACE_1, STRING_BAG_ADDR); + BagState bag2 = underTest.state(NAMESPACE_2, STRING_BAG_ADDR); + + bag1.add("Hello"); + bag2.add("World"); + bag1.add("!"); + + StateMerging.mergeBags(Arrays.asList(bag1, bag2), bag1); + + // Reading the merged bag gets both the contents + assertThat(bag1.read(), containsInAnyOrder("Hello", "World", "!")); + assertThat(bag2.read(), Matchers.emptyIterable()); + } + + @Test + public void testMergeBagIntoNewNamespace() throws Exception { + BagState bag1 = underTest.state(NAMESPACE_1, STRING_BAG_ADDR); + BagState bag2 = underTest.state(NAMESPACE_2, STRING_BAG_ADDR); + BagState bag3 = underTest.state(NAMESPACE_3, STRING_BAG_ADDR); + + bag1.add("Hello"); + bag2.add("World"); + bag1.add("!"); + + StateMerging.mergeBags(Arrays.asList(bag1, bag2, bag3), bag3); + + // Reading the merged bag gets both the contents + assertThat(bag3.read(), containsInAnyOrder("Hello", "World", "!")); + assertThat(bag1.read(), Matchers.emptyIterable()); + assertThat(bag2.read(), Matchers.emptyIterable()); + } + + @Test + public void testSet() throws Exception { + + SetState value = underTest.state(NAMESPACE_1, STRING_SET_ADDR); + + // State instances are cached, but depend on the namespace. + assertThat(value, equalTo(underTest.state(NAMESPACE_1, STRING_SET_ADDR))); + assertThat(value, not(equalTo(underTest.state(NAMESPACE_2, STRING_SET_ADDR)))); + + // empty + assertThat(value.read(), Matchers.emptyIterable()); + assertFalse(value.contains("A").read()); + + // add + value.add("A"); + value.add("B"); + value.add("A"); + assertFalse(value.addIfAbsent("B").read()); + assertThat(value.read(), containsInAnyOrder("A", "B")); + + // remove + value.remove("A"); + assertThat(value.read(), containsInAnyOrder("B")); + value.remove("C"); + assertThat(value.read(), containsInAnyOrder("B")); + + // contains + assertFalse(value.contains("A").read()); + assertTrue(value.contains("B").read()); + value.add("C"); + value.add("D"); + + // readLater + assertThat(value.readLater().read(), containsInAnyOrder("B", "C", "D")); + SetState later = value.readLater(); + assertThat(later.read(), hasItems("C", "D")); + assertFalse(later.contains("A").read()); + + // clear + value.clear(); + assertThat(value.read(), Matchers.emptyIterable()); + assertThat(underTest.state(NAMESPACE_1, STRING_SET_ADDR), equalTo(value)); + + } + + @Test + public void testSetIsEmpty() throws Exception { + + SetState value = underTest.state(NAMESPACE_1, STRING_SET_ADDR); + + assertThat(value.isEmpty().read(), Matchers.is(true)); + ReadableState readFuture = value.isEmpty(); + value.add("hello"); + assertThat(readFuture.read(), Matchers.is(false)); + + value.clear(); + assertThat(readFuture.read(), Matchers.is(true)); + } + + @Test + public void testMergeSetIntoSource() throws Exception { + + SetState set1 = underTest.state(NAMESPACE_1, STRING_SET_ADDR); + SetState set2 = underTest.state(NAMESPACE_2, STRING_SET_ADDR); + + set1.add("Hello"); + set2.add("Hello"); + set2.add("World"); + set1.add("!"); + + StateMerging.mergeSets(Arrays.asList(set1, set2), set1); + + // Reading the merged set gets both the contents + assertThat(set1.read(), containsInAnyOrder("Hello", "World", "!")); + assertThat(set2.read(), Matchers.emptyIterable()); + } + + @Test + public void testMergeSetIntoNewNamespace() throws Exception { + + SetState set1 = underTest.state(NAMESPACE_1, STRING_SET_ADDR); + SetState set2 = underTest.state(NAMESPACE_2, STRING_SET_ADDR); + SetState set3 = underTest.state(NAMESPACE_3, STRING_SET_ADDR); + + set1.add("Hello"); + set2.add("Hello"); + set2.add("World"); + set1.add("!"); + + StateMerging.mergeSets(Arrays.asList(set1, set2, set3), set3); + + // Reading the merged set gets both the contents + assertThat(set3.read(), containsInAnyOrder("Hello", "World", "!")); + assertThat(set1.read(), Matchers.emptyIterable()); + assertThat(set2.read(), Matchers.emptyIterable()); + } + + // for testMap + private static class MapEntry implements Map.Entry { + private K key; + private V value; + + private MapEntry(K key, V value) { + this.key = key; + this.value = value; + } + + static Map.Entry of(K k, V v) { + return new MapEntry<>(k, v); + } + + public final K getKey() { + return key; + } + public final V getValue() { + return value; + } + + public final String toString() { + return key + "=" + value; + } + + public final int hashCode() { + return Objects.hashCode(key) ^ Objects.hashCode(value); + } + + public final V setValue(V newValue) { + V oldValue = value; + value = newValue; + return oldValue; + } + + public final boolean equals(Object o) { + if (o == this) { + return true; + } + if (o instanceof Map.Entry) { + Map.Entry e = (Map.Entry) o; + if (Objects.equals(key, e.getKey()) + && Objects.equals(value, e.getValue())) { + return true; + } + } + return false; + } + } + + @Test + public void testMap() throws Exception { + + MapState value = underTest.state(NAMESPACE_1, STRING_MAP_ADDR); + + // State instances are cached, but depend on the namespace. + assertThat(value, equalTo(underTest.state(NAMESPACE_1, STRING_MAP_ADDR))); + assertThat(value, not(equalTo(underTest.state(NAMESPACE_2, STRING_MAP_ADDR)))); + + // put + assertThat(value.entries().read(), Matchers.emptyIterable()); + value.put("A", 1); + value.put("B", 2); + value.put("A", 11); + assertThat(value.putIfAbsent("B", 22).read(), equalTo(2)); + assertThat(value.entries().read(), containsInAnyOrder(MapEntry.of("A", 11), + MapEntry.of("B", 2))); + + // remove + value.remove("A"); + assertThat(value.entries().read(), containsInAnyOrder(MapEntry.of("B", 2))); + value.remove("C"); + assertThat(value.entries().read(), containsInAnyOrder(MapEntry.of("B", 2))); + + // get + assertNull(value.get("A").read()); + assertThat(value.get("B").read(), equalTo(2)); + value.put("C", 3); + value.put("D", 4); + assertThat(value.get("C").read(), equalTo(3)); + + // iterate + value.put("E", 5); + value.remove("C"); + assertThat(value.keys().read(), containsInAnyOrder("B", "D", "E")); + assertThat(value.values().read(), containsInAnyOrder(2, 4, 5)); + assertThat( + value.entries().read(), + containsInAnyOrder(MapEntry.of("B", 2), MapEntry.of("D", 4), MapEntry.of("E", 5))); + + // readLater + assertThat(value.get("B").readLater().read(), equalTo(2)); + assertNull(value.get("A").readLater().read()); + assertThat( + value.entries().readLater().read(), + containsInAnyOrder(MapEntry.of("B", 2), MapEntry.of("D", 4), MapEntry.of("E", 5))); + + // clear + value.clear(); + assertThat(value.entries().read(), Matchers.emptyIterable()); + assertThat(underTest.state(NAMESPACE_1, STRING_MAP_ADDR), equalTo(value)); + } + + @Test + public void testCombiningValue() throws Exception { + + GroupingState value = underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR); + + // State instances are cached, but depend on the namespace. + assertEquals(value, underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR)); + assertFalse(value.equals(underTest.state(NAMESPACE_2, SUM_INTEGER_ADDR))); + + assertThat(value.read(), equalTo(0)); + value.add(2); + assertThat(value.read(), equalTo(2)); + + value.add(3); + assertThat(value.read(), equalTo(5)); + + value.clear(); + assertThat(value.read(), equalTo(0)); + assertThat(underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR), equalTo(value)); + } + + @Test + public void testCombiningIsEmpty() throws Exception { + GroupingState value = underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR); + + assertThat(value.isEmpty().read(), Matchers.is(true)); + ReadableState readFuture = value.isEmpty(); + value.add(5); + assertThat(readFuture.read(), Matchers.is(false)); + + value.clear(); + assertThat(readFuture.read(), Matchers.is(true)); + } + + @Test + public void testMergeCombiningValueIntoSource() throws Exception { + CombiningState value1 = + underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR); + CombiningState value2 = + underTest.state(NAMESPACE_2, SUM_INTEGER_ADDR); + + value1.add(5); + value2.add(10); + value1.add(6); + + assertThat(value1.read(), equalTo(11)); + assertThat(value2.read(), equalTo(10)); + + // Merging clears the old values and updates the result value. + StateMerging.mergeCombiningValues(Arrays.asList(value1, value2), value1); + + assertThat(value1.read(), equalTo(21)); + assertThat(value2.read(), equalTo(0)); + } + + @Test + public void testMergeCombiningValueIntoNewNamespace() throws Exception { + CombiningState value1 = + underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR); + CombiningState value2 = + underTest.state(NAMESPACE_2, SUM_INTEGER_ADDR); + CombiningState value3 = + underTest.state(NAMESPACE_3, SUM_INTEGER_ADDR); + + value1.add(5); + value2.add(10); + value1.add(6); + + StateMerging.mergeCombiningValues(Arrays.asList(value1, value2), value3); + + // Merging clears the old values and updates the result value. + assertThat(value1.read(), equalTo(0)); + assertThat(value2.read(), equalTo(0)); + assertThat(value3.read(), equalTo(21)); + } + + @Test + public void testWatermarkEarliestState() throws Exception { + WatermarkHoldState value = + underTest.state(NAMESPACE_1, WATERMARK_EARLIEST_ADDR); + + // State instances are cached, but depend on the namespace. + assertEquals(value, underTest.state(NAMESPACE_1, WATERMARK_EARLIEST_ADDR)); + assertFalse(value.equals(underTest.state(NAMESPACE_2, WATERMARK_EARLIEST_ADDR))); + + assertThat(value.read(), Matchers.nullValue()); + value.add(new Instant(2000)); + assertThat(value.read(), equalTo(new Instant(2000))); + + value.add(new Instant(3000)); + assertThat(value.read(), equalTo(new Instant(2000))); + + value.add(new Instant(1000)); + assertThat(value.read(), equalTo(new Instant(1000))); + + value.clear(); + assertThat(value.read(), equalTo(null)); + assertThat(underTest.state(NAMESPACE_1, WATERMARK_EARLIEST_ADDR), equalTo(value)); + } + + @Test + public void testWatermarkLatestState() throws Exception { + WatermarkHoldState value = + underTest.state(NAMESPACE_1, WATERMARK_LATEST_ADDR); + + // State instances are cached, but depend on the namespace. + assertEquals(value, underTest.state(NAMESPACE_1, WATERMARK_LATEST_ADDR)); + assertFalse(value.equals(underTest.state(NAMESPACE_2, WATERMARK_LATEST_ADDR))); + + assertThat(value.read(), Matchers.nullValue()); + value.add(new Instant(2000)); + assertThat(value.read(), equalTo(new Instant(2000))); + + value.add(new Instant(3000)); + assertThat(value.read(), equalTo(new Instant(3000))); + + value.add(new Instant(1000)); + assertThat(value.read(), equalTo(new Instant(3000))); + + value.clear(); + assertThat(value.read(), equalTo(null)); + assertThat(underTest.state(NAMESPACE_1, WATERMARK_LATEST_ADDR), equalTo(value)); + } + + @Test + public void testWatermarkEndOfWindowState() throws Exception { + WatermarkHoldState value = underTest.state(NAMESPACE_1, WATERMARK_EOW_ADDR); + + // State instances are cached, but depend on the namespace. + assertEquals(value, underTest.state(NAMESPACE_1, WATERMARK_EOW_ADDR)); + assertFalse(value.equals(underTest.state(NAMESPACE_2, WATERMARK_EOW_ADDR))); + + assertThat(value.read(), Matchers.nullValue()); + value.add(new Instant(2000)); + assertThat(value.read(), equalTo(new Instant(2000))); + + value.clear(); + assertThat(value.read(), equalTo(null)); + assertThat(underTest.state(NAMESPACE_1, WATERMARK_EOW_ADDR), equalTo(value)); + } + + @Test + public void testWatermarkStateIsEmpty() throws Exception { + WatermarkHoldState value = + underTest.state(NAMESPACE_1, WATERMARK_EARLIEST_ADDR); + + assertThat(value.isEmpty().read(), Matchers.is(true)); + ReadableState readFuture = value.isEmpty(); + value.add(new Instant(1000)); + assertThat(readFuture.read(), Matchers.is(false)); + + value.clear(); + assertThat(readFuture.read(), Matchers.is(true)); + } + + @Test + public void testMergeEarliestWatermarkIntoSource() throws Exception { + WatermarkHoldState value1 = + underTest.state(NAMESPACE_1, WATERMARK_EARLIEST_ADDR); + WatermarkHoldState value2 = + underTest.state(NAMESPACE_2, WATERMARK_EARLIEST_ADDR); + + value1.add(new Instant(3000)); + value2.add(new Instant(5000)); + value1.add(new Instant(4000)); + value2.add(new Instant(2000)); + + // Merging clears the old values and updates the merged value. + StateMerging.mergeWatermarks(Arrays.asList(value1, value2), value1, WINDOW_1); + + assertThat(value1.read(), equalTo(new Instant(2000))); + assertThat(value2.read(), equalTo(null)); + } + + @Test + public void testMergeLatestWatermarkIntoSource() throws Exception { + WatermarkHoldState value1 = + underTest.state(NAMESPACE_1, WATERMARK_LATEST_ADDR); + WatermarkHoldState value2 = + underTest.state(NAMESPACE_2, WATERMARK_LATEST_ADDR); + WatermarkHoldState value3 = + underTest.state(NAMESPACE_3, WATERMARK_LATEST_ADDR); + + value1.add(new Instant(3000)); + value2.add(new Instant(5000)); + value1.add(new Instant(4000)); + value2.add(new Instant(2000)); + + // Merging clears the old values and updates the result value. + StateMerging.mergeWatermarks(Arrays.asList(value1, value2), value3, WINDOW_1); + + // Merging clears the old values and updates the result value. + assertThat(value3.read(), equalTo(new Instant(5000))); + assertThat(value1.read(), equalTo(null)); + assertThat(value2.read(), equalTo(null)); + } +} diff --git a/runners/flink/pom.xml b/runners/flink/pom.xml index c4c6b55d6cf46..a5b8203507ffb 100644 --- a/runners/flink/pom.xml +++ b/runners/flink/pom.xml @@ -381,5 +381,13 @@ test-jar test + + + org.apache.beam + beam-runners-core-java + test-jar + test + + diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java index 35d2b786b82da..e7564ec914a2c 100644 --- a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java @@ -17,31 +17,11 @@ */ package org.apache.beam.runners.flink.streaming; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertNotEquals; -import static org.junit.Assert.assertThat; - import java.nio.ByteBuffer; -import java.util.Arrays; -import org.apache.beam.runners.core.StateMerging; -import org.apache.beam.runners.core.StateNamespace; -import org.apache.beam.runners.core.StateNamespaceForTest; -import org.apache.beam.runners.core.StateTag; -import org.apache.beam.runners.core.StateTags; +import org.apache.beam.runners.core.StateInternals; +import org.apache.beam.runners.core.StateInternalsTest; import org.apache.beam.runners.flink.translation.wrappers.streaming.state.FlinkStateInternals; import org.apache.beam.sdk.coders.StringUtf8Coder; -import org.apache.beam.sdk.coders.VarIntCoder; -import org.apache.beam.sdk.state.BagState; -import org.apache.beam.sdk.state.CombiningState; -import org.apache.beam.sdk.state.GroupingState; -import org.apache.beam.sdk.state.ReadableState; -import org.apache.beam.sdk.state.ValueState; -import org.apache.beam.sdk.state.WatermarkHoldState; -import org.apache.beam.sdk.transforms.Sum; -import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.transforms.windowing.IntervalWindow; -import org.apache.beam.sdk.transforms.windowing.TimestampCombiner; import org.apache.beam.sdk.util.CoderUtils; import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.JobID; @@ -52,42 +32,17 @@ import org.apache.flink.runtime.state.AbstractKeyedStateBackend; import org.apache.flink.runtime.state.KeyGroupRange; import org.apache.flink.runtime.state.memory.MemoryStateBackend; -import org.hamcrest.Matchers; -import org.joda.time.Instant; -import org.junit.Before; -import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; /** - * Tests for {@link FlinkStateInternals}. This is based on the tests for - * {@code InMemoryStateInternals}. + * Tests for {@link FlinkStateInternals}. This is based on {@link StateInternalsTest}. */ @RunWith(JUnit4.class) -public class FlinkStateInternalsTest { - private static final BoundedWindow WINDOW_1 = new IntervalWindow(new Instant(0), new Instant(10)); - private static final StateNamespace NAMESPACE_1 = new StateNamespaceForTest("ns1"); - private static final StateNamespace NAMESPACE_2 = new StateNamespaceForTest("ns2"); - private static final StateNamespace NAMESPACE_3 = new StateNamespaceForTest("ns3"); - - private static final StateTag> STRING_VALUE_ADDR = - StateTags.value("stringValue", StringUtf8Coder.of()); - private static final StateTag> - SUM_INTEGER_ADDR = StateTags.combiningValueFromInputInternal( - "sumInteger", VarIntCoder.of(), Sum.ofIntegers()); - private static final StateTag> STRING_BAG_ADDR = - StateTags.bag("stringBag", StringUtf8Coder.of()); - private static final StateTag WATERMARK_EARLIEST_ADDR = - StateTags.watermarkStateInternal("watermark", TimestampCombiner.EARLIEST); - private static final StateTag WATERMARK_LATEST_ADDR = - StateTags.watermarkStateInternal("watermark", TimestampCombiner.LATEST); - private static final StateTag WATERMARK_EOW_ADDR = - StateTags.watermarkStateInternal("watermark", TimestampCombiner.END_OF_WINDOW); - - FlinkStateInternals underTest; +public class FlinkStateInternalsTest extends StateInternalsTest { - @Before - public void initStateInternals() { + @Override + protected StateInternals createStateInternals() { MemoryStateBackend backend = new MemoryStateBackend(); try { AbstractKeyedStateBackend keyedStateBackend = backend.createKeyedStateBackend( @@ -98,296 +53,31 @@ public void initStateInternals() { 1, new KeyGroupRange(0, 0), new KvStateRegistry().createTaskRegistry(new JobID(), new JobVertexID())); - underTest = new FlinkStateInternals<>(keyedStateBackend, StringUtf8Coder.of()); keyedStateBackend.setCurrentKey( ByteBuffer.wrap(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "Hello"))); + + return new FlinkStateInternals<>(keyedStateBackend, StringUtf8Coder.of()); } catch (Exception e) { throw new RuntimeException(e); } } - @Test - public void testValue() throws Exception { - ValueState value = underTest.state(NAMESPACE_1, STRING_VALUE_ADDR); - - assertEquals(underTest.state(NAMESPACE_1, STRING_VALUE_ADDR), value); - assertNotEquals( - underTest.state(NAMESPACE_2, STRING_VALUE_ADDR), - value); - - assertThat(value.read(), Matchers.nullValue()); - value.write("hello"); - assertThat(value.read(), Matchers.equalTo("hello")); - value.write("world"); - assertThat(value.read(), Matchers.equalTo("world")); - - value.clear(); - assertThat(value.read(), Matchers.nullValue()); - assertEquals(underTest.state(NAMESPACE_1, STRING_VALUE_ADDR), value); - - } - - @Test - public void testBag() throws Exception { - BagState value = underTest.state(NAMESPACE_1, STRING_BAG_ADDR); - - assertEquals(value, underTest.state(NAMESPACE_1, STRING_BAG_ADDR)); - assertFalse(value.equals(underTest.state(NAMESPACE_2, STRING_BAG_ADDR))); - - assertThat(value.read(), Matchers.emptyIterable()); - value.add("hello"); - assertThat(value.read(), Matchers.containsInAnyOrder("hello")); - - value.add("world"); - assertThat(value.read(), Matchers.containsInAnyOrder("hello", "world")); - - value.clear(); - assertThat(value.read(), Matchers.emptyIterable()); - assertEquals(underTest.state(NAMESPACE_1, STRING_BAG_ADDR), value); - - } - - @Test - public void testBagIsEmpty() throws Exception { - BagState value = underTest.state(NAMESPACE_1, STRING_BAG_ADDR); - - assertThat(value.isEmpty().read(), Matchers.is(true)); - ReadableState readFuture = value.isEmpty(); - value.add("hello"); - assertThat(readFuture.read(), Matchers.is(false)); - - value.clear(); - assertThat(readFuture.read(), Matchers.is(true)); - } - - @Test - public void testMergeBagIntoSource() throws Exception { - BagState bag1 = underTest.state(NAMESPACE_1, STRING_BAG_ADDR); - BagState bag2 = underTest.state(NAMESPACE_2, STRING_BAG_ADDR); - - bag1.add("Hello"); - bag2.add("World"); - bag1.add("!"); - - StateMerging.mergeBags(Arrays.asList(bag1, bag2), bag1); - - // Reading the merged bag gets both the contents - assertThat(bag1.read(), Matchers.containsInAnyOrder("Hello", "World", "!")); - assertThat(bag2.read(), Matchers.emptyIterable()); - } - - @Test - public void testMergeBagIntoNewNamespace() throws Exception { - BagState bag1 = underTest.state(NAMESPACE_1, STRING_BAG_ADDR); - BagState bag2 = underTest.state(NAMESPACE_2, STRING_BAG_ADDR); - BagState bag3 = underTest.state(NAMESPACE_3, STRING_BAG_ADDR); - - bag1.add("Hello"); - bag2.add("World"); - bag1.add("!"); - - StateMerging.mergeBags(Arrays.asList(bag1, bag2, bag3), bag3); - - // Reading the merged bag gets both the contents - assertThat(bag3.read(), Matchers.containsInAnyOrder("Hello", "World", "!")); - assertThat(bag1.read(), Matchers.emptyIterable()); - assertThat(bag2.read(), Matchers.emptyIterable()); - } - - @Test - public void testCombiningValue() throws Exception { - GroupingState value = underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR); - - // State instances are cached, but depend on the namespace. - assertEquals(value, underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR)); - assertFalse(value.equals(underTest.state(NAMESPACE_2, SUM_INTEGER_ADDR))); - - assertThat(value.read(), Matchers.equalTo(0)); - value.add(2); - assertThat(value.read(), Matchers.equalTo(2)); - - value.add(3); - assertThat(value.read(), Matchers.equalTo(5)); - - value.clear(); - assertThat(value.read(), Matchers.equalTo(0)); - assertEquals(underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR), value); - } - - @Test - public void testCombiningIsEmpty() throws Exception { - GroupingState value = underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR); - - assertThat(value.isEmpty().read(), Matchers.is(true)); - ReadableState readFuture = value.isEmpty(); - value.add(5); - assertThat(readFuture.read(), Matchers.is(false)); - - value.clear(); - assertThat(readFuture.read(), Matchers.is(true)); - } - - @Test - public void testMergeCombiningValueIntoSource() throws Exception { - CombiningState value1 = - underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR); - CombiningState value2 = - underTest.state(NAMESPACE_2, SUM_INTEGER_ADDR); - - value1.add(5); - value2.add(10); - value1.add(6); - - assertThat(value1.read(), Matchers.equalTo(11)); - assertThat(value2.read(), Matchers.equalTo(10)); - - // Merging clears the old values and updates the result value. - StateMerging.mergeCombiningValues(Arrays.asList(value1, value2), value1); - - assertThat(value1.read(), Matchers.equalTo(21)); - assertThat(value2.read(), Matchers.equalTo(0)); - } - - @Test - public void testMergeCombiningValueIntoNewNamespace() throws Exception { - CombiningState value1 = - underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR); - CombiningState value2 = - underTest.state(NAMESPACE_2, SUM_INTEGER_ADDR); - CombiningState value3 = - underTest.state(NAMESPACE_3, SUM_INTEGER_ADDR); - - value1.add(5); - value2.add(10); - value1.add(6); - - StateMerging.mergeCombiningValues(Arrays.asList(value1, value2), value3); - - // Merging clears the old values and updates the result value. - assertThat(value1.read(), Matchers.equalTo(0)); - assertThat(value2.read(), Matchers.equalTo(0)); - assertThat(value3.read(), Matchers.equalTo(21)); - } - - @Test - public void testWatermarkEarliestState() throws Exception { - WatermarkHoldState value = - underTest.state(NAMESPACE_1, WATERMARK_EARLIEST_ADDR); - - // State instances are cached, but depend on the namespace. - assertEquals(value, underTest.state(NAMESPACE_1, WATERMARK_EARLIEST_ADDR)); - assertFalse(value.equals(underTest.state(NAMESPACE_2, WATERMARK_EARLIEST_ADDR))); - - assertThat(value.read(), Matchers.nullValue()); - value.add(new Instant(2000)); - assertThat(value.read(), Matchers.equalTo(new Instant(2000))); + ///////////////////////// Unsupported tests \\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\ - value.add(new Instant(3000)); - assertThat(value.read(), Matchers.equalTo(new Instant(2000))); + @Override + public void testSet() {} - value.add(new Instant(1000)); - assertThat(value.read(), Matchers.equalTo(new Instant(1000))); + @Override + public void testSetIsEmpty() {} - value.clear(); - assertThat(value.read(), Matchers.equalTo(null)); - assertEquals(underTest.state(NAMESPACE_1, WATERMARK_EARLIEST_ADDR), value); - } - - @Test - public void testWatermarkLatestState() throws Exception { - WatermarkHoldState value = - underTest.state(NAMESPACE_1, WATERMARK_LATEST_ADDR); - - // State instances are cached, but depend on the namespace. - assertEquals(value, underTest.state(NAMESPACE_1, WATERMARK_LATEST_ADDR)); - assertFalse(value.equals(underTest.state(NAMESPACE_2, WATERMARK_LATEST_ADDR))); + @Override + public void testMergeSetIntoSource() {} - assertThat(value.read(), Matchers.nullValue()); - value.add(new Instant(2000)); - assertThat(value.read(), Matchers.equalTo(new Instant(2000))); + @Override + public void testMergeSetIntoNewNamespace() {} - value.add(new Instant(3000)); - assertThat(value.read(), Matchers.equalTo(new Instant(3000))); + @Override + public void testMap() {} - value.add(new Instant(1000)); - assertThat(value.read(), Matchers.equalTo(new Instant(3000))); - - value.clear(); - assertThat(value.read(), Matchers.equalTo(null)); - assertEquals(underTest.state(NAMESPACE_1, WATERMARK_LATEST_ADDR), value); - } - - @Test - public void testWatermarkEndOfWindowState() throws Exception { - WatermarkHoldState value = underTest.state(NAMESPACE_1, WATERMARK_EOW_ADDR); - - // State instances are cached, but depend on the namespace. - assertEquals(value, underTest.state(NAMESPACE_1, WATERMARK_EOW_ADDR)); - assertFalse(value.equals(underTest.state(NAMESPACE_2, WATERMARK_EOW_ADDR))); - - assertThat(value.read(), Matchers.nullValue()); - value.add(new Instant(2000)); - assertThat(value.read(), Matchers.equalTo(new Instant(2000))); - - value.clear(); - assertThat(value.read(), Matchers.equalTo(null)); - assertEquals(underTest.state(NAMESPACE_1, WATERMARK_EOW_ADDR), value); - } - - @Test - public void testWatermarkStateIsEmpty() throws Exception { - WatermarkHoldState value = - underTest.state(NAMESPACE_1, WATERMARK_EARLIEST_ADDR); - - assertThat(value.isEmpty().read(), Matchers.is(true)); - ReadableState readFuture = value.isEmpty(); - value.add(new Instant(1000)); - assertThat(readFuture.read(), Matchers.is(false)); - - value.clear(); - assertThat(readFuture.read(), Matchers.is(true)); - } - - @Test - public void testMergeEarliestWatermarkIntoSource() throws Exception { - WatermarkHoldState value1 = - underTest.state(NAMESPACE_1, WATERMARK_EARLIEST_ADDR); - WatermarkHoldState value2 = - underTest.state(NAMESPACE_2, WATERMARK_EARLIEST_ADDR); - - value1.add(new Instant(3000)); - value2.add(new Instant(5000)); - value1.add(new Instant(4000)); - value2.add(new Instant(2000)); - - // Merging clears the old values and updates the merged value. - StateMerging.mergeWatermarks(Arrays.asList(value1, value2), value1, WINDOW_1); - - assertThat(value1.read(), Matchers.equalTo(new Instant(2000))); - assertThat(value2.read(), Matchers.equalTo(null)); - } - - @Test - public void testMergeLatestWatermarkIntoSource() throws Exception { - WatermarkHoldState value1 = - underTest.state(NAMESPACE_1, WATERMARK_LATEST_ADDR); - WatermarkHoldState value2 = - underTest.state(NAMESPACE_2, WATERMARK_LATEST_ADDR); - WatermarkHoldState value3 = - underTest.state(NAMESPACE_3, WATERMARK_LATEST_ADDR); - - value1.add(new Instant(3000)); - value2.add(new Instant(5000)); - value1.add(new Instant(4000)); - value2.add(new Instant(2000)); - - // Merging clears the old values and updates the result value. - StateMerging.mergeWatermarks(Arrays.asList(value1, value2), value3, WINDOW_1); - - // Merging clears the old values and updates the result value. - assertThat(value3.read(), Matchers.equalTo(new Instant(5000))); - assertThat(value1.read(), Matchers.equalTo(null)); - assertThat(value2.read(), Matchers.equalTo(null)); - } } From 809f17876d847002ba76979cb3362451fa01c110 Mon Sep 17 00:00:00 2001 From: Charles Chen Date: Mon, 12 Jun 2017 14:17:50 -0700 Subject: [PATCH 024/200] Reverse removal of NativeWrite evaluator in Python DirectRunner --- .../runners/direct/transform_evaluator.py | 62 ++++++++++++++++++- 1 file changed, 61 insertions(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/runners/direct/transform_evaluator.py b/sdks/python/apache_beam/runners/direct/transform_evaluator.py index 0fec8b8cb1db3..b1cb626ca0cb6 100644 --- a/sdks/python/apache_beam/runners/direct/transform_evaluator.py +++ b/sdks/python/apache_beam/runners/direct/transform_evaluator.py @@ -29,6 +29,7 @@ from apache_beam.runners.common import DoFnState from apache_beam.runners.direct.watermark_manager import WatermarkManager from apache_beam.runners.direct.transform_result import TransformResult +from apache_beam.runners.dataflow.native_io.iobase import _NativeWrite # pylint: disable=protected-access from apache_beam.transforms import core from apache_beam.transforms.window import GlobalWindows from apache_beam.transforms.window import WindowedValue @@ -53,6 +54,7 @@ def __init__(self, evaluation_context): core.Flatten: _FlattenEvaluator, core.ParDo: _ParDoEvaluator, core._GroupByKeyOnly: _GroupByKeyOnlyEvaluator, + _NativeWrite: _NativeWriteEvaluator, } def for_application( @@ -96,7 +98,8 @@ def should_execute_serially(self, applied_ptransform): Returns: True if executor should execute applied_ptransform serially. """ - return isinstance(applied_ptransform.transform, core._GroupByKeyOnly) + return isinstance(applied_ptransform.transform, + (core._GroupByKeyOnly, _NativeWrite)) class _TransformEvaluator(object): @@ -400,3 +403,60 @@ def len_element_fn(element): return TransformResult( self._applied_ptransform, bundles, state, None, None, hold) + + +class _NativeWriteEvaluator(_TransformEvaluator): + """TransformEvaluator for _NativeWrite transform.""" + + def __init__(self, evaluation_context, applied_ptransform, + input_committed_bundle, side_inputs, scoped_metrics_container): + assert not side_inputs + super(_NativeWriteEvaluator, self).__init__( + evaluation_context, applied_ptransform, input_committed_bundle, + side_inputs, scoped_metrics_container) + + assert applied_ptransform.transform.sink + self._sink = applied_ptransform.transform.sink + + @property + def _is_final_bundle(self): + return (self._execution_context.watermarks.input_watermark + == WatermarkManager.WATERMARK_POS_INF) + + @property + def _has_already_produced_output(self): + return (self._execution_context.watermarks.output_watermark + == WatermarkManager.WATERMARK_POS_INF) + + def start_bundle(self): + # state: [values] + self.state = (self._execution_context.existing_state + if self._execution_context.existing_state else []) + + def process_element(self, element): + self.state.append(element) + + def finish_bundle(self): + # finish_bundle will append incoming bundles in memory until all the bundles + # carrying data is processed. This is done to produce only a single output + # shard (some tests depends on this behavior). It is possible to have + # incoming empty bundles after the output is produced, these bundles will be + # ignored and would not generate additional output files. + # TODO(altay): Do not wait until the last bundle to write in a single shard. + if self._is_final_bundle: + if self._has_already_produced_output: + # Ignore empty bundles that arrive after the output is produced. + assert self.state == [] + else: + self._sink.pipeline_options = self._evaluation_context.pipeline_options + with self._sink.writer() as writer: + for v in self.state: + writer.Write(v.value) + state = None + hold = WatermarkManager.WATERMARK_POS_INF + else: + state = self.state + hold = WatermarkManager.WATERMARK_NEG_INF + + return TransformResult( + self._applied_ptransform, [], state, None, None, hold) From f0f98c70b0b01c43ce1b093c2035b20bd90ba907 Mon Sep 17 00:00:00 2001 From: Thomas Groh Date: Mon, 12 Jun 2017 11:07:47 -0700 Subject: [PATCH 025/200] Cleanup Combine Tests with Context Split out the "shared" bit of all the accumulators, so they show up as an explicit component of the final result string. Update timestamped creation logic. --- .../beam/sdk/transforms/CombineTest.java | 225 ++++++++++++------ 1 file changed, 154 insertions(+), 71 deletions(-) diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/CombineTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/CombineTest.java index dc9788f69a1e1..6a4348de5567e 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/CombineTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/CombineTest.java @@ -95,14 +95,6 @@ public class CombineTest implements Serializable { // This test is Serializable, just so that it's easy to have // anonymous inner classes inside the non-static test methods. - static final List> TABLE = Arrays.asList( - KV.of("a", 1), - KV.of("a", 1), - KV.of("a", 4), - KV.of("b", 1), - KV.of("b", 13) - ); - static final List> EMPTY_TABLE = Collections.emptyList(); @Mock private DoFn.ProcessContext processContext; @@ -168,16 +160,28 @@ private void runTestSimpleCombineWithContext(List> table, @Category(ValidatesRunner.class) @SuppressWarnings({"rawtypes", "unchecked"}) public void testSimpleCombine() { - runTestSimpleCombine(TABLE, 20, Arrays.asList(KV.of("a", "114"), KV.of("b", "113"))); + runTestSimpleCombine(Arrays.asList( + KV.of("a", 1), + KV.of("a", 1), + KV.of("a", 4), + KV.of("b", 1), + KV.of("b", 13) + ), 20, Arrays.asList(KV.of("a", "114"), KV.of("b", "113"))); } @Test @Category(ValidatesRunner.class) @SuppressWarnings({"rawtypes", "unchecked"}) public void testSimpleCombineWithContext() { - runTestSimpleCombineWithContext(TABLE, 20, - Arrays.asList(KV.of("a", "01124"), KV.of("b", "01123")), - new String[] {"01111234"}); + runTestSimpleCombineWithContext(Arrays.asList( + KV.of("a", 1), + KV.of("a", 1), + KV.of("a", 4), + KV.of("b", 1), + KV.of("b", 13) + ), 20, + Arrays.asList(KV.of("a", "20:114"), KV.of("b", "20:113")), + new String[] {"20:111134"}); } @Test @@ -216,7 +220,13 @@ private void runTestBasicCombine(List> table, @Test @Category(ValidatesRunner.class) public void testBasicCombine() { - runTestBasicCombine(TABLE, ImmutableSet.of(1, 13, 4), Arrays.asList( + runTestBasicCombine(Arrays.asList( + KV.of("a", 1), + KV.of("a", 1), + KV.of("a", 4), + KV.of("b", 1), + KV.of("b", 13) + ), ImmutableSet.of(1, 13, 4), Arrays.asList( KV.of("a", (Set) ImmutableSet.of(1, 4)), KV.of("b", (Set) ImmutableSet.of(1, 13)))); } @@ -251,9 +261,16 @@ private void runTestAccumulatingCombine(List> table, @Category(ValidatesRunner.class) public void testFixedWindowsCombine() { PCollection> input = - pipeline.apply(Create.timestamped(TABLE, Arrays.asList(0L, 1L, 6L, 7L, 8L)) - .withCoder(KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of()))) - .apply(Window.>into(FixedWindows.of(Duration.millis(2)))); + pipeline + .apply( + Create.timestamped( + TimestampedValue.of(KV.of("a", 1), new Instant(0L)), + TimestampedValue.of(KV.of("a", 1), new Instant(1L)), + TimestampedValue.of(KV.of("a", 4), new Instant(6L)), + TimestampedValue.of(KV.of("b", 1), new Instant(7L)), + TimestampedValue.of(KV.of("b", 13), new Instant(8L))) + .withCoder(KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of()))) + .apply(Window.>into(FixedWindows.of(Duration.millis(2)))); PCollection sum = input .apply(Values.create()) @@ -275,9 +292,16 @@ public void testFixedWindowsCombine() { @Category(ValidatesRunner.class) public void testFixedWindowsCombineWithContext() { PCollection> perKeyInput = - pipeline.apply(Create.timestamped(TABLE, Arrays.asList(0L, 1L, 6L, 7L, 8L)) - .withCoder(KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of()))) - .apply(Window.>into(FixedWindows.of(Duration.millis(2)))); + pipeline + .apply( + Create.timestamped( + TimestampedValue.of(KV.of("a", 1), new Instant(0L)), + TimestampedValue.of(KV.of("a", 1), new Instant(1L)), + TimestampedValue.of(KV.of("a", 4), new Instant(6L)), + TimestampedValue.of(KV.of("b", 1), new Instant(7L)), + TimestampedValue.of(KV.of("b", 13), new Instant(8L))) + .withCoder(KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of()))) + .apply(Window.>into(FixedWindows.of(Duration.millis(2)))); PCollection globallyInput = perKeyInput.apply(Values.create()); @@ -298,26 +322,33 @@ public void testFixedWindowsCombineWithContext() { PAssert.that(sum).containsInAnyOrder(2, 5, 13); PAssert.that(combinePerKeyWithContext).containsInAnyOrder( - KV.of("a", "112"), - KV.of("a", "45"), - KV.of("b", "15"), - KV.of("b", "1133")); - PAssert.that(combineGloballyWithContext).containsInAnyOrder("112", "145", "1133"); + KV.of("a", "2:11"), + KV.of("a", "5:4"), + KV.of("b", "5:1"), + KV.of("b", "13:13")); + PAssert.that(combineGloballyWithContext).containsInAnyOrder("2:11", "5:14", "13:13"); pipeline.run(); } @Test @Category(ValidatesRunner.class) public void testSlidingWindowsCombineWithContext() { + // [a: 1, 1], [a: 4; b: 1], [b: 13] PCollection> perKeyInput = - pipeline.apply(Create.timestamped(TABLE, Arrays.asList(2L, 3L, 8L, 9L, 10L)) - .withCoder(KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of()))) - .apply(Window.>into(SlidingWindows.of(Duration.millis(2)))); + pipeline + .apply( + Create.timestamped( + TimestampedValue.of(KV.of("a", 1), new Instant(2L)), + TimestampedValue.of(KV.of("a", 1), new Instant(3L)), + TimestampedValue.of(KV.of("a", 4), new Instant(8L)), + TimestampedValue.of(KV.of("b", 1), new Instant(9L)), + TimestampedValue.of(KV.of("b", 13), new Instant(10L))) + .withCoder(KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of()))) + .apply(Window.>into(SlidingWindows.of(Duration.millis(2)))); PCollection globallyInput = perKeyInput.apply(Values.create()); - PCollection sum = globallyInput - .apply("Sum", Combine.globally(new SumInts()).withoutDefaults()); + PCollection sum = globallyInput.apply("Sum", Sum.integersGlobally().withoutDefaults()); PCollectionView globallySumView = sum.apply(View.asSingleton()); @@ -333,16 +364,16 @@ public void testSlidingWindowsCombineWithContext() { PAssert.that(sum).containsInAnyOrder(1, 2, 1, 4, 5, 14, 13); PAssert.that(combinePerKeyWithContext).containsInAnyOrder( - KV.of("a", "11"), - KV.of("a", "112"), - KV.of("a", "11"), - KV.of("a", "44"), - KV.of("a", "45"), - KV.of("b", "15"), - KV.of("b", "11134"), - KV.of("b", "1133")); + KV.of("a", "1:1"), + KV.of("a", "2:11"), + KV.of("a", "1:1"), + KV.of("a", "4:4"), + KV.of("a", "5:4"), + KV.of("b", "5:1"), + KV.of("b", "14:113"), + KV.of("b", "13:13")); PAssert.that(combineGloballyWithContext).containsInAnyOrder( - "11", "112", "11", "44", "145", "11134", "1133"); + "1:1", "2:11", "1:1", "4:4", "5:14", "14:113", "13:13"); pipeline.run(); } @@ -383,9 +414,16 @@ public Void apply(Iterable input) { @Category(ValidatesRunner.class) public void testSessionsCombine() { PCollection> input = - pipeline.apply(Create.timestamped(TABLE, Arrays.asList(0L, 4L, 7L, 10L, 16L)) - .withCoder(KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of()))) - .apply(Window.>into(Sessions.withGapDuration(Duration.millis(5)))); + pipeline + .apply( + Create.timestamped( + TimestampedValue.of(KV.of("a", 1), new Instant(0L)), + TimestampedValue.of(KV.of("a", 1), new Instant(4L)), + TimestampedValue.of(KV.of("a", 4), new Instant(7L)), + TimestampedValue.of(KV.of("b", 1), new Instant(10L)), + TimestampedValue.of(KV.of("b", 13), new Instant(16L))) + .withCoder(KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of()))) + .apply(Window.>into(Sessions.withGapDuration(Duration.millis(5)))); PCollection sum = input .apply(Values.create()) @@ -406,7 +444,13 @@ public void testSessionsCombine() { @Category(ValidatesRunner.class) public void testSessionsCombineWithContext() { PCollection> perKeyInput = - pipeline.apply(Create.timestamped(TABLE, Arrays.asList(0L, 4L, 7L, 10L, 16L)) + pipeline.apply( + Create.timestamped( + TimestampedValue.of(KV.of("a", 1), new Instant(0L)), + TimestampedValue.of(KV.of("a", 1), new Instant(4L)), + TimestampedValue.of(KV.of("a", 4), new Instant(7L)), + TimestampedValue.of(KV.of("b", 1), new Instant(10L)), + TimestampedValue.of(KV.of("b", 13), new Instant(16L))) .withCoder(KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of()))); PCollection globallyInput = perKeyInput.apply(Values.create()); @@ -429,19 +473,22 @@ public void testSessionsCombineWithContext() { new TestCombineFnWithContext(globallyFixedWindowsView)) .withSideInputs(Arrays.asList(globallyFixedWindowsView))); - PCollection sessionsCombineGlobally = globallyInput - .apply("Globally Input Sessions", - Window.into(Sessions.withGapDuration(Duration.millis(5)))) - .apply(Combine.globally(new TestCombineFnWithContext(globallyFixedWindowsView)) - .withoutDefaults() - .withSideInputs(Arrays.asList(globallyFixedWindowsView))); + PCollection sessionsCombineGlobally = + globallyInput + .apply( + "Globally Input Sessions", + Window.into(Sessions.withGapDuration(Duration.millis(5)))) + .apply( + Combine.globally(new TestCombineFnWithContext(globallyFixedWindowsView)) + .withoutDefaults() + .withSideInputs(Arrays.asList(globallyFixedWindowsView))); PAssert.that(fixedWindowsSum).containsInAnyOrder(2, 4, 1, 13); PAssert.that(sessionsCombinePerKey).containsInAnyOrder( - KV.of("a", "1114"), - KV.of("b", "11"), - KV.of("b", "013")); - PAssert.that(sessionsCombineGlobally).containsInAnyOrder("11114", "013"); + KV.of("a", "1:114"), + KV.of("b", "1:1"), + KV.of("b", "0:13")); + PAssert.that(sessionsCombineGlobally).containsInAnyOrder("1:1114", "0:13"); pipeline.run(); } @@ -461,7 +508,13 @@ public void testWindowedCombineEmpty() { @Test @Category(ValidatesRunner.class) public void testAccumulatingCombine() { - runTestAccumulatingCombine(TABLE, 4.0, Arrays.asList(KV.of("a", 2.0), KV.of("b", 7.0))); + runTestAccumulatingCombine(Arrays.asList( + KV.of("a", 1), + KV.of("a", 1), + KV.of("a", 4), + KV.of("b", 1), + KV.of("b", 13) + ), 4.0, Arrays.asList(KV.of("a", 2.0), KV.of("b", 7.0))); } @Test @@ -503,7 +556,13 @@ public Integer apply(String input) { @Test @Category(ValidatesRunner.class) public void testHotKeyCombining() { - PCollection> input = copy(createInput(pipeline, TABLE), 10); + PCollection> input = copy(createInput(pipeline, Arrays.asList( + KV.of("a", 1), + KV.of("a", 1), + KV.of("a", 4), + KV.of("b", 1), + KV.of("b", 13) + )), 10); CombineFn mean = new MeanInts(); PCollection> coldMean = input.apply("ColdMean", @@ -560,7 +619,13 @@ public Void apply(Iterable input) { @Test @Category(NeedsRunner.class) public void testBinaryCombineFn() { - PCollection> input = copy(createInput(pipeline, TABLE), 2); + PCollection> input = copy(createInput(pipeline, Arrays.asList( + KV.of("a", 1), + KV.of("a", 1), + KV.of("a", 4), + KV.of("b", 1), + KV.of("b", 13) + )), 2); PCollection> intProduct = input .apply("IntProduct", Combine.perKey(new TestProdInt())); PCollection> objProduct = input @@ -917,8 +982,10 @@ public static class TestCombineFn extends CombineFn getAccumulatorCoder( @Override public Accumulator createAccumulator() { - return new Accumulator(""); + return new Accumulator("", ""); } @Override public Accumulator addInput(Accumulator accumulator, Integer value) { try { - return new Accumulator(accumulator.value + String.valueOf(value)); + return new Accumulator(accumulator.seed, accumulator.value + String.valueOf(value)); } finally { accumulator.value = "cleared in addInput"; } @@ -972,12 +1042,18 @@ public Accumulator addInput(Accumulator accumulator, Integer value) { @Override public Accumulator mergeAccumulators(Iterable accumulators) { + String seed = null; String all = ""; for (Accumulator accumulator : accumulators) { + if (seed == null) { + seed = accumulator.seed; + } else { + checkArgument(seed.equals(accumulator.seed), "Different seed values in accumulator"); + } all += accumulator.value; accumulator.value = "cleared in mergeAccumulators"; } - return new Accumulator(all); + return new Accumulator(seed, all); } @Override @@ -1007,40 +1083,47 @@ public Coder getAccumulatorCoder( @Override public TestCombineFn.Accumulator createAccumulator(Context c) { - return new TestCombineFn.Accumulator(c.sideInput(view).toString()); + Integer sideInputValue = c.sideInput(view); + return new TestCombineFn.Accumulator(sideInputValue.toString(), ""); } @Override public TestCombineFn.Accumulator addInput( TestCombineFn.Accumulator accumulator, Integer value, Context c) { try { - assertThat(accumulator.value, Matchers.startsWith(c.sideInput(view).toString())); - return new TestCombineFn.Accumulator(accumulator.value + String.valueOf(value)); + assertThat( + "Not expecting view contents to change", + accumulator.seed, + Matchers.equalTo(Integer.toString(c.sideInput(view)))); + return new TestCombineFn.Accumulator( + accumulator.seed, accumulator.value + String.valueOf(value)); } finally { accumulator.value = "cleared in addInput"; } - } @Override public TestCombineFn.Accumulator mergeAccumulators( Iterable accumulators, Context c) { - String prefix = c.sideInput(view).toString(); - String all = prefix; + String sideInputValue = c.sideInput(view).toString(); + StringBuilder all = new StringBuilder(); for (TestCombineFn.Accumulator accumulator : accumulators) { - assertThat(accumulator.value, Matchers.startsWith(prefix)); - all += accumulator.value.substring(prefix.length()); + assertThat( + "Accumulators should all have the same Side Input Value", + accumulator.seed, + Matchers.equalTo(sideInputValue)); + all.append(accumulator.value); accumulator.value = "cleared in mergeAccumulators"; } - return new TestCombineFn.Accumulator(all); + return new TestCombineFn.Accumulator(sideInputValue, all.toString()); } @Override public String extractOutput(TestCombineFn.Accumulator accumulator, Context c) { - assertThat(accumulator.value, Matchers.startsWith(c.sideInput(view).toString())); + assertThat(accumulator.seed, Matchers.startsWith(c.sideInput(view).toString())); char[] chars = accumulator.value.toCharArray(); Arrays.sort(chars); - return new String(chars); + return accumulator.seed + ":" + new String(chars); } } From 1ac18b2eb1371422e60d50a8c3f37b3b24d59611 Mon Sep 17 00:00:00 2001 From: Thomas Groh Date: Mon, 12 Jun 2017 16:55:59 -0700 Subject: [PATCH 026/200] Check for Deferral on Non-additional inputs Because Side Inputs are represented within the expanded inputs, the check that the transform is a Combine with Side Inputs would never be hit. This ensures that we do not consider additional inputs during the check to defer evaluation of the node. --- .../java/org/apache/beam/runners/spark/SparkRunner.java | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunner.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunner.java index 9e2426ef83814..d008718af0cc8 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunner.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunner.java @@ -26,6 +26,7 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; +import org.apache.beam.runners.core.construction.TransformInputs; import org.apache.beam.runners.spark.aggregators.AggregatorsAccumulator; import org.apache.beam.runners.spark.io.CreateStream; import org.apache.beam.runners.spark.metrics.AggregatorMetricSource; @@ -359,10 +360,12 @@ public CompositeBehavior enterCompositeTransform(TransformHierarchy.Node node) { protected boolean shouldDefer(TransformHierarchy.Node node) { // if the input is not a PCollection, or it is but with non merging windows, don't defer. - if (node.getInputs().size() != 1) { + Collection nonAdditionalInputs = + TransformInputs.nonAdditionalInputs(node.toAppliedPTransform(getPipeline())); + if (nonAdditionalInputs.size() != 1) { return false; } - PValue input = Iterables.getOnlyElement(node.getInputs().values()); + PValue input = Iterables.getOnlyElement(nonAdditionalInputs); if (!(input instanceof PCollection) || ((PCollection) input).getWindowingStrategy().getWindowFn().isNonMerging()) { return false; From 6859f80400bb16bbca34ed282e7a5e8a1328f955 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Isma=C3=ABl=20Mej=C3=ADa?= Date: Mon, 5 Jun 2017 23:20:27 +0200 Subject: [PATCH 027/200] [BEAM-2412] Update HBaseIO to use HBase client 1.2.6 --- sdks/java/io/hbase/pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdks/java/io/hbase/pom.xml b/sdks/java/io/hbase/pom.xml index 746b993d6b04f..f81cd2461dcef 100644 --- a/sdks/java/io/hbase/pom.xml +++ b/sdks/java/io/hbase/pom.xml @@ -31,7 +31,7 @@ Library to read and write from/to HBase - 1.2.5 + 1.2.6 2.5.1 From 4c36508733a69fafce0f7dfb86c71eee5eb6bc84 Mon Sep 17 00:00:00 2001 From: JingsongLi Date: Wed, 7 Jun 2017 14:34:25 +0800 Subject: [PATCH 028/200] Use CoderTypeSerializer and remove unuse code in FlinkStateInternals --- .../streaming/state/FlinkStateInternals.java | 198 +----------------- 1 file changed, 10 insertions(+), 188 deletions(-) diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java index f0d3278191919..d8771de998f25 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java @@ -25,7 +25,6 @@ import org.apache.beam.runners.core.StateInternals; import org.apache.beam.runners.core.StateNamespace; import org.apache.beam.runners.core.StateTag; -import org.apache.beam.runners.flink.translation.types.CoderTypeInformation; import org.apache.beam.runners.flink.translation.types.CoderTypeSerializer; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.CoderException; @@ -196,9 +195,8 @@ private static class FlinkValueState implements ValueState { this.address = address; this.flinkStateBackend = flinkStateBackend; - CoderTypeInformation typeInfo = new CoderTypeInformation<>(coder); - - flinkStateDescriptor = new ValueStateDescriptor<>(address.getId(), typeInfo, null); + flinkStateDescriptor = new ValueStateDescriptor<>( + address.getId(), new CoderTypeSerializer<>(coder)); } @Override @@ -282,9 +280,8 @@ private static class FlinkBagState implements BagState { this.address = address; this.flinkStateBackend = flinkStateBackend; - CoderTypeInformation typeInfo = new CoderTypeInformation<>(coder); - - flinkStateDescriptor = new ListStateDescriptor<>(address.getId(), typeInfo); + flinkStateDescriptor = new ListStateDescriptor<>( + address.getId(), new CoderTypeSerializer<>(coder)); } @Override @@ -398,9 +395,8 @@ private static class FlinkCombiningState this.combineFn = combineFn; this.flinkStateBackend = flinkStateBackend; - CoderTypeInformation typeInfo = new CoderTypeInformation<>(accumCoder); - - flinkStateDescriptor = new ValueStateDescriptor<>(address.getId(), typeInfo, null); + flinkStateDescriptor = new ValueStateDescriptor<>( + address.getId(), new CoderTypeSerializer<>(accumCoder)); } @Override @@ -545,179 +541,6 @@ public int hashCode() { } } - private static class FlinkKeyedCombiningState - implements CombiningState { - - private final StateNamespace namespace; - private final StateTag> address; - private final Combine.CombineFn combineFn; - private final ValueStateDescriptor flinkStateDescriptor; - private final KeyedStateBackend flinkStateBackend; - private final FlinkStateInternals flinkStateInternals; - - FlinkKeyedCombiningState( - KeyedStateBackend flinkStateBackend, - StateTag> address, - Combine.CombineFn combineFn, - StateNamespace namespace, - Coder accumCoder, - FlinkStateInternals flinkStateInternals) { - - this.namespace = namespace; - this.address = address; - this.combineFn = combineFn; - this.flinkStateBackend = flinkStateBackend; - this.flinkStateInternals = flinkStateInternals; - - CoderTypeInformation typeInfo = new CoderTypeInformation<>(accumCoder); - - flinkStateDescriptor = new ValueStateDescriptor<>(address.getId(), typeInfo, null); - } - - @Override - public CombiningState readLater() { - return this; - } - - @Override - public void add(InputT value) { - try { - org.apache.flink.api.common.state.ValueState state = - flinkStateBackend.getPartitionedState( - namespace.stringKey(), - StringSerializer.INSTANCE, - flinkStateDescriptor); - - AccumT current = state.value(); - if (current == null) { - current = combineFn.createAccumulator(); - } - current = combineFn.addInput(current, value); - state.update(current); - } catch (RuntimeException re) { - throw re; - } catch (Exception e) { - throw new RuntimeException("Error adding to state." , e); - } - } - - @Override - public void addAccum(AccumT accum) { - try { - org.apache.flink.api.common.state.ValueState state = - flinkStateBackend.getPartitionedState( - namespace.stringKey(), - StringSerializer.INSTANCE, - flinkStateDescriptor); - - AccumT current = state.value(); - if (current == null) { - state.update(accum); - } else { - current = combineFn.mergeAccumulators(Lists.newArrayList(current, accum)); - state.update(current); - } - } catch (Exception e) { - throw new RuntimeException("Error adding to state.", e); - } - } - - @Override - public AccumT getAccum() { - try { - return flinkStateBackend.getPartitionedState( - namespace.stringKey(), - StringSerializer.INSTANCE, - flinkStateDescriptor).value(); - } catch (Exception e) { - throw new RuntimeException("Error reading state.", e); - } - } - - @Override - public AccumT mergeAccumulators(Iterable accumulators) { - return combineFn.mergeAccumulators(accumulators); - } - - @Override - public OutputT read() { - try { - org.apache.flink.api.common.state.ValueState state = - flinkStateBackend.getPartitionedState( - namespace.stringKey(), - StringSerializer.INSTANCE, - flinkStateDescriptor); - - AccumT accum = state.value(); - if (accum != null) { - return combineFn.extractOutput(accum); - } else { - return combineFn.extractOutput(combineFn.createAccumulator()); - } - } catch (Exception e) { - throw new RuntimeException("Error reading state.", e); - } - } - - @Override - public ReadableState isEmpty() { - return new ReadableState() { - @Override - public Boolean read() { - try { - return flinkStateBackend.getPartitionedState( - namespace.stringKey(), - StringSerializer.INSTANCE, - flinkStateDescriptor).value() == null; - } catch (Exception e) { - throw new RuntimeException("Error reading state.", e); - } - - } - - @Override - public ReadableState readLater() { - return this; - } - }; - } - - @Override - public void clear() { - try { - flinkStateBackend.getPartitionedState( - namespace.stringKey(), - StringSerializer.INSTANCE, - flinkStateDescriptor).clear(); - } catch (Exception e) { - throw new RuntimeException("Error clearing state.", e); - } - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - - FlinkKeyedCombiningState that = - (FlinkKeyedCombiningState) o; - - return namespace.equals(that.namespace) && address.equals(that.address); - - } - - @Override - public int hashCode() { - int result = namespace.hashCode(); - result = 31 * result + address.hashCode(); - return result; - } - } - private static class FlinkCombiningStateWithContext implements CombiningState { @@ -745,9 +568,8 @@ private static class FlinkCombiningStateWithContext this.flinkStateInternals = flinkStateInternals; this.context = context; - CoderTypeInformation typeInfo = new CoderTypeInformation<>(accumCoder); - - flinkStateDescriptor = new ValueStateDescriptor<>(address.getId(), typeInfo, null); + flinkStateDescriptor = new ValueStateDescriptor<>( + address.getId(), new CoderTypeSerializer<>(accumCoder)); } @Override @@ -913,8 +735,8 @@ public FlinkWatermarkHoldState( this.flinkStateBackend = flinkStateBackend; this.flinkStateInternals = flinkStateInternals; - CoderTypeInformation typeInfo = new CoderTypeInformation<>(InstantCoder.of()); - flinkStateDescriptor = new ValueStateDescriptor<>(address.getId(), typeInfo, null); + flinkStateDescriptor = new ValueStateDescriptor<>( + address.getId(), new CoderTypeSerializer<>(InstantCoder.of())); } @Override From 10b166b355a03daeae78dd1e71016fc72805939d Mon Sep 17 00:00:00 2001 From: JingsongLi Date: Wed, 7 Jun 2017 14:40:30 +0800 Subject: [PATCH 029/200] [BEAM-1483] Support SetState in Flink runner and fix MapState to be consistent with InMemoryStateInternals. --- runners/flink/pom.xml | 1 - .../streaming/state/FlinkStateInternals.java | 227 ++++++++++++++---- .../streaming/FlinkStateInternalsTest.java | 17 -- 3 files changed, 182 insertions(+), 63 deletions(-) diff --git a/runners/flink/pom.xml b/runners/flink/pom.xml index a5b8203507ffb..339aa8e445a97 100644 --- a/runners/flink/pom.xml +++ b/runners/flink/pom.xml @@ -91,7 +91,6 @@ org.apache.beam.sdk.testing.FlattenWithHeterogeneousCoders, org.apache.beam.sdk.testing.LargeKeys$Above100MB, - org.apache.beam.sdk.testing.UsesSetState, org.apache.beam.sdk.testing.UsesCommittedMetrics, org.apache.beam.sdk.testing.UsesTestStream, org.apache.beam.sdk.testing.UsesSplittableParDo diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java index d8771de998f25..a0b015b57d329 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java @@ -17,6 +17,7 @@ */ package org.apache.beam.runners.flink.translation.wrappers.streaming.state; +import com.google.common.collect.Iterables; import com.google.common.collect.Lists; import java.nio.ByteBuffer; import java.util.Collections; @@ -33,6 +34,7 @@ import org.apache.beam.sdk.state.CombiningState; import org.apache.beam.sdk.state.MapState; import org.apache.beam.sdk.state.ReadableState; +import org.apache.beam.sdk.state.ReadableStates; import org.apache.beam.sdk.state.SetState; import org.apache.beam.sdk.state.State; import org.apache.beam.sdk.state.StateContext; @@ -48,6 +50,7 @@ import org.apache.flink.api.common.state.ListStateDescriptor; import org.apache.flink.api.common.state.MapStateDescriptor; import org.apache.flink.api.common.state.ValueStateDescriptor; +import org.apache.flink.api.common.typeutils.base.BooleanSerializer; import org.apache.flink.api.common.typeutils.base.StringSerializer; import org.apache.flink.runtime.state.KeyedStateBackend; import org.joda.time.Instant; @@ -127,8 +130,8 @@ public BagState bindBag( @Override public SetState bindSet( StateTag> address, Coder elemCoder) { - throw new UnsupportedOperationException( - String.format("%s is not supported", SetState.class.getSimpleName())); + return new FlinkSetState<>( + flinkStateBackend, address, namespace, elemCoder); } @Override @@ -875,24 +878,15 @@ private static class FlinkMapState implements MapState get(final KeyT input) { - return new ReadableState() { - @Override - public ValueT read() { - try { - return flinkStateBackend.getPartitionedState( + try { + return ReadableStates.immediate( + flinkStateBackend.getPartitionedState( namespace.stringKey(), StringSerializer.INSTANCE, - flinkStateDescriptor).get(input); - } catch (Exception e) { - throw new RuntimeException("Error get from state.", e); - } - } - - @Override - public ReadableState readLater() { - return this; - } - }; + flinkStateDescriptor).get(input)); + } catch (Exception e) { + throw new RuntimeException("Error get from state.", e); + } } @Override @@ -909,32 +903,22 @@ public void put(KeyT key, ValueT value) { @Override public ReadableState putIfAbsent(final KeyT key, final ValueT value) { - return new ReadableState() { - @Override - public ValueT read() { - try { - ValueT current = flinkStateBackend.getPartitionedState( - namespace.stringKey(), - StringSerializer.INSTANCE, - flinkStateDescriptor).get(key); - - if (current == null) { - flinkStateBackend.getPartitionedState( - namespace.stringKey(), - StringSerializer.INSTANCE, - flinkStateDescriptor).put(key, value); - } - return current; - } catch (Exception e) { - throw new RuntimeException("Error put kv to state.", e); - } - } + try { + ValueT current = flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor).get(key); - @Override - public ReadableState readLater() { - return this; + if (current == null) { + flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor).put(key, value); } - }; + return ReadableStates.immediate(current); + } catch (Exception e) { + throw new RuntimeException("Error put kv to state.", e); + } } @Override @@ -955,10 +939,11 @@ public ReadableState> keys() { @Override public Iterable read() { try { - return flinkStateBackend.getPartitionedState( + Iterable result = flinkStateBackend.getPartitionedState( namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor).keys(); + return result != null ? result : Collections.emptyList(); } catch (Exception e) { throw new RuntimeException("Error get map state keys.", e); } @@ -977,10 +962,11 @@ public ReadableState> values() { @Override public Iterable read() { try { - return flinkStateBackend.getPartitionedState( + Iterable result = flinkStateBackend.getPartitionedState( namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor).values(); + return result != null ? result : Collections.emptyList(); } catch (Exception e) { throw new RuntimeException("Error get map state values.", e); } @@ -999,10 +985,11 @@ public ReadableState>> entries() { @Override public Iterable> read() { try { - return flinkStateBackend.getPartitionedState( + Iterable> result = flinkStateBackend.getPartitionedState( namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor).entries(); + return result != null ? result : Collections.>emptyList(); } catch (Exception e) { throw new RuntimeException("Error get map state entries.", e); } @@ -1050,4 +1037,154 @@ public int hashCode() { } } + private static class FlinkSetState implements SetState { + + private final StateNamespace namespace; + private final StateTag> address; + private final MapStateDescriptor flinkStateDescriptor; + private final KeyedStateBackend flinkStateBackend; + + FlinkSetState( + KeyedStateBackend flinkStateBackend, + StateTag> address, + StateNamespace namespace, + Coder coder) { + this.namespace = namespace; + this.address = address; + this.flinkStateBackend = flinkStateBackend; + this.flinkStateDescriptor = new MapStateDescriptor<>(address.getId(), + new CoderTypeSerializer<>(coder), new BooleanSerializer()); + } + + @Override + public ReadableState contains(final T t) { + try { + Boolean result = flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor).get(t); + return ReadableStates.immediate(result != null ? result : false); + } catch (Exception e) { + throw new RuntimeException("Error contains value from state.", e); + } + } + + @Override + public ReadableState addIfAbsent(final T t) { + try { + org.apache.flink.api.common.state.MapState state = + flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor); + boolean alreadyContained = state.contains(t); + if (!alreadyContained) { + state.put(t, true); + } + return ReadableStates.immediate(!alreadyContained); + } catch (Exception e) { + throw new RuntimeException("Error addIfAbsent value to state.", e); + } + } + + @Override + public void remove(T t) { + try { + flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor).remove(t); + } catch (Exception e) { + throw new RuntimeException("Error remove value to state.", e); + } + } + + @Override + public SetState readLater() { + return this; + } + + @Override + public void add(T value) { + try { + flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor).put(value, true); + } catch (Exception e) { + throw new RuntimeException("Error add value to state.", e); + } + } + + @Override + public ReadableState isEmpty() { + return new ReadableState() { + @Override + public Boolean read() { + try { + Iterable result = flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor).keys(); + return result == null || Iterables.isEmpty(result); + } catch (Exception e) { + throw new RuntimeException("Error isEmpty from state.", e); + } + } + + @Override + public ReadableState readLater() { + return this; + } + }; + } + + @Override + public Iterable read() { + try { + Iterable result = flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor).keys(); + return result != null ? result : Collections.emptyList(); + } catch (Exception e) { + throw new RuntimeException("Error read from state.", e); + } + } + + @Override + public void clear() { + try { + flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor).clear(); + } catch (Exception e) { + throw new RuntimeException("Error clearing state.", e); + } + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + FlinkSetState that = (FlinkSetState) o; + + return namespace.equals(that.namespace) && address.equals(that.address); + + } + + @Override + public int hashCode() { + int result = namespace.hashCode(); + result = 31 * result + address.hashCode(); + return result; + } + } + } diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java index e7564ec914a2c..b8d41de77b448 100644 --- a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java @@ -63,21 +63,4 @@ protected StateInternals createStateInternals() { } } - ///////////////////////// Unsupported tests \\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\ - - @Override - public void testSet() {} - - @Override - public void testSetIsEmpty() {} - - @Override - public void testMergeSetIntoSource() {} - - @Override - public void testMergeSetIntoNewNamespace() {} - - @Override - public void testMap() {} - } From 4d18606378f43c7b0d3ac05d45ca6e0570e49eef Mon Sep 17 00:00:00 2001 From: JingsongLi Date: Tue, 13 Jun 2017 10:15:33 +0800 Subject: [PATCH 030/200] Add set and map readable test to StateInternalsTest --- .../beam/runners/core/StateInternalsTest.java | 40 +++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/StateInternalsTest.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/StateInternalsTest.java index bf3156aad110e..6011fb48aed67 100644 --- a/runners/core-java/src/test/java/org/apache/beam/runners/core/StateInternalsTest.java +++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/StateInternalsTest.java @@ -27,6 +27,7 @@ import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; +import com.google.common.collect.Iterables; import java.util.Arrays; import java.util.Map; import java.util.Objects; @@ -570,4 +571,43 @@ public void testMergeLatestWatermarkIntoSource() throws Exception { assertThat(value1.read(), equalTo(null)); assertThat(value2.read(), equalTo(null)); } + + @Test + public void testSetReadable() throws Exception { + SetState value = underTest.state(NAMESPACE_1, STRING_SET_ADDR); + + // test contains + ReadableState readable = value.contains("A"); + value.add("A"); + assertFalse(readable.read()); + + // test addIfAbsent + value.addIfAbsent("B"); + assertTrue(value.contains("B").read()); + } + + @Test + public void testMapReadable() throws Exception { + MapState value = underTest.state(NAMESPACE_1, STRING_MAP_ADDR); + + // test iterable, should just return a iterable view of the values contained in this map. + // The iterable is backed by the map, so changes to the map are reflected in the iterable. + ReadableState> keys = value.keys(); + ReadableState> values = value.values(); + ReadableState>> entries = value.entries(); + value.put("A", 1); + assertFalse(Iterables.isEmpty(keys.read())); + assertFalse(Iterables.isEmpty(values.read())); + assertFalse(Iterables.isEmpty(entries.read())); + + // test get + ReadableState get = value.get("B"); + value.put("B", 2); + assertNull(get.read()); + + // test addIfAbsent + value.putIfAbsent("C", 3); + assertThat(value.get("C").read(), equalTo(3)); + } + } From e71eb66ae319bdf0cdad1fe9b54662962c8e8f16 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Fri, 9 Jun 2017 16:44:55 -0700 Subject: [PATCH 031/200] Actually test the fn_api_runner. The test suite was not being run due to a typo. Fix breakage due to changes in the code in the meantime. --- .../runners/portability/fn_api_runner.py | 16 ++++++---------- .../runners/portability/fn_api_runner_test.py | 4 ++-- .../apache_beam/runners/worker/operations.py | 1 + .../apache_beam/runners/worker/sdk_worker.py | 2 +- 4 files changed, 10 insertions(+), 13 deletions(-) diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner.py b/sdks/python/apache_beam/runners/portability/fn_api_runner.py index a83eae403701e..8c213ad08f381 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner.py @@ -179,7 +179,7 @@ def outputs(op): # into the sdk worker, or an injection of the source object into the # sdk worker as data followed by an SDF that reads that source. if (isinstance(operation.source.source, - worker_runner_base.InMemorySource) + maptask_executor_runner.InMemorySource) and isinstance(operation.source.source.default_output_coder(), WindowedValueCoder)): output_stream = create_OutputStream() @@ -264,11 +264,9 @@ def outputs(op): element_coder.get_impl().encode_to_stream( element, output_stream, True) elements_data = output_stream.get() - state_key = beam_fn_api_pb2.StateKey(function_spec_reference=view_id) + state_key = beam_fn_api_pb2.StateKey(key=view_id) state_handler.Clear(state_key) - state_handler.Append( - beam_fn_api_pb2.SimpleStateAppendRequest( - state_key=state_key, data=[elements_data])) + state_handler.Append(state_key, elements_data) elif isinstance(operation, operation_specs.WorkerFlatten): fn = sdk_worker.pack_function_spec_data( @@ -382,9 +380,8 @@ def Get(self, state_key): return beam_fn_api_pb2.Elements.Data( data=''.join(self._all[self._to_key(state_key)])) - def Append(self, append_request): - self._all[self._to_key(append_request.state_key)].extend( - append_request.data) + def Append(self, state_key, data): + self._all[self._to_key(state_key)].extend(data) def Clear(self, state_key): try: @@ -394,8 +391,7 @@ def Clear(self, state_key): @staticmethod def _to_key(state_key): - return (state_key.function_spec_reference, state_key.window, - state_key.key) + return state_key.window, state_key.key class DirectController(object): """An in-memory controller for fn API control, state and data planes.""" diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py index 633602f2a9ac3..66d985a9b053c 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py @@ -20,10 +20,10 @@ import apache_beam as beam from apache_beam.runners.portability import fn_api_runner -from apache_beam.runners.portability import maptask_executor_runner +from apache_beam.runners.portability import maptask_executor_runner_test -class FnApiRunnerTest(maptask_executor_runner.MapTaskExecutorRunner): +class FnApiRunnerTest(maptask_executor_runner_test.MapTaskExecutorRunnerTest): def create_pipeline(self): return beam.Pipeline(runner=fn_api_runner.FnApiRunner()) diff --git a/sdks/python/apache_beam/runners/worker/operations.py b/sdks/python/apache_beam/runners/worker/operations.py index a44561d096c83..c4f945bf2b9d6 100644 --- a/sdks/python/apache_beam/runners/worker/operations.py +++ b/sdks/python/apache_beam/runners/worker/operations.py @@ -129,6 +129,7 @@ def __init__(self, operation_name, spec, counter_factory, state_sampler): self.operation_name + '-finish') # TODO(ccy): the '-abort' state can be added when the abort is supported in # Operations. + self.scoped_metrics_container = None def start(self): """Start operation.""" diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker.py b/sdks/python/apache_beam/runners/worker/sdk_worker.py index 33f2b61dd3883..dc4f5c2ffab0b 100644 --- a/sdks/python/apache_beam/runners/worker/sdk_worker.py +++ b/sdks/python/apache_beam/runners/worker/sdk_worker.py @@ -359,7 +359,7 @@ def create_side_input(tag, si): source=SideInputSource( self.state_handler, beam_fn_api_pb2.StateKey( - function_spec_reference=si.view_fn.id), + key=si.view_fn.id.encode('utf-8')), coder=unpack_and_deserialize_py_fn(si.view_fn))) output_tags = list(transform.outputs.keys()) spec = operation_specs.WorkerDoFn( From b547b5a1eee61e293ff2f8ccfac57f308867328c Mon Sep 17 00:00:00 2001 From: Maria Garcia Herrero Date: Fri, 9 Jun 2017 23:34:59 -0700 Subject: [PATCH 032/200] Make unique test names for value-provider arguments --- .../options/pipeline_options_test.py | 39 ++++---- .../options/value_provider_test.py | 93 ++++++++++--------- 2 files changed, 71 insertions(+), 61 deletions(-) diff --git a/sdks/python/apache_beam/options/pipeline_options_test.py b/sdks/python/apache_beam/options/pipeline_options_test.py index 1a644b449e423..f4dd4d92b788b 100644 --- a/sdks/python/apache_beam/options/pipeline_options_test.py +++ b/sdks/python/apache_beam/options/pipeline_options_test.py @@ -192,47 +192,52 @@ def _add_argparse_args(cls, parser): options = PipelineOptions(['--redefined_flag']) self.assertTrue(options.get_all_options()['redefined_flag']) + # TODO(BEAM-1319): Require unique names only within a test. + # For now, _vp_arg will be the convention + # to name value-provider arguments in tests, as opposed to + # _non_vp_arg for non-value-provider arguments. + # The number will grow per file as tests are added. def test_value_provider_options(self): class UserOptions(PipelineOptions): @classmethod def _add_argparse_args(cls, parser): parser.add_value_provider_argument( - '--vp_arg', + '--pot_vp_arg1', help='This flag is a value provider') parser.add_value_provider_argument( - '--vp_arg2', + '--pot_vp_arg2', default=1, type=int) parser.add_argument( - '--non_vp_arg', + '--pot_non_vp_arg1', default=1, type=int ) # Provide values: if not provided, the option becomes of the type runtime vp - options = UserOptions(['--vp_arg', 'hello']) - self.assertIsInstance(options.vp_arg, StaticValueProvider) - self.assertIsInstance(options.vp_arg2, RuntimeValueProvider) - self.assertIsInstance(options.non_vp_arg, int) + options = UserOptions(['--pot_vp_arg1', 'hello']) + self.assertIsInstance(options.pot_vp_arg1, StaticValueProvider) + self.assertIsInstance(options.pot_vp_arg2, RuntimeValueProvider) + self.assertIsInstance(options.pot_non_vp_arg1, int) # Values can be overwritten - options = UserOptions(vp_arg=5, - vp_arg2=StaticValueProvider(value_type=str, - value='bye'), - non_vp_arg=RuntimeValueProvider( + options = UserOptions(pot_vp_arg1=5, + pot_vp_arg2=StaticValueProvider(value_type=str, + value='bye'), + pot_non_vp_arg1=RuntimeValueProvider( option_name='foo', value_type=int, default_value=10)) - self.assertEqual(options.vp_arg, 5) - self.assertTrue(options.vp_arg2.is_accessible(), - '%s is not accessible' % options.vp_arg2) - self.assertEqual(options.vp_arg2.get(), 'bye') - self.assertFalse(options.non_vp_arg.is_accessible()) + self.assertEqual(options.pot_vp_arg1, 5) + self.assertTrue(options.pot_vp_arg2.is_accessible(), + '%s is not accessible' % options.pot_vp_arg2) + self.assertEqual(options.pot_vp_arg2.get(), 'bye') + self.assertFalse(options.pot_non_vp_arg1.is_accessible()) with self.assertRaises(RuntimeError): - options.non_vp_arg.get() + options.pot_non_vp_arg1.get() if __name__ == '__main__': diff --git a/sdks/python/apache_beam/options/value_provider_test.py b/sdks/python/apache_beam/options/value_provider_test.py index 3a45e8b7d122f..17e9590d2a363 100644 --- a/sdks/python/apache_beam/options/value_provider_test.py +++ b/sdks/python/apache_beam/options/value_provider_test.py @@ -24,72 +24,77 @@ from apache_beam.options.value_provider import StaticValueProvider +# TODO(BEAM-1319): Require unique names only within a test. +# For now, _vp_arg will be the convention +# to name value-provider arguments in tests, as opposed to +# _non_vp_arg for non-value-provider arguments. +# The number will grow per file as tests are added. class ValueProviderTests(unittest.TestCase): def test_static_value_provider_keyword_argument(self): class UserDefinedOptions(PipelineOptions): @classmethod def _add_argparse_args(cls, parser): parser.add_value_provider_argument( - '--vp_arg', + '--vpt_vp_arg1', help='This keyword argument is a value provider', default='some value') - options = UserDefinedOptions(['--vp_arg', 'abc']) - self.assertTrue(isinstance(options.vp_arg, StaticValueProvider)) - self.assertTrue(options.vp_arg.is_accessible()) - self.assertEqual(options.vp_arg.get(), 'abc') + options = UserDefinedOptions(['--vpt_vp_arg1', 'abc']) + self.assertTrue(isinstance(options.vpt_vp_arg1, StaticValueProvider)) + self.assertTrue(options.vpt_vp_arg1.is_accessible()) + self.assertEqual(options.vpt_vp_arg1.get(), 'abc') def test_runtime_value_provider_keyword_argument(self): class UserDefinedOptions(PipelineOptions): @classmethod def _add_argparse_args(cls, parser): parser.add_value_provider_argument( - '--vp_arg', + '--vpt_vp_arg2', help='This keyword argument is a value provider') options = UserDefinedOptions() - self.assertTrue(isinstance(options.vp_arg, RuntimeValueProvider)) - self.assertFalse(options.vp_arg.is_accessible()) + self.assertTrue(isinstance(options.vpt_vp_arg2, RuntimeValueProvider)) + self.assertFalse(options.vpt_vp_arg2.is_accessible()) with self.assertRaises(RuntimeError): - options.vp_arg.get() + options.vpt_vp_arg2.get() def test_static_value_provider_positional_argument(self): class UserDefinedOptions(PipelineOptions): @classmethod def _add_argparse_args(cls, parser): parser.add_value_provider_argument( - 'vp_pos_arg', + 'vpt_vp_arg3', help='This positional argument is a value provider', default='some value') options = UserDefinedOptions(['abc']) - self.assertTrue(isinstance(options.vp_pos_arg, StaticValueProvider)) - self.assertTrue(options.vp_pos_arg.is_accessible()) - self.assertEqual(options.vp_pos_arg.get(), 'abc') + self.assertTrue(isinstance(options.vpt_vp_arg3, StaticValueProvider)) + self.assertTrue(options.vpt_vp_arg3.is_accessible()) + self.assertEqual(options.vpt_vp_arg3.get(), 'abc') def test_runtime_value_provider_positional_argument(self): class UserDefinedOptions(PipelineOptions): @classmethod def _add_argparse_args(cls, parser): parser.add_value_provider_argument( - 'vp_pos_arg', + 'vpt_vp_arg4', help='This positional argument is a value provider') options = UserDefinedOptions([]) - self.assertTrue(isinstance(options.vp_pos_arg, RuntimeValueProvider)) - self.assertFalse(options.vp_pos_arg.is_accessible()) + self.assertTrue(isinstance(options.vpt_vp_arg4, RuntimeValueProvider)) + self.assertFalse(options.vpt_vp_arg4.is_accessible()) with self.assertRaises(RuntimeError): - options.vp_pos_arg.get() + options.vpt_vp_arg4.get() def test_static_value_provider_type_cast(self): class UserDefinedOptions(PipelineOptions): @classmethod def _add_argparse_args(cls, parser): parser.add_value_provider_argument( - '--vp_arg', + '--vpt_vp_arg5', type=int, help='This flag is a value provider') - options = UserDefinedOptions(['--vp_arg', '123']) - self.assertTrue(isinstance(options.vp_arg, StaticValueProvider)) - self.assertTrue(options.vp_arg.is_accessible()) - self.assertEqual(options.vp_arg.get(), 123) + options = UserDefinedOptions(['--vpt_vp_arg5', '123']) + self.assertTrue(isinstance(options.vpt_vp_arg5, StaticValueProvider)) + self.assertTrue(options.vpt_vp_arg5.is_accessible()) + self.assertEqual(options.vpt_vp_arg5.get(), 123) def test_set_runtime_option(self): # define ValueProvider ptions, with and without default values @@ -97,25 +102,25 @@ class UserDefinedOptions1(PipelineOptions): @classmethod def _add_argparse_args(cls, parser): parser.add_value_provider_argument( - '--vp_arg', + '--vpt_vp_arg6', help='This keyword argument is a value provider') # set at runtime parser.add_value_provider_argument( # not set, had default int - '-v', '--vp_arg2', # with short form + '-v', '--vpt_vp_arg7', # with short form default=123, type=int) parser.add_value_provider_argument( # not set, had default str - '--vp-arg3', # with dash in name + '--vpt_vp-arg8', # with dash in name default='123', type=str) parser.add_value_provider_argument( # not set and no default - '--vp_arg4', + '--vpt_vp_arg9', type=float) parser.add_value_provider_argument( # positional argument set - 'vp_pos_arg', # default & runtime ignored + 'vpt_vp_arg10', # default & runtime ignored help='This positional argument is a value provider', type=float, default=5.4) @@ -123,23 +128,23 @@ def _add_argparse_args(cls, parser): # provide values at graph-construction time # (options not provided here become of the type RuntimeValueProvider) options = UserDefinedOptions1(['1.2']) - self.assertFalse(options.vp_arg.is_accessible()) - self.assertFalse(options.vp_arg2.is_accessible()) - self.assertFalse(options.vp_arg3.is_accessible()) - self.assertFalse(options.vp_arg4.is_accessible()) - self.assertTrue(options.vp_pos_arg.is_accessible()) + self.assertFalse(options.vpt_vp_arg6.is_accessible()) + self.assertFalse(options.vpt_vp_arg7.is_accessible()) + self.assertFalse(options.vpt_vp_arg8.is_accessible()) + self.assertFalse(options.vpt_vp_arg9.is_accessible()) + self.assertTrue(options.vpt_vp_arg10.is_accessible()) # provide values at job-execution time # (options not provided here will use their default, if they have one) - RuntimeValueProvider.set_runtime_options({'vp_arg': 'abc', - 'vp_pos_arg':'3.2'}) - self.assertTrue(options.vp_arg.is_accessible()) - self.assertEqual(options.vp_arg.get(), 'abc') - self.assertTrue(options.vp_arg2.is_accessible()) - self.assertEqual(options.vp_arg2.get(), 123) - self.assertTrue(options.vp_arg3.is_accessible()) - self.assertEqual(options.vp_arg3.get(), '123') - self.assertTrue(options.vp_arg4.is_accessible()) - self.assertIsNone(options.vp_arg4.get()) - self.assertTrue(options.vp_pos_arg.is_accessible()) - self.assertEqual(options.vp_pos_arg.get(), 1.2) + RuntimeValueProvider.set_runtime_options({'vpt_vp_arg6': 'abc', + 'vpt_vp_arg10':'3.2'}) + self.assertTrue(options.vpt_vp_arg6.is_accessible()) + self.assertEqual(options.vpt_vp_arg6.get(), 'abc') + self.assertTrue(options.vpt_vp_arg7.is_accessible()) + self.assertEqual(options.vpt_vp_arg7.get(), 123) + self.assertTrue(options.vpt_vp_arg8.is_accessible()) + self.assertEqual(options.vpt_vp_arg8.get(), '123') + self.assertTrue(options.vpt_vp_arg9.is_accessible()) + self.assertIsNone(options.vpt_vp_arg9.get()) + self.assertTrue(options.vpt_vp_arg10.is_accessible()) + self.assertEqual(options.vpt_vp_arg10.get(), 1.2) From ee728f1b2f617dac8e5cd729cacf1a46911021e0 Mon Sep 17 00:00:00 2001 From: Vikas Kedigehalli Date: Mon, 12 Jun 2017 23:11:22 -0700 Subject: [PATCH 033/200] Fix WindowValueCoder for large timestamps --- sdks/python/apache_beam/coders/coder_impl.py | 4 ++++ sdks/python/apache_beam/coders/coders_test_common.py | 8 ++++++++ 2 files changed, 12 insertions(+) diff --git a/sdks/python/apache_beam/coders/coder_impl.py b/sdks/python/apache_beam/coders/coder_impl.py index 10298bfbf3b24..2670250c36b63 100644 --- a/sdks/python/apache_beam/coders/coder_impl.py +++ b/sdks/python/apache_beam/coders/coder_impl.py @@ -710,6 +710,10 @@ def decode_from_stream(self, in_stream, nested): timestamp = MAX_TIMESTAMP.micros else: timestamp *= 1000 + if timestamp > MAX_TIMESTAMP.micros: + timestamp = MAX_TIMESTAMP.micros + if timestamp < MIN_TIMESTAMP.micros: + timestamp = MIN_TIMESTAMP.micros windows = self._windows_coder.decode_from_stream(in_stream, True) # Read PaneInfo encoded byte. diff --git a/sdks/python/apache_beam/coders/coders_test_common.py b/sdks/python/apache_beam/coders/coders_test_common.py index c9b67b3462397..577c53aee8ddb 100644 --- a/sdks/python/apache_beam/coders/coders_test_common.py +++ b/sdks/python/apache_beam/coders/coders_test_common.py @@ -23,6 +23,8 @@ import dill +from apache_beam.transforms.window import GlobalWindow +from apache_beam.utils.timestamp import MIN_TIMESTAMP import observable from apache_beam.transforms import window from apache_beam.utils import timestamp @@ -287,6 +289,12 @@ def test_windowed_value_coder(self): # Test binary representation self.assertEqual('\x7f\xdf;dZ\x1c\xac\t\x00\x00\x00\x01\x0f\x01', coder.encode(window.GlobalWindows.windowed_value(1))) + + # Test decoding large timestamp + self.assertEqual( + coder.decode('\x7f\xdf;dZ\x1c\xac\x08\x00\x00\x00\x01\x0f\x00'), + windowed_value.create(0, MIN_TIMESTAMP.micros, (GlobalWindow(),))) + # Test unnested self.check_coder( coders.WindowedValueCoder(coders.VarIntCoder()), From 33662b92f1a9936f594223c9aee1a7233f59a569 Mon Sep 17 00:00:00 2001 From: "chamikara@google.com" Date: Thu, 8 Jun 2017 14:56:24 -0700 Subject: [PATCH 034/200] Adds ability to dynamically replace PTransforms during runtime. To this end, adds two interfaces, PTransformMatcher and PTransformOverride. Currently only supports replacements where input and output types are an exact match (we have to address complexities due to type hints before supporting replacements with different types). This will be used by SplittableDoFn where matching ParDo transforms will be dynamically replaced by SplittableParDo. --- sdks/python/apache_beam/pipeline.py | 201 ++++++++++++++++++ sdks/python/apache_beam/pipeline_test.py | 35 +++ .../runners/direct/direct_runner.py | 11 + 3 files changed, 247 insertions(+) diff --git a/sdks/python/apache_beam/pipeline.py b/sdks/python/apache_beam/pipeline.py index cea7215b2b822..05715d7157042 100644 --- a/sdks/python/apache_beam/pipeline.py +++ b/sdks/python/apache_beam/pipeline.py @@ -45,6 +45,7 @@ from __future__ import absolute_import +import abc import collections import logging import os @@ -53,6 +54,7 @@ from apache_beam import pvalue from apache_beam.internal import pickler +from apache_beam.pvalue import PCollection from apache_beam.runners import create_runner from apache_beam.runners import PipelineRunner from apache_beam.transforms import ptransform @@ -157,6 +159,157 @@ def _root_transform(self): """Returns the root transform of the transform stack.""" return self.transforms_stack[0] + def _remove_labels_recursively(self, applied_transform): + for part in applied_transform.parts: + if part.full_label in self.applied_labels: + self.applied_labels.remove(part.full_label) + if part.parts: + for part2 in part.parts: + self._remove_labels_recursively(part2) + + def _replace(self, override): + + assert isinstance(override, PTransformOverride) + matcher = override.get_matcher() + + output_map = {} + output_replacements = {} + input_replacements = {} + + class TransformUpdater(PipelineVisitor): # pylint: disable=used-before-assignment + """"A visitor that replaces the matching PTransforms.""" + + def __init__(self, pipeline): + self.pipeline = pipeline + + def _replace_if_needed(self, transform_node): + if matcher(transform_node): + replacement_transform = override.get_replacement_transform( + transform_node.transform) + inputs = transform_node.inputs + # TODO: Support replacing PTransforms with multiple inputs. + if len(inputs) > 1: + raise NotImplementedError( + 'PTransform overriding is only supported for PTransforms that ' + 'have a single input. Tried to replace input of ' + 'AppliedPTransform %r that has %d inputs', + transform_node, len(inputs)) + transform_node.transform = replacement_transform + self.pipeline.transforms_stack.append(transform_node) + + # Keeping the same label for the replaced node but recursively + # removing labels of child transforms since they will be replaced + # during the expand below. + self.pipeline._remove_labels_recursively(transform_node) + + new_output = replacement_transform.expand(inputs[0]) + if new_output.producer is None: + # When current transform is a primitive, we set the producer here. + new_output.producer = transform_node + + # We only support replacing transforms with a single output with + # another transform that produces a single output. + # TODO: Support replacing PTransforms with multiple outputs. + if (len(transform_node.outputs) > 1 or + not isinstance(transform_node.outputs[None], PCollection) or + not isinstance(new_output, PCollection)): + raise NotImplementedError( + 'PTransform overriding is only supported for PTransforms that ' + 'have a single output. Tried to replace output of ' + 'AppliedPTransform %r with %r.' + , transform_node, new_output) + + # Recording updated outputs. This cannot be done in the same visitor + # since if we dynamically update output type here, we'll run into + # errors when visiting child nodes. + output_map[transform_node.outputs[None]] = new_output + + self.pipeline.transforms_stack.pop() + + def enter_composite_transform(self, transform_node): + self._replace_if_needed(transform_node) + + def visit_transform(self, transform_node): + self._replace_if_needed(transform_node) + + self.visit(TransformUpdater(self)) + + # Adjusting inputs and outputs + class InputOutputUpdater(PipelineVisitor): # pylint: disable=used-before-assignment + """"A visitor that records input and output values to be replaced. + + Input and output values that should be updated are recorded in maps + input_replacements and output_replacements respectively. + + We cannot update input and output values while visiting since that results + in validation errors. + """ + + def __init__(self, pipeline): + self.pipeline = pipeline + + def enter_composite_transform(self, transform_node): + self.visit_transform(transform_node) + + def visit_transform(self, transform_node): + if (None in transform_node.outputs and + transform_node.outputs[None] in output_map): + output_replacements[transform_node] = ( + output_map[transform_node.outputs[None]]) + + replace_input = False + for input in transform_node.inputs: + if input in output_map: + replace_input = True + break + + if replace_input: + new_input = [ + input if not input in output_map else output_map[input] + for input in transform_node.inputs] + input_replacements[transform_node] = new_input + + self.visit(InputOutputUpdater(self)) + + for transform in output_replacements: + transform.replace_output(output_replacements[transform]) + + for transform in input_replacements: + transform.inputs = input_replacements[transform] + + def _check_replacement(self, override): + matcher = override.get_matcher() + + class ReplacementValidator(PipelineVisitor): + def visit_transform(self, transform_node): + if matcher(transform_node): + raise RuntimeError('Transform node %r was not replaced as expected.', + transform_node) + + self.visit(ReplacementValidator()) + + def replace_all(self, replacements): + """ Dynamically replaces PTransforms in the currently populated hierarchy. + + Currently this only works for replacements where input and output types + are exactly the same. + TODO: Update this to also work for transform overrides where input and + output types are different. + + Args: + replacements a list of PTransformOverride objects. + """ + for override in replacements: + assert isinstance(override, PTransformOverride) + self._replace(override) + + # Checking if the PTransforms have been successfully replaced. This will + # result in a failure if a PTransform that was replaced in a given override + # gets re-added in a subsequent override. This is not allowed and ordering + # of PTransformOverride objects in 'replacements' is important. + for override in replacements: + self._check_replacement(override) + def run(self, test_runner_api=True): """Runs the pipeline. Returns whatever our runner returns after running.""" @@ -441,6 +594,20 @@ def real_producer(pv): for side_input in self.side_inputs: real_producer(side_input.pvalue).refcounts[side_input.pvalue.tag] += 1 + def replace_output(self, output, tag=None): + """Replaces the output defined by the given tag with the given output. + + Args: + output: replacement output + tag: tag of the output to be replaced. + """ + if isinstance(output, pvalue.DoOutputsTuple): + self.replace_output(output[output._main_tag]) + elif isinstance(output, pvalue.PValue): + self.outputs[tag] = output + else: + raise TypeError("Unexpected output type: %s" % output) + def add_output(self, output, tag=None): if isinstance(output, pvalue.DoOutputsTuple): self.add_output(output[output._main_tag]) @@ -564,3 +731,37 @@ def from_runner_api(proto, context): pc.tag = tag result.update_input_refcounts() return result + + +class PTransformOverride(object): + """For internal use only; no backwards-compatibility guarantees. + + Gives a matcher and replacements for matching PTransforms. + + TODO: Update this to support cases where input and/our output types are + different. + """ + __metaclass__ = abc.ABCMeta + + @abc.abstractmethod + def get_matcher(self): + """Gives a matcher that will be used to to perform this override. + + Returns: + a callable that takes an AppliedPTransform as a parameter and returns a + boolean as a result. + """ + raise NotImplementedError + + @abc.abstractmethod + def get_replacement_transform(self, ptransform): + """Provides a runner specific override for a given PTransform. + + Args: + ptransform: PTransform to be replaced. + Returns: + A PTransform that will be the replacement for the PTransform given as an + argument. + """ + # Returns a PTransformReplacement + raise NotImplementedError diff --git a/sdks/python/apache_beam/pipeline_test.py b/sdks/python/apache_beam/pipeline_test.py index e0775d1036810..f9b894f72eb72 100644 --- a/sdks/python/apache_beam/pipeline_test.py +++ b/sdks/python/apache_beam/pipeline_test.py @@ -28,9 +28,11 @@ from apache_beam.io import Read from apache_beam.metrics import Metrics from apache_beam.pipeline import Pipeline +from apache_beam.pipeline import PTransformOverride from apache_beam.pipeline import PipelineOptions from apache_beam.pipeline import PipelineVisitor from apache_beam.pvalue import AsSingleton +from apache_beam.runners import DirectRunner from apache_beam.runners.dataflow.native_io.iobase import NativeSource from apache_beam.testing.test_pipeline import TestPipeline from apache_beam.testing.util import assert_that @@ -75,6 +77,18 @@ def reader(self): return FakeSource._Reader(self._vals) +class DoubleParDo(beam.PTransform): + def expand(self, input): + return input | 'Inner' >> beam.Map(lambda a: a * 2) + + +class TripleParDo(beam.PTransform): + def expand(self, input): + # Keeping labels the same intentionally to make sure that there is no label + # conflict due to replacement. + return input | 'Inner' >> beam.Map(lambda a: a * 3) + + class PipelineTest(unittest.TestCase): @staticmethod @@ -285,6 +299,27 @@ def raise_exception(exn): # p = Pipeline('EagerRunner') # self.assertEqual([1, 4, 9], p | Create([1, 2, 3]) | Map(lambda x: x*x)) + def test_ptransform_overrides(self): + + def my_par_do_matcher(applied_ptransform): + return isinstance(applied_ptransform.transform, DoubleParDo) + + class MyParDoOverride(PTransformOverride): + + def get_matcher(self): + return my_par_do_matcher + + def get_replacement_transform(self, ptransform): + if isinstance(ptransform, DoubleParDo): + return TripleParDo() + raise ValueError('Unsupported type of transform: %r', ptransform) + + # Using following private variable for testing. + DirectRunner._PTRANSFORM_OVERRIDES.append(MyParDoOverride()) + with Pipeline() as p: + pcoll = p | beam.Create([1, 2, 3]) | 'Multiply' >> DoubleParDo() + assert_that(pcoll, equal_to([3, 6, 9])) + class DoFnTest(unittest.TestCase): diff --git a/sdks/python/apache_beam/runners/direct/direct_runner.py b/sdks/python/apache_beam/runners/direct/direct_runner.py index ecf5114feefdb..323f44b4b4bd6 100644 --- a/sdks/python/apache_beam/runners/direct/direct_runner.py +++ b/sdks/python/apache_beam/runners/direct/direct_runner.py @@ -42,6 +42,14 @@ class DirectRunner(PipelineRunner): """Executes a single pipeline on the local machine.""" + # A list of PTransformOverride objects to be applied before running a pipeline + # using DirectRunner. + # Currently this only works for overrides where the input and output types do + # not change. + # For internal SDK use only. This should not be updated by Beam pipeline + # authors. + _PTRANSFORM_OVERRIDES = [] + def __init__(self): self._cache = None @@ -59,6 +67,9 @@ def apply_CombinePerKey(self, transform, pcoll): def run(self, pipeline): """Execute the entire pipeline and returns an DirectPipelineResult.""" + # Performing configured PTransform overrides. + pipeline.replace_all(DirectRunner._PTRANSFORM_OVERRIDES) + # TODO: Move imports to top. Pipeline <-> Runner dependency cause problems # with resolving imports when they are at top. # pylint: disable=wrong-import-position From 3a48e47c73520580021c21a037ab412a761d0eeb Mon Sep 17 00:00:00 2001 From: Eugene Kirpichov Date: Mon, 12 Jun 2017 11:51:39 -0700 Subject: [PATCH 035/200] Improves message when transitively serializing PipelineOptions --- .../src/main/resources/beam/findbugs-filter.xml | 9 +++++++++ .../beam/sdk/options/ProxyInvocationHandler.java | 15 ++++++++++++++- .../sdk/options/ProxyInvocationHandlerTest.java | 12 ++++++++++++ 3 files changed, 35 insertions(+), 1 deletion(-) diff --git a/sdks/java/build-tools/src/main/resources/beam/findbugs-filter.xml b/sdks/java/build-tools/src/main/resources/beam/findbugs-filter.xml index 3430750d37a89..0c9080d408b70 100644 --- a/sdks/java/build-tools/src/main/resources/beam/findbugs-filter.xml +++ b/sdks/java/build-tools/src/main/resources/beam/findbugs-filter.xml @@ -412,4 +412,13 @@ + + + + + + diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/options/ProxyInvocationHandler.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/options/ProxyInvocationHandler.java index eda21a8aadb3b..3842388e8c0b9 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/options/ProxyInvocationHandler.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/options/ProxyInvocationHandler.java @@ -45,6 +45,8 @@ import com.google.common.collect.MutableClassToInstanceMap; import java.beans.PropertyDescriptor; import java.io.IOException; +import java.io.NotSerializableException; +import java.io.Serializable; import java.lang.annotation.Annotation; import java.lang.reflect.InvocationHandler; import java.lang.reflect.Method; @@ -87,7 +89,7 @@ * {@link PipelineOptions#as(Class)}. */ @ThreadSafe -class ProxyInvocationHandler implements InvocationHandler { +class ProxyInvocationHandler implements InvocationHandler, Serializable { /** * No two instances of this class are considered equivalent hence we generate a random hash code. */ @@ -164,6 +166,17 @@ public Object invoke(Object proxy, Method method, Object[] args) { + Arrays.toString(args) + "]."); } + private void writeObject(java.io.ObjectOutputStream stream) + throws IOException { + throw new NotSerializableException( + "PipelineOptions objects are not serializable and should not be embedded into transforms " + + "(did you capture a PipelineOptions object in a field or in an anonymous class?). " + + "Instead, if you're using a DoFn, access PipelineOptions at runtime " + + "via ProcessContext/StartBundleContext/FinishBundleContext.getPipelineOptions(), " + + "or pre-extract necessary fields from PipelineOptions " + + "at pipeline construction time."); + } + /** * Track whether options values are explicitly set, or retrieved from defaults. */ diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/options/ProxyInvocationHandlerTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/options/ProxyInvocationHandlerTest.java index 2c43f57a40fe9..d90cb4210139c 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/options/ProxyInvocationHandlerTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/options/ProxyInvocationHandlerTest.java @@ -44,6 +44,7 @@ import com.google.common.collect.Maps; import com.google.common.testing.EqualsTester; import java.io.IOException; +import java.io.NotSerializableException; import java.io.Serializable; import java.util.HashSet; import java.util.List; @@ -54,6 +55,7 @@ import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.display.DisplayData; +import org.apache.beam.sdk.util.SerializableUtils; import org.apache.beam.sdk.util.common.ReflectHelpers; import org.hamcrest.Matchers; import org.joda.time.Instant; @@ -1019,4 +1021,14 @@ public void testDisplayDataExcludesJsonIgnoreOptions() { DisplayData data = DisplayData.from(options); assertThat(data, not(hasDisplayItem("value"))); } + + private static class CapturesOptions implements Serializable { + PipelineOptions options = PipelineOptionsFactory.create(); + } + + @Test + public void testOptionsAreNotSerializable() { + expectedException.expectCause(Matchers.instanceOf(NotSerializableException.class)); + SerializableUtils.clone(new CapturesOptions()); + } } From 9115af488ceb907de121313ffa096d58a0ccc1e1 Mon Sep 17 00:00:00 2001 From: Mairbek Khadikov Date: Wed, 7 Jun 2017 16:27:01 -0700 Subject: [PATCH 036/200] SpannerIO: Introduced a MutationGroup. Allows to group together mutation in a logical bundle that is submitted in the same transaction. --- .../sdk/io/gcp/spanner/MutationGroup.java | 67 ++++++++++++++++ .../io/gcp/spanner/MutationSizeEstimator.java | 9 +++ .../beam/sdk/io/gcp/spanner/SpannerIO.java | 53 +++++++++++-- .../spanner/MutationSizeEstimatorTest.java | 12 +++ .../sdk/io/gcp/spanner/SpannerIOTest.java | 76 ++++++++++++++++--- 5 files changed, 197 insertions(+), 20 deletions(-) create mode 100644 sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/MutationGroup.java diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/MutationGroup.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/MutationGroup.java new file mode 100644 index 0000000000000..5b08da2f25361 --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/MutationGroup.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.spanner; + +import com.google.cloud.spanner.Mutation; +import com.google.common.collect.ImmutableList; + +import java.io.Serializable; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; + +/** + * A bundle of mutations that must be submitted atomically. + * + *

One of the mutations is chosen to be "primary", and can be used to determine partitions. + */ +public final class MutationGroup implements Serializable, Iterable { + private final ImmutableList mutations; + + /** + * Creates a new group. + * + * @param primary a primary mutation. + * @param other other mutations, usually interleaved in parent. + * @return new mutation group. + */ + public static MutationGroup create(Mutation primary, Mutation... other) { + return create(primary, Arrays.asList(other)); + } + + public static MutationGroup create(Mutation primary, Iterable other) { + return new MutationGroup(ImmutableList.builder().add(primary).addAll(other).build()); + } + + @Override + public Iterator iterator() { + return mutations.iterator(); + } + + private MutationGroup(ImmutableList mutations) { + this.mutations = mutations; + } + + public Mutation primary() { + return mutations.get(0); + } + + public List attached() { + return mutations.subList(1, mutations.size()); + } +} diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/MutationSizeEstimator.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/MutationSizeEstimator.java index 61652e736e908..241881693f8da 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/MutationSizeEstimator.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/MutationSizeEstimator.java @@ -44,6 +44,15 @@ static long sizeOf(Mutation m) { return result; } + /** Estimates a size of the mutation group in bytes. */ + public static long sizeOf(MutationGroup group) { + long result = 0; + for (Mutation m : group) { + result += sizeOf(m); + } + return result; + } + private static long estimatePrimitiveValue(Value v) { switch (v.getType().getCode()) { case BOOL: diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java index 5058d13f77bb1..af5253ba1f3b9 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java @@ -29,10 +29,12 @@ import com.google.cloud.spanner.Spanner; import com.google.cloud.spanner.SpannerOptions; import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.Iterables; import java.io.IOException; import java.util.ArrayList; import java.util.List; import javax.annotation.Nullable; + import org.apache.beam.sdk.annotations.Experimental; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.transforms.DoFn; @@ -88,6 +90,11 @@ *

  • If the pipeline was unexpectedly stopped, mutations that were already applied will not get * rolled back. * + * + *

    Use {@link MutationGroup} to ensure that a small set mutations is bundled together. It is + * guaranteed that mutations in a group are submitted in the same transaction. Build + * {@link SpannerIO.Write} transform, and call {@link Write#grouped()} method. It will return a + * transformation that can be applied to a PCollection of MutationGroup. */ @Experimental(Experimental.Kind.SOURCE_SINK) public class SpannerIO { @@ -187,6 +194,13 @@ public Write withDatabaseId(String databaseId) { return toBuilder().setDatabaseId(databaseId).build(); } + /** + * Same transform but can be applied to {@link PCollection} of {@link MutationGroup}. + */ + public WriteGrouped grouped() { + return new WriteGrouped(this); + } + @VisibleForTesting Write withServiceFactory(ServiceFactory serviceFactory) { return toBuilder().setServiceFactory(serviceFactory).build(); @@ -204,7 +218,9 @@ public void validate(PipelineOptions options) { @Override public PDone expand(PCollection input) { - input.apply("Write mutations to Cloud Spanner", ParDo.of(new SpannerWriteFn(this))); + input + .apply("To mutation group", ParDo.of(new ToMutationGroupFn())) + .apply("Write mutations to Cloud Spanner", ParDo.of(new SpannerWriteGroupFn(this))); return PDone.in(input.getPipeline()); } @@ -227,15 +243,37 @@ public void populateDisplayData(DisplayData.Builder builder) { } } + /** Same as {@link Write} but supports grouped mutations. */ + public static class WriteGrouped extends PTransform, PDone> { + private final Write spec; + + public WriteGrouped(Write spec) { + this.spec = spec; + } + + @Override public PDone expand(PCollection input) { + input.apply("Write mutations to Cloud Spanner", ParDo.of(new SpannerWriteGroupFn(spec))); + return PDone.in(input.getPipeline()); + } + } + + private static class ToMutationGroupFn extends DoFn { + @ProcessElement + public void processElement(ProcessContext c) throws Exception { + Mutation value = c.element(); + c.output(MutationGroup.create(value)); + } + } + /** Batches together and writes mutations to Google Cloud Spanner. */ @VisibleForTesting - static class SpannerWriteFn extends DoFn { - private static final Logger LOG = LoggerFactory.getLogger(SpannerWriteFn.class); + static class SpannerWriteGroupFn extends DoFn { + private static final Logger LOG = LoggerFactory.getLogger(SpannerWriteGroupFn.class); private final Write spec; private transient Spanner spanner; private transient DatabaseClient dbClient; // Current batch of mutations to be written. - private List mutations; + private List mutations; private long batchSizeBytes = 0; private static final int MAX_RETRIES = 5; @@ -244,8 +282,7 @@ static class SpannerWriteFn extends DoFn { .withMaxRetries(MAX_RETRIES) .withInitialBackoff(Duration.standardSeconds(5)); - @VisibleForTesting - SpannerWriteFn(Write spec) { + @VisibleForTesting SpannerWriteGroupFn(Write spec) { this.spec = spec; } @@ -261,7 +298,7 @@ public void setup() throws Exception { @ProcessElement public void processElement(ProcessContext c) throws Exception { - Mutation m = c.element(); + MutationGroup m = c.element(); mutations.add(m); batchSizeBytes += MutationSizeEstimator.sizeOf(m); if (batchSizeBytes >= spec.getBatchSizeBytes()) { @@ -319,7 +356,7 @@ private void flushBatch() throws AbortedException, IOException, InterruptedExcep while (true) { // Batch upsert rows. try { - dbClient.writeAtLeastOnce(mutations); + dbClient.writeAtLeastOnce(Iterables.concat(mutations)); // Break if the commit threw no exception. break; diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/MutationSizeEstimatorTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/MutationSizeEstimatorTest.java index 03eb28ed943dd..013b83d458663 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/MutationSizeEstimatorTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/MutationSizeEstimatorTest.java @@ -135,4 +135,16 @@ public void dates() throws Exception { assertThat(MutationSizeEstimator.sizeOf(timestampArray), is(24L)); assertThat(MutationSizeEstimator.sizeOf(dateArray), is(48L)); } + + @Test + public void group() throws Exception { + Mutation int64 = Mutation.newInsertOrUpdateBuilder("test").set("one").to(1).build(); + Mutation float64 = Mutation.newInsertOrUpdateBuilder("test").set("one").to(2.9).build(); + Mutation bool = Mutation.newInsertOrUpdateBuilder("test").set("one").to(false).build(); + + MutationGroup group = MutationGroup.create(int64, float64, bool); + + assertThat(MutationSizeEstimator.sizeOf(group), is(17L)); + } + } diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOTest.java index 5bdfea5522b24..4a759fb119173 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOTest.java @@ -114,9 +114,31 @@ public void singleMutationPipeline() throws Exception { } @Test - public void batching() throws Exception { + @Category(NeedsRunner.class) + public void singleMutationGroupPipeline() throws Exception { Mutation one = Mutation.newInsertOrUpdateBuilder("test").set("one").to(1).build(); Mutation two = Mutation.newInsertOrUpdateBuilder("test").set("two").to(2).build(); + Mutation three = Mutation.newInsertOrUpdateBuilder("test").set("three").to(3).build(); + PCollection mutations = pipeline + .apply(Create.of(g(one, two, three))); + mutations.apply( + SpannerIO.write() + .withProjectId("test-project") + .withInstanceId("test-instance") + .withDatabaseId("test-database") + .withServiceFactory(serviceFactory) + .grouped()); + pipeline.run(); + verify(serviceFactory.mockSpanner()) + .getDatabaseClient(DatabaseId.of("test-project", "test-instance", "test-database")); + verify(serviceFactory.mockDatabaseClient(), times(1)) + .writeAtLeastOnce(argThat(new IterableOfSize(3))); + } + + @Test + public void batching() throws Exception { + MutationGroup one = g(Mutation.newInsertOrUpdateBuilder("test").set("one").to(1).build()); + MutationGroup two = g(Mutation.newInsertOrUpdateBuilder("test").set("two").to(2).build()); SpannerIO.Write write = SpannerIO.write() .withProjectId("test-project") @@ -124,8 +146,8 @@ public void batching() throws Exception { .withDatabaseId("test-database") .withBatchSizeBytes(1000000000) .withServiceFactory(serviceFactory); - SpannerIO.SpannerWriteFn writerFn = new SpannerIO.SpannerWriteFn(write); - DoFnTester fnTester = DoFnTester.of(writerFn); + SpannerIO.SpannerWriteGroupFn writerFn = new SpannerIO.SpannerWriteGroupFn(write); + DoFnTester fnTester = DoFnTester.of(writerFn); fnTester.processBundle(Arrays.asList(one, two)); verify(serviceFactory.mockSpanner()) @@ -136,9 +158,9 @@ public void batching() throws Exception { @Test public void batchingGroups() throws Exception { - Mutation one = Mutation.newInsertOrUpdateBuilder("test").set("one").to(1).build(); - Mutation two = Mutation.newInsertOrUpdateBuilder("test").set("two").to(2).build(); - Mutation three = Mutation.newInsertOrUpdateBuilder("test").set("three").to(3).build(); + MutationGroup one = g(Mutation.newInsertOrUpdateBuilder("test").set("one").to(1).build()); + MutationGroup two = g(Mutation.newInsertOrUpdateBuilder("test").set("two").to(2).build()); + MutationGroup three = g(Mutation.newInsertOrUpdateBuilder("test").set("three").to(3).build()); // Have a room to accumulate one more item. long batchSize = MutationSizeEstimator.sizeOf(one) + 1; @@ -150,8 +172,8 @@ public void batchingGroups() throws Exception { .withDatabaseId("test-database") .withBatchSizeBytes(batchSize) .withServiceFactory(serviceFactory); - SpannerIO.SpannerWriteFn writerFn = new SpannerIO.SpannerWriteFn(write); - DoFnTester fnTester = DoFnTester.of(writerFn); + SpannerIO.SpannerWriteGroupFn writerFn = new SpannerIO.SpannerWriteGroupFn(write); + DoFnTester fnTester = DoFnTester.of(writerFn); fnTester.processBundle(Arrays.asList(one, two, three)); verify(serviceFactory.mockSpanner()) @@ -164,8 +186,8 @@ public void batchingGroups() throws Exception { @Test public void noBatching() throws Exception { - Mutation one = Mutation.newInsertOrUpdateBuilder("test").set("one").to(1).build(); - Mutation two = Mutation.newInsertOrUpdateBuilder("test").set("two").to(2).build(); + MutationGroup one = g(Mutation.newInsertOrUpdateBuilder("test").set("one").to(1).build()); + MutationGroup two = g(Mutation.newInsertOrUpdateBuilder("test").set("two").to(2).build()); SpannerIO.Write write = SpannerIO.write() .withProjectId("test-project") @@ -173,8 +195,8 @@ public void noBatching() throws Exception { .withDatabaseId("test-database") .withBatchSizeBytes(0) // turn off batching. .withServiceFactory(serviceFactory); - SpannerIO.SpannerWriteFn writerFn = new SpannerIO.SpannerWriteFn(write); - DoFnTester fnTester = DoFnTester.of(writerFn); + SpannerIO.SpannerWriteGroupFn writerFn = new SpannerIO.SpannerWriteGroupFn(write); + DoFnTester fnTester = DoFnTester.of(writerFn); fnTester.processBundle(Arrays.asList(one, two)); verify(serviceFactory.mockSpanner()) @@ -183,6 +205,32 @@ public void noBatching() throws Exception { .writeAtLeastOnce(argThat(new IterableOfSize(1))); } + @Test + public void groups() throws Exception { + Mutation one = Mutation.newInsertOrUpdateBuilder("test").set("one").to(1).build(); + Mutation two = Mutation.newInsertOrUpdateBuilder("test").set("two").to(2).build(); + Mutation three = Mutation.newInsertOrUpdateBuilder("test").set("three").to(3).build(); + + // Smallest batch size + long batchSize = 1; + + SpannerIO.Write write = + SpannerIO.write() + .withProjectId("test-project") + .withInstanceId("test-instance") + .withDatabaseId("test-database") + .withBatchSizeBytes(batchSize) + .withServiceFactory(serviceFactory); + SpannerIO.SpannerWriteGroupFn writerFn = new SpannerIO.SpannerWriteGroupFn(write); + DoFnTester fnTester = DoFnTester.of(writerFn); + fnTester.processBundle(Arrays.asList(g(one, two, three))); + + verify(serviceFactory.mockSpanner()) + .getDatabaseClient(DatabaseId.of("test-project", "test-instance", "test-database")); + verify(serviceFactory.mockDatabaseClient(), times(1)) + .writeAtLeastOnce(argThat(new IterableOfSize(3))); + } + private static class FakeServiceFactory implements ServiceFactory, Serializable { // Marked as static so they could be returned by serviceFactory, which is serializable. @@ -241,4 +289,8 @@ public boolean matches(Object argument) { return argument instanceof Iterable && Iterables.size((Iterable) argument) == size; } } + + private static MutationGroup g(Mutation m, Mutation... other) { + return MutationGroup.create(m, other); + } } From 7e4e51fda633d208c4cc5e88182e5db16156f2cb Mon Sep 17 00:00:00 2001 From: Charles Chen Date: Tue, 13 Jun 2017 15:03:15 -0700 Subject: [PATCH 037/200] Move Runner API protos to portability/api --- .gitignore | 2 +- sdks/python/apache_beam/coders/coders.py | 2 +- sdks/python/apache_beam/pipeline.py | 4 ++-- .../portability/{runners => }/api/__init__.py | 0 .../portability/runners/__init__.py | 18 ------------------ sdks/python/apache_beam/pvalue.py | 2 +- .../runners/dataflow/dataflow_runner.py | 4 ++-- .../apache_beam/runners/pipeline_context.py | 2 +- .../runners/portability/fn_api_runner.py | 2 +- .../apache_beam/runners/worker/data_plane.py | 2 +- .../runners/worker/data_plane_test.py | 2 +- .../apache_beam/runners/worker/log_handler.py | 2 +- .../runners/worker/log_handler_test.py | 2 +- .../apache_beam/runners/worker/sdk_worker.py | 2 +- .../runners/worker/sdk_worker_main.py | 2 +- .../runners/worker/sdk_worker_test.py | 2 +- sdks/python/apache_beam/transforms/core.py | 2 +- .../apache_beam/transforms/ptransform.py | 2 +- sdks/python/apache_beam/transforms/trigger.py | 2 +- sdks/python/apache_beam/transforms/window.py | 4 ++-- sdks/python/apache_beam/utils/urns.py | 2 +- sdks/python/gen_protos.py | 2 +- sdks/python/run_pylint.sh | 2 +- 23 files changed, 24 insertions(+), 42 deletions(-) rename sdks/python/apache_beam/portability/{runners => }/api/__init__.py (100%) delete mode 100644 sdks/python/apache_beam/portability/runners/__init__.py diff --git a/.gitignore b/.gitignore index 631d7f32cb965..36c5cc8774ea4 100644 --- a/.gitignore +++ b/.gitignore @@ -25,7 +25,7 @@ sdks/python/**/*.egg sdks/python/LICENSE sdks/python/NOTICE sdks/python/README.md -sdks/python/apache_beam/portability/runners/api/*pb2*.* +sdks/python/apache_beam/portability/api/*pb2*.* # Ignore IntelliJ files. .idea/ diff --git a/sdks/python/apache_beam/coders/coders.py b/sdks/python/apache_beam/coders/coders.py index 1be1f3c7a4775..c56ef52301b7d 100644 --- a/sdks/python/apache_beam/coders/coders.py +++ b/sdks/python/apache_beam/coders/coders.py @@ -25,7 +25,7 @@ import google.protobuf from apache_beam.coders import coder_impl -from apache_beam.portability.runners.api import beam_runner_api_pb2 +from apache_beam.portability.api import beam_runner_api_pb2 from apache_beam.utils import urns from apache_beam.utils import proto_utils diff --git a/sdks/python/apache_beam/pipeline.py b/sdks/python/apache_beam/pipeline.py index 05715d7157042..ab77956a0c1a2 100644 --- a/sdks/python/apache_beam/pipeline.py +++ b/sdks/python/apache_beam/pipeline.py @@ -492,7 +492,7 @@ def visit_value(self, value, _): def to_runner_api(self): """For internal use only; no backwards-compatibility guarantees.""" from apache_beam.runners import pipeline_context - from apache_beam.portability.runners.api import beam_runner_api_pb2 + from apache_beam.portability.api import beam_runner_api_pb2 context = pipeline_context.PipelineContext() # Mutates context; placing inline would force dependence on # argument evaluation order. @@ -692,7 +692,7 @@ def named_outputs(self): if isinstance(output, pvalue.PCollection)} def to_runner_api(self, context): - from apache_beam.portability.runners.api import beam_runner_api_pb2 + from apache_beam.portability.api import beam_runner_api_pb2 def transform_to_runner_api(transform, context): if transform is None: diff --git a/sdks/python/apache_beam/portability/runners/api/__init__.py b/sdks/python/apache_beam/portability/api/__init__.py similarity index 100% rename from sdks/python/apache_beam/portability/runners/api/__init__.py rename to sdks/python/apache_beam/portability/api/__init__.py diff --git a/sdks/python/apache_beam/portability/runners/__init__.py b/sdks/python/apache_beam/portability/runners/__init__.py deleted file mode 100644 index 0bce5d68f7243..0000000000000 --- a/sdks/python/apache_beam/portability/runners/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -"""For internal use only; no backwards-compatibility guarantees.""" diff --git a/sdks/python/apache_beam/pvalue.py b/sdks/python/apache_beam/pvalue.py index 8a774c4c5bf45..34a483e7bb9c5 100644 --- a/sdks/python/apache_beam/pvalue.py +++ b/sdks/python/apache_beam/pvalue.py @@ -128,7 +128,7 @@ def __reduce_ex__(self, unused_version): return _InvalidUnpickledPCollection, () def to_runner_api(self, context): - from apache_beam.portability.runners.api import beam_runner_api_pb2 + from apache_beam.portability.api import beam_runner_api_pb2 from apache_beam.internal import pickler return beam_runner_api_pb2.PCollection( unique_name='%d%s.%s' % ( diff --git a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py b/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py index a6cc25d715127..d6944b2802647 100644 --- a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py +++ b/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py @@ -732,7 +732,7 @@ def run__NativeWrite(self, transform_node): @classmethod def serialize_windowing_strategy(cls, windowing): from apache_beam.runners import pipeline_context - from apache_beam.portability.runners.api import beam_runner_api_pb2 + from apache_beam.portability.api import beam_runner_api_pb2 context = pipeline_context.PipelineContext() windowing_proto = windowing.to_runner_api(context) return cls.byte_array_to_json_string( @@ -745,7 +745,7 @@ def deserialize_windowing_strategy(cls, serialized_data): # Imported here to avoid circular dependencies. # pylint: disable=wrong-import-order, wrong-import-position from apache_beam.runners import pipeline_context - from apache_beam.portability.runners.api import beam_runner_api_pb2 + from apache_beam.portability.api import beam_runner_api_pb2 from apache_beam.transforms.core import Windowing proto = beam_runner_api_pb2.MessageWithComponents() proto.ParseFromString(cls.json_string_to_byte_array(serialized_data)) diff --git a/sdks/python/apache_beam/runners/pipeline_context.py b/sdks/python/apache_beam/runners/pipeline_context.py index 1330c3904edfe..e212abf8d9fc2 100644 --- a/sdks/python/apache_beam/runners/pipeline_context.py +++ b/sdks/python/apache_beam/runners/pipeline_context.py @@ -24,7 +24,7 @@ from apache_beam import pipeline from apache_beam import pvalue from apache_beam import coders -from apache_beam.portability.runners.api import beam_runner_api_pb2 +from apache_beam.portability.api import beam_runner_api_pb2 from apache_beam.transforms import core diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner.py b/sdks/python/apache_beam/runners/portability/fn_api_runner.py index 8c213ad08f381..90764f4dfac01 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner.py @@ -33,7 +33,7 @@ from apache_beam.internal import pickler from apache_beam.io import iobase from apache_beam.transforms.window import GlobalWindows -from apache_beam.portability.runners.api import beam_fn_api_pb2 +from apache_beam.portability.api import beam_fn_api_pb2 from apache_beam.runners.portability import maptask_executor_runner from apache_beam.runners.worker import data_plane from apache_beam.runners.worker import operation_specs diff --git a/sdks/python/apache_beam/runners/worker/data_plane.py b/sdks/python/apache_beam/runners/worker/data_plane.py index 734ee9cda36a3..bc981a8d30edd 100644 --- a/sdks/python/apache_beam/runners/worker/data_plane.py +++ b/sdks/python/apache_beam/runners/worker/data_plane.py @@ -28,7 +28,7 @@ import threading from apache_beam.coders import coder_impl -from apache_beam.portability.runners.api import beam_fn_api_pb2 +from apache_beam.portability.api import beam_fn_api_pb2 import grpc # This module is experimental. No backwards-compatibility guarantees. diff --git a/sdks/python/apache_beam/runners/worker/data_plane_test.py b/sdks/python/apache_beam/runners/worker/data_plane_test.py index a2b31e8eb72ad..360468a868747 100644 --- a/sdks/python/apache_beam/runners/worker/data_plane_test.py +++ b/sdks/python/apache_beam/runners/worker/data_plane_test.py @@ -29,7 +29,7 @@ from concurrent import futures import grpc -from apache_beam.portability.runners.api import beam_fn_api_pb2 +from apache_beam.portability.api import beam_fn_api_pb2 from apache_beam.runners.worker import data_plane diff --git a/sdks/python/apache_beam/runners/worker/log_handler.py b/sdks/python/apache_beam/runners/worker/log_handler.py index dca0e4bd11b4b..b8f635210d2ee 100644 --- a/sdks/python/apache_beam/runners/worker/log_handler.py +++ b/sdks/python/apache_beam/runners/worker/log_handler.py @@ -21,7 +21,7 @@ import Queue as queue import threading -from apache_beam.portability.runners.api import beam_fn_api_pb2 +from apache_beam.portability.api import beam_fn_api_pb2 import grpc # This module is experimental. No backwards-compatibility guarantees. diff --git a/sdks/python/apache_beam/runners/worker/log_handler_test.py b/sdks/python/apache_beam/runners/worker/log_handler_test.py index 6dd018f6ad457..2256bb5556f0b 100644 --- a/sdks/python/apache_beam/runners/worker/log_handler_test.py +++ b/sdks/python/apache_beam/runners/worker/log_handler_test.py @@ -22,7 +22,7 @@ from concurrent import futures import grpc -from apache_beam.portability.runners.api import beam_fn_api_pb2 +from apache_beam.portability.api import beam_fn_api_pb2 from apache_beam.runners.worker import log_handler diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker.py b/sdks/python/apache_beam/runners/worker/sdk_worker.py index dc4f5c2ffab0b..f662538e981dc 100644 --- a/sdks/python/apache_beam/runners/worker/sdk_worker.py +++ b/sdks/python/apache_beam/runners/worker/sdk_worker.py @@ -38,7 +38,7 @@ from apache_beam.io import iobase from apache_beam.runners.dataflow.native_io import iobase as native_iobase from apache_beam.utils import counters -from apache_beam.portability.runners.api import beam_fn_api_pb2 +from apache_beam.portability.api import beam_fn_api_pb2 from apache_beam.runners.worker import operation_specs from apache_beam.runners.worker import operations diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker_main.py b/sdks/python/apache_beam/runners/worker/sdk_worker_main.py index 9c11068a972e2..f3f1e023e0fc1 100644 --- a/sdks/python/apache_beam/runners/worker/sdk_worker_main.py +++ b/sdks/python/apache_beam/runners/worker/sdk_worker_main.py @@ -24,7 +24,7 @@ import grpc from google.protobuf import text_format -from apache_beam.portability.runners.api import beam_fn_api_pb2 +from apache_beam.portability.api import beam_fn_api_pb2 from apache_beam.runners.worker.log_handler import FnApiLogRecordHandler from apache_beam.runners.worker.sdk_worker import SdkHarness diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker_test.py b/sdks/python/apache_beam/runners/worker/sdk_worker_test.py index 93f60d3e7424a..c431bcdf24576 100644 --- a/sdks/python/apache_beam/runners/worker/sdk_worker_test.py +++ b/sdks/python/apache_beam/runners/worker/sdk_worker_test.py @@ -29,7 +29,7 @@ from apache_beam.io.concat_source_test import RangeSource from apache_beam.io.iobase import SourceBundle -from apache_beam.portability.runners.api import beam_fn_api_pb2 +from apache_beam.portability.api import beam_fn_api_pb2 from apache_beam.runners.worker import data_plane from apache_beam.runners.worker import sdk_worker diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py index d7fa770af3be7..a137a1357ad38 100644 --- a/sdks/python/apache_beam/transforms/core.py +++ b/sdks/python/apache_beam/transforms/core.py @@ -27,7 +27,7 @@ from apache_beam import typehints from apache_beam.coders import typecoders from apache_beam.internal import util -from apache_beam.portability.runners.api import beam_runner_api_pb2 +from apache_beam.portability.api import beam_runner_api_pb2 from apache_beam.transforms import ptransform from apache_beam.transforms.display import DisplayDataItem from apache_beam.transforms.display import HasDisplayData diff --git a/sdks/python/apache_beam/transforms/ptransform.py b/sdks/python/apache_beam/transforms/ptransform.py index 79fe3add26ddf..60413535f65a9 100644 --- a/sdks/python/apache_beam/transforms/ptransform.py +++ b/sdks/python/apache_beam/transforms/ptransform.py @@ -430,7 +430,7 @@ def register_urn(cls, urn, parameter_type, constructor): cls._known_urns[urn] = parameter_type, constructor def to_runner_api(self, context): - from apache_beam.portability.runners.api import beam_runner_api_pb2 + from apache_beam.portability.api import beam_runner_api_pb2 urn, typed_param = self.to_runner_api_parameter(context) return beam_runner_api_pb2.FunctionSpec( urn=urn, diff --git a/sdks/python/apache_beam/transforms/trigger.py b/sdks/python/apache_beam/transforms/trigger.py index 41516070e8e3f..89c6ec535db9d 100644 --- a/sdks/python/apache_beam/transforms/trigger.py +++ b/sdks/python/apache_beam/transforms/trigger.py @@ -33,7 +33,7 @@ from apache_beam.transforms.window import TimestampCombiner from apache_beam.transforms.window import WindowedValue from apache_beam.transforms.window import WindowFn -from apache_beam.portability.runners.api import beam_runner_api_pb2 +from apache_beam.portability.api import beam_runner_api_pb2 from apache_beam.utils.timestamp import MAX_TIMESTAMP from apache_beam.utils.timestamp import MIN_TIMESTAMP diff --git a/sdks/python/apache_beam/transforms/window.py b/sdks/python/apache_beam/transforms/window.py index 08c7a2d132f7b..458fb747eb226 100644 --- a/sdks/python/apache_beam/transforms/window.py +++ b/sdks/python/apache_beam/transforms/window.py @@ -55,8 +55,8 @@ from google.protobuf import timestamp_pb2 from apache_beam.coders import coders -from apache_beam.portability.runners.api import beam_runner_api_pb2 -from apache_beam.portability.runners.api import standard_window_fns_pb2 +from apache_beam.portability.api import beam_runner_api_pb2 +from apache_beam.portability.api import standard_window_fns_pb2 from apache_beam.transforms import timeutil from apache_beam.utils import proto_utils from apache_beam.utils import urns diff --git a/sdks/python/apache_beam/utils/urns.py b/sdks/python/apache_beam/utils/urns.py index b925bcc9fbcdd..e553eea95f6a0 100644 --- a/sdks/python/apache_beam/utils/urns.py +++ b/sdks/python/apache_beam/utils/urns.py @@ -102,7 +102,7 @@ def to_runner_api(self, context): Prefer overriding self.to_runner_api_parameter. """ - from apache_beam.portability.runners.api import beam_runner_api_pb2 + from apache_beam.portability.api import beam_runner_api_pb2 urn, typed_param = self.to_runner_api_parameter(context) return beam_runner_api_pb2.SdkFunctionSpec( spec=beam_runner_api_pb2.FunctionSpec( diff --git a/sdks/python/gen_protos.py b/sdks/python/gen_protos.py index a33c74b9cd92d..a3d963d18d781 100644 --- a/sdks/python/gen_protos.py +++ b/sdks/python/gen_protos.py @@ -35,7 +35,7 @@ os.path.join('..', 'common', 'fn-api', 'src', 'main', 'proto') ] -PYTHON_OUTPUT_PATH = os.path.join('apache_beam', 'portability', 'runners', 'api') +PYTHON_OUTPUT_PATH = os.path.join('apache_beam', 'portability', 'api') def generate_proto_files(): diff --git a/sdks/python/run_pylint.sh b/sdks/python/run_pylint.sh index 7434516bfdeb4..2691be4ea42c0 100755 --- a/sdks/python/run_pylint.sh +++ b/sdks/python/run_pylint.sh @@ -46,7 +46,7 @@ EXCLUDED_GENERATED_FILES=( "apache_beam/io/gcp/internal/clients/storage/storage_v1_client.py" "apache_beam/io/gcp/internal/clients/storage/storage_v1_messages.py" "apache_beam/coders/proto2_coder_test_messages_pb2.py" -apache_beam/portability/runners/api/*pb2*.py +apache_beam/portability/api/*pb2*.py ) FILES_TO_IGNORE="" From 8a850af3304a48618739dc23e286800dc0c4641a Mon Sep 17 00:00:00 2001 From: Thomas Groh Date: Tue, 13 Jun 2017 12:50:58 -0700 Subject: [PATCH 038/200] Do not produce Unprocessed Inputs if all inputs were Processed This stops the WatermarkManager "Pending Bundles" from growing without bound. --- .../beam/runners/direct/CommittedResult.java | 12 ++++----- .../runners/direct/EvaluationContext.java | 26 ++++++++++++++----- .../ExecutorServiceParallelExecutor.java | 9 ++++--- .../beam/runners/direct/WatermarkManager.java | 4 +-- .../runners/direct/CommittedResultTest.java | 17 +++++++----- .../runners/direct/TransformExecutorTest.java | 11 ++++++-- .../runners/direct/WatermarkManagerTest.java | 15 ++++++++--- 7 files changed, 62 insertions(+), 32 deletions(-) diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/CommittedResult.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/CommittedResult.java index 8c45449491b52..70e3ac3c988a3 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/CommittedResult.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/CommittedResult.java @@ -19,8 +19,8 @@ package org.apache.beam.runners.direct; import com.google.auto.value.AutoValue; +import com.google.common.base.Optional; import java.util.Set; -import javax.annotation.Nullable; import org.apache.beam.sdk.runners.AppliedPTransform; import org.apache.beam.sdk.transforms.View.CreatePCollectionView; @@ -36,12 +36,10 @@ abstract class CommittedResult { /** * Returns the {@link CommittedBundle} that contains the input elements that could not be - * processed by the evaluation. - * - *

    {@code null} if the input bundle was null. + * processed by the evaluation. The returned optional is present if there were any unprocessed + * input elements, and absent otherwise. */ - @Nullable - public abstract CommittedBundle getUnprocessedInputs(); + public abstract Optional> getUnprocessedInputs(); /** * Returns the outputs produced by the transform. @@ -59,7 +57,7 @@ abstract class CommittedResult { public static CommittedResult create( TransformResult original, - CommittedBundle unprocessedElements, + Optional> unprocessedElements, Iterable> outputs, Set producedOutputs) { return new AutoValue_CommittedResult(original.getTransform(), diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/EvaluationContext.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/EvaluationContext.java index e215070dac073..d192785681ec1 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/EvaluationContext.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/EvaluationContext.java @@ -20,6 +20,7 @@ import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Optional; import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterables; import com.google.common.util.concurrent.MoreExecutors; @@ -158,12 +159,9 @@ public CommittedResult handleResult( } else { outputTypes.add(OutputType.BUNDLE); } - CommittedResult committedResult = CommittedResult.create(result, - completedBundle == null - ? null - : completedBundle.withElements((Iterable) result.getUnprocessedElements()), - committedBundles, - outputTypes); + CommittedResult committedResult = + CommittedResult.create( + result, getUnprocessedInput(completedBundle, result), committedBundles, outputTypes); // Update state internals CopyOnAccessInMemoryStateInternals theirState = result.getState(); if (theirState != null) { @@ -187,6 +185,22 @@ public CommittedResult handleResult( return committedResult; } + /** + * Returns an {@link Optional} containing a bundle which contains all of the unprocessed elements + * that were not processed from the {@code completedBundle}. If all of the elements of the {@code + * completedBundle} were processed, or if {@code completedBundle} is null, returns an absent + * {@link Optional}. + */ + private Optional> getUnprocessedInput( + @Nullable CommittedBundle completedBundle, TransformResult result) { + if (completedBundle == null || Iterables.isEmpty(result.getUnprocessedElements())) { + return Optional.absent(); + } + CommittedBundle residual = + completedBundle.withElements((Iterable) result.getUnprocessedElements()); + return Optional.of(residual); + } + private Iterable> commitBundles( Iterable> bundles) { ImmutableList.Builder> completed = ImmutableList.builder(); diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ExecutorServiceParallelExecutor.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ExecutorServiceParallelExecutor.java index 6fe8ebd2609c7..2f4d1f64ec43a 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ExecutorServiceParallelExecutor.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ExecutorServiceParallelExecutor.java @@ -357,15 +357,16 @@ public final CommittedResult handleResult( ExecutorUpdate.fromBundle( outputBundle, graph.getPerElementConsumers(outputBundle.getPCollection()))); } - CommittedBundle unprocessedInputs = committedResult.getUnprocessedInputs(); - if (unprocessedInputs != null && !Iterables.isEmpty(unprocessedInputs.getElements())) { + Optional> unprocessedInputs = + committedResult.getUnprocessedInputs(); + if (unprocessedInputs.isPresent()) { if (inputBundle.getPCollection() == null) { // TODO: Split this logic out of an if statement - pendingRootBundles.get(result.getTransform()).offer(unprocessedInputs); + pendingRootBundles.get(result.getTransform()).offer(unprocessedInputs.get()); } else { allUpdates.offer( ExecutorUpdate.fromBundle( - unprocessedInputs, + unprocessedInputs.get(), Collections.>singleton( committedResult.getTransform()))); } diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WatermarkManager.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WatermarkManager.java index 80a3504599d8c..599b74fb5fba1 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WatermarkManager.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WatermarkManager.java @@ -994,9 +994,9 @@ private void updatePending( } TransformWatermarks completedTransform = transformToWatermarks.get(result.getTransform()); - if (input != null) { + if (result.getUnprocessedInputs().isPresent()) { // Add the unprocessed inputs - completedTransform.addPending(result.getUnprocessedInputs()); + completedTransform.addPending(result.getUnprocessedInputs().get()); } completedTransform.updateTimers(timerUpdate); if (input != null) { diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/CommittedResultTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/CommittedResultTest.java index cf19dc20e7e65..8b95b345d5548 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/CommittedResultTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/CommittedResultTest.java @@ -18,9 +18,9 @@ package org.apache.beam.runners.direct; -import static org.hamcrest.Matchers.nullValue; import static org.junit.Assert.assertThat; +import com.google.common.base.Optional; import com.google.common.collect.ImmutableList; import java.io.Serializable; import java.util.Collections; @@ -72,7 +72,7 @@ public void getTransformExtractsFromResult() { CommittedResult result = CommittedResult.create( StepTransformResult.withoutHold(transform).build(), - bundleFactory.createBundle(created).commit(Instant.now()), + Optional.>absent(), Collections.>emptyList(), EnumSet.noneOf(OutputType.class)); @@ -88,11 +88,11 @@ public void getUncommittedElementsEqualInput() { CommittedResult result = CommittedResult.create( StepTransformResult.withoutHold(transform).build(), - bundle, + Optional.of(bundle), Collections.>emptyList(), EnumSet.noneOf(OutputType.class)); - assertThat(result.getUnprocessedInputs(), + assertThat(result.getUnprocessedInputs().get(), Matchers.>equalTo(bundle)); } @@ -101,11 +101,14 @@ public void getUncommittedElementsNull() { CommittedResult result = CommittedResult.create( StepTransformResult.withoutHold(transform).build(), - null, + Optional.>absent(), Collections.>emptyList(), EnumSet.noneOf(OutputType.class)); - assertThat(result.getUnprocessedInputs(), nullValue()); + assertThat( + result.getUnprocessedInputs(), + Matchers.>>equalTo( + Optional.>absent())); } @Test @@ -120,7 +123,7 @@ public void getOutputsEqualInput() { CommittedResult result = CommittedResult.create( StepTransformResult.withoutHold(transform).build(), - bundleFactory.createBundle(created).commit(Instant.now()), + Optional.>absent(), outputs, EnumSet.of(OutputType.BUNDLE, OutputType.PCOLLECTION_VIEW)); diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/TransformExecutorTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/TransformExecutorTest.java index 3dd4028af6054..b7f5a7c14fec0 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/TransformExecutorTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/TransformExecutorTest.java @@ -25,6 +25,8 @@ import static org.junit.Assert.assertThat; import static org.mockito.Mockito.when; +import com.google.common.base.Optional; +import com.google.common.collect.Iterables; import com.google.common.util.concurrent.MoreExecutors; import java.util.ArrayList; import java.util.Collection; @@ -415,8 +417,13 @@ public CommittedResult handleResult(CommittedBundle inputBundle, TransformRes ? Collections.emptyList() : result.getUnprocessedElements(); - CommittedBundle unprocessedBundle = - inputBundle == null ? null : inputBundle.withElements(unprocessedElements); + Optional> unprocessedBundle; + if (inputBundle == null || Iterables.isEmpty(unprocessedElements)) { + unprocessedBundle = Optional.absent(); + } else { + unprocessedBundle = + Optional.>of(inputBundle.withElements(unprocessedElements)); + } return CommittedResult.create( result, unprocessedBundle, diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WatermarkManagerTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WatermarkManagerTest.java index e0b52515d8c63..e3f62155734ea 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WatermarkManagerTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WatermarkManagerTest.java @@ -24,6 +24,7 @@ import static org.hamcrest.Matchers.not; import static org.junit.Assert.assertThat; +import com.google.common.base.Optional; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; @@ -318,7 +319,7 @@ public void getWatermarkMultiIdenticalInput() { TimerUpdate.empty(), CommittedResult.create( StepTransformResult.withoutHold(graph.getProducer(created)).build(), - root.withElements(Collections.>emptyList()), + Optional.>absent(), Collections.singleton(createBundle), EnumSet.allOf(OutputType.class)), BoundedWindow.TIMESTAMP_MAX_VALUE); @@ -332,7 +333,7 @@ public void getWatermarkMultiIdenticalInput() { TimerUpdate.empty(), CommittedResult.create( StepTransformResult.withoutHold(theFlatten).build(), - createBundle.withElements(Collections.>emptyList()), + Optional.>absent(), Collections.>emptyList(), EnumSet.allOf(OutputType.class)), BoundedWindow.TIMESTAMP_MAX_VALUE); @@ -345,7 +346,7 @@ public void getWatermarkMultiIdenticalInput() { TimerUpdate.empty(), CommittedResult.create( StepTransformResult.withoutHold(theFlatten).build(), - createBundle.withElements(Collections.>emptyList()), + Optional.>absent(), Collections.>emptyList(), EnumSet.allOf(OutputType.class)), BoundedWindow.TIMESTAMP_MAX_VALUE); @@ -1501,9 +1502,15 @@ private CommittedResult result( AppliedPTransform transform, @Nullable CommittedBundle unprocessedBundle, Iterable> bundles) { + Optional> unprocessedElements; + if (unprocessedBundle == null || Iterables.isEmpty(unprocessedBundle.getElements())) { + unprocessedElements = Optional.absent(); + } else { + unprocessedElements = Optional.of(unprocessedBundle); + } return CommittedResult.create( StepTransformResult.withoutHold(transform).build(), - unprocessedBundle, + unprocessedElements, bundles, Iterables.isEmpty(bundles) ? EnumSet.noneOf(OutputType.class) From 0b600d20de2cf2e6071d1d288d4b6a4795df710a Mon Sep 17 00:00:00 2001 From: Charles Chen Date: Wed, 7 Jun 2017 16:09:10 -0700 Subject: [PATCH 039/200] Choose GroupAlsoByWindows implementation based on streaming flag --- .../apache_beam/options/pipeline_options.py | 9 -- .../runners/direct/direct_runner.py | 28 ++++++ sdks/python/apache_beam/transforms/core.py | 89 +++++++++++-------- 3 files changed, 79 insertions(+), 47 deletions(-) diff --git a/sdks/python/apache_beam/options/pipeline_options.py b/sdks/python/apache_beam/options/pipeline_options.py index daef3a71bb281..8598e057e83c1 100644 --- a/sdks/python/apache_beam/options/pipeline_options.py +++ b/sdks/python/apache_beam/options/pipeline_options.py @@ -18,7 +18,6 @@ """Pipeline options obtained from command line parsing.""" import argparse -import warnings from apache_beam.transforms.display import HasDisplayData from apache_beam.options.value_provider import StaticValueProvider @@ -279,14 +278,6 @@ def _add_argparse_args(cls, parser): action='store_true', help='Whether to enable streaming mode.') - # TODO(BEAM-1265): Remove this warning, once at least one runner supports - # streaming pipelines. - def validate(self, validator): - errors = [] - if self.view_as(StandardOptions).streaming: - warnings.warn('Streaming pipelines are not supported.') - return errors - class TypeOptions(PipelineOptions): diff --git a/sdks/python/apache_beam/runners/direct/direct_runner.py b/sdks/python/apache_beam/runners/direct/direct_runner.py index 323f44b4b4bd6..d80ef102e0363 100644 --- a/sdks/python/apache_beam/runners/direct/direct_runner.py +++ b/sdks/python/apache_beam/runners/direct/direct_runner.py @@ -26,19 +26,34 @@ import collections import logging +from apache_beam import typehints from apache_beam.metrics.execution import MetricsEnvironment from apache_beam.runners.direct.bundle_factory import BundleFactory from apache_beam.runners.runner import PipelineResult from apache_beam.runners.runner import PipelineRunner from apache_beam.runners.runner import PipelineState from apache_beam.runners.runner import PValueCache +from apache_beam.transforms.core import _GroupAlsoByWindow from apache_beam.options.pipeline_options import DirectOptions +from apache_beam.options.pipeline_options import StandardOptions from apache_beam.options.value_provider import RuntimeValueProvider __all__ = ['DirectRunner'] +# Type variables. +K = typehints.TypeVariable('K') +V = typehints.TypeVariable('V') + + +@typehints.with_input_types(typehints.KV[K, typehints.Iterable[V]]) +@typehints.with_output_types(typehints.KV[K, typehints.Iterable[V]]) +class _StreamingGroupAlsoByWindow(_GroupAlsoByWindow): + """Streaming GroupAlsoByWindow placeholder for overriding in DirectRunner.""" + pass + + class DirectRunner(PipelineRunner): """Executes a single pipeline on the local machine.""" @@ -64,6 +79,19 @@ def apply_CombinePerKey(self, transform, pcoll): except NotImplementedError: return transform.expand(pcoll) + def apply__GroupAlsoByWindow(self, transform, pcoll): + if (transform.__class__ == _GroupAlsoByWindow and + pcoll.pipeline._options.view_as(StandardOptions).streaming): + # Use specialized streaming implementation, if requested. + raise NotImplementedError( + 'Streaming support is not yet available on the DirectRunner.') + # TODO(ccy): enable when streaming implementation is plumbed through. + # type_hints = transform.get_type_hints() + # return pcoll | (_StreamingGroupAlsoByWindow(transform.windowing) + # .with_input_types(*type_hints.input_types[0]) + # .with_output_types(*type_hints.output_types[0])) + return transform.expand(pcoll) + def run(self, pipeline): """Execute the entire pipeline and returns an DirectPipelineResult.""" diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py index a137a1357ad38..c30136de2a439 100644 --- a/sdks/python/apache_beam/transforms/core.py +++ b/sdks/python/apache_beam/transforms/core.py @@ -1078,40 +1078,6 @@ def infer_output_type(self, input_type): key_type, value_type = trivial_inference.key_value_types(input_type) return Iterable[KV[key_type, typehints.WindowedValue[value_type]]] - class GroupAlsoByWindow(DoFn): - # TODO(robertwb): Support combiner lifting. - - def __init__(self, windowing): - super(GroupByKey.GroupAlsoByWindow, self).__init__() - self.windowing = windowing - - def infer_output_type(self, input_type): - key_type, windowed_value_iter_type = trivial_inference.key_value_types( - input_type) - value_type = windowed_value_iter_type.inner_type.inner_type - return Iterable[KV[key_type, Iterable[value_type]]] - - def start_bundle(self): - # pylint: disable=wrong-import-order, wrong-import-position - from apache_beam.transforms.trigger import InMemoryUnmergedState - from apache_beam.transforms.trigger import create_trigger_driver - # pylint: enable=wrong-import-order, wrong-import-position - self.driver = create_trigger_driver(self.windowing, True) - self.state_type = InMemoryUnmergedState - - def process(self, element): - k, vs = element - state = self.state_type() - # TODO(robertwb): Conditionally process in smaller chunks. - for wvalue in self.driver.process_elements(state, vs, MIN_TIMESTAMP): - yield wvalue.with_value((k, wvalue.value)) - while state.timers: - fired = state.get_and_clear_timers() - for timer_window, (name, time_domain, fire_time) in fired: - for wvalue in self.driver.process_timer( - timer_window, name, time_domain, fire_time, state): - yield wvalue.with_value((k, wvalue.value)) - def expand(self, pcoll): # This code path is only used in the local direct runner. For Dataflow # runner execution, the GroupByKey transform is expanded on the service. @@ -1136,8 +1102,7 @@ def expand(self, pcoll): | 'GroupByKey' >> (_GroupByKeyOnly() .with_input_types(reify_output_type) .with_output_types(gbk_input_type)) - | ('GroupByWindow' >> ParDo( - self.GroupAlsoByWindow(pcoll.windowing)) + | ('GroupByWindow' >> _GroupAlsoByWindow(pcoll.windowing) .with_input_types(gbk_input_type) .with_output_types(gbk_output_type))) else: @@ -1145,8 +1110,7 @@ def expand(self, pcoll): return (pcoll | 'ReifyWindows' >> ParDo(self.ReifyWindows()) | 'GroupByKey' >> _GroupByKeyOnly() - | 'GroupByWindow' >> ParDo( - self.GroupAlsoByWindow(pcoll.windowing))) + | 'GroupByWindow' >> _GroupAlsoByWindow(pcoll.windowing)) @typehints.with_input_types(typehints.KV[K, V]) @@ -1162,6 +1126,55 @@ def expand(self, pcoll): return pvalue.PCollection(pcoll.pipeline) +@typehints.with_input_types(typehints.KV[K, typehints.Iterable[V]]) +@typehints.with_output_types(typehints.KV[K, typehints.Iterable[V]]) +class _GroupAlsoByWindow(ParDo): + """The GroupAlsoByWindow transform.""" + def __init__(self, windowing): + super(_GroupAlsoByWindow, self).__init__( + _GroupAlsoByWindowDoFn(windowing)) + self.windowing = windowing + + def expand(self, pcoll): + self._check_pcollection(pcoll) + return pvalue.PCollection(pcoll.pipeline) + + +class _GroupAlsoByWindowDoFn(DoFn): + # TODO(robertwb): Support combiner lifting. + + def __init__(self, windowing): + super(_GroupAlsoByWindowDoFn, self).__init__() + self.windowing = windowing + + def infer_output_type(self, input_type): + key_type, windowed_value_iter_type = trivial_inference.key_value_types( + input_type) + value_type = windowed_value_iter_type.inner_type.inner_type + return Iterable[KV[key_type, Iterable[value_type]]] + + def start_bundle(self): + # pylint: disable=wrong-import-order, wrong-import-position + from apache_beam.transforms.trigger import InMemoryUnmergedState + from apache_beam.transforms.trigger import create_trigger_driver + # pylint: enable=wrong-import-order, wrong-import-position + self.driver = create_trigger_driver(self.windowing, True) + self.state_type = InMemoryUnmergedState + + def process(self, element): + k, vs = element + state = self.state_type() + # TODO(robertwb): Conditionally process in smaller chunks. + for wvalue in self.driver.process_elements(state, vs, MIN_TIMESTAMP): + yield wvalue.with_value((k, wvalue.value)) + while state.timers: + fired = state.get_and_clear_timers() + for timer_window, (name, time_domain, fire_time) in fired: + for wvalue in self.driver.process_timer( + timer_window, name, time_domain, fire_time, state): + yield wvalue.with_value((k, wvalue.value)) + + class Partition(PTransformWithSideInputs): """Split a PCollection into several partitions. From 329bf1e775c29b84a498fc106342fddd6e11f0b6 Mon Sep 17 00:00:00 2001 From: Sourabh Bajaj Date: Wed, 14 Jun 2017 16:35:45 -0700 Subject: [PATCH 040/200] [BEAM-1585] Add beam plugins as pipeline options --- sdks/python/apache_beam/io/filesystem.py | 14 ++----- .../apache_beam/options/pipeline_options.py | 8 ++++ .../runners/dataflow/dataflow_runner.py | 10 +++++ sdks/python/apache_beam/utils/plugin.py | 42 +++++++++++++++++++ 4 files changed, 63 insertions(+), 11 deletions(-) create mode 100644 sdks/python/apache_beam/utils/plugin.py diff --git a/sdks/python/apache_beam/io/filesystem.py b/sdks/python/apache_beam/io/filesystem.py index db6a1d0f6e6c2..f5530262b4ca7 100644 --- a/sdks/python/apache_beam/io/filesystem.py +++ b/sdks/python/apache_beam/io/filesystem.py @@ -26,6 +26,8 @@ import logging import time +from apache_beam.utils.plugin import BeamPlugin + logger = logging.getLogger(__name__) DEFAULT_READ_BUFFER_SIZE = 16 * 1024 * 1024 @@ -409,7 +411,7 @@ def __init__(self, msg, exception_details=None): self.exception_details = exception_details -class FileSystem(object): +class FileSystem(BeamPlugin): """A class that defines the functions that can be performed on a filesystem. All methods are abstract and they are for file system providers to @@ -428,16 +430,6 @@ def _get_compression_type(path, compression_type): 'was %s' % type(compression_type)) return compression_type - @classmethod - def get_all_subclasses(cls): - """Get all the subclasses of the FileSystem class - """ - all_subclasses = [] - for subclass in cls.__subclasses__(): - all_subclasses.append(subclass) - all_subclasses.extend(subclass.get_all_subclasses()) - return all_subclasses - @classmethod def scheme(cls): """URI scheme for the FileSystem diff --git a/sdks/python/apache_beam/options/pipeline_options.py b/sdks/python/apache_beam/options/pipeline_options.py index 8598e057e83c1..283b340ecfc37 100644 --- a/sdks/python/apache_beam/options/pipeline_options.py +++ b/sdks/python/apache_beam/options/pipeline_options.py @@ -544,6 +544,14 @@ def _add_argparse_args(cls, parser): 'During job submission a source distribution will be built and the ' 'worker will install the resulting package before running any custom ' 'code.')) + parser.add_argument( + '--beam_plugins', + default=None, + help= + ('Bootstrap the python process before executing any code by importing ' + 'all the plugins used in the pipeline. Please pass a comma separated' + 'list of import paths to be included. This is currently an ' + 'experimental flag and provides no stability.')) parser.add_argument( '--save_main_session', default=False, diff --git a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py b/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py index d6944b2802647..cc9274ec40c78 100644 --- a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py +++ b/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py @@ -46,6 +46,8 @@ from apache_beam.transforms.display import DisplayData from apache_beam.typehints import typehints from apache_beam.options.pipeline_options import StandardOptions +from apache_beam.options.pipeline_options import SetupOptions +from apache_beam.utils.plugin import BeamPlugin __all__ = ['DataflowRunner'] @@ -226,6 +228,14 @@ def run(self, pipeline): raise ImportError( 'Google Cloud Dataflow runner not available, ' 'please install apache_beam[gcp]') + + # Add setup_options for all the BeamPlugin imports + setup_options = pipeline._options.view_as(SetupOptions) + plugins = BeamPlugin.get_all_plugin_paths() + if setup_options.beam_plugins is not None: + plugins = list(set(plugins + setup_options.beam_plugins.split(','))) + setup_options.beam_plugins = plugins + self.job = apiclient.Job(pipeline._options) # Dataflow runner requires a KV type for GBK inputs, hence we enforce that diff --git a/sdks/python/apache_beam/utils/plugin.py b/sdks/python/apache_beam/utils/plugin.py new file mode 100644 index 0000000000000..563b93c54c7d0 --- /dev/null +++ b/sdks/python/apache_beam/utils/plugin.py @@ -0,0 +1,42 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""A BeamPlugin base class. + +For experimental usage only; no backwards-compatibility guarantees. +""" + + +class BeamPlugin(object): + """Plugin base class to be extended by dependent users such as FileSystem. + Any instantiated subclass will be imported at worker startup time.""" + + @classmethod + def get_all_subclasses(cls): + """Get all the subclasses of the BeamPlugin class.""" + all_subclasses = [] + for subclass in cls.__subclasses__(): + all_subclasses.append(subclass) + all_subclasses.extend(subclass.get_all_subclasses()) + return all_subclasses + + @classmethod + def get_all_plugin_paths(cls): + """Get full import paths of the BeamPlugin subclass.""" + def fullname(o): + return o.__module__ + "." + o.__name__ + return [fullname(o) for o in cls.get_all_subclasses()] From be09a162e32d158f5ae043e064223bb4f3742648 Mon Sep 17 00:00:00 2001 From: Charles Chen Date: Wed, 14 Jun 2017 16:14:50 -0700 Subject: [PATCH 041/200] Migrate DirectRunner evaluators to use Beam state API --- .../runners/dataflow/native_io/iobase_test.py | 39 +++++++++- .../runners/direct/evaluation_context.py | 56 ++++++++++---- .../runners/direct/transform_evaluator.py | 74 ++++++++++--------- .../runners/direct/transform_result.py | 3 +- 4 files changed, 122 insertions(+), 50 deletions(-) diff --git a/sdks/python/apache_beam/runners/dataflow/native_io/iobase_test.py b/sdks/python/apache_beam/runners/dataflow/native_io/iobase_test.py index 7610baff6b479..3d8c24f5651cd 100644 --- a/sdks/python/apache_beam/runners/dataflow/native_io/iobase_test.py +++ b/sdks/python/apache_beam/runners/dataflow/native_io/iobase_test.py @@ -20,7 +20,9 @@ import unittest -from apache_beam import error, pvalue +from apache_beam import Create +from apache_beam import error +from apache_beam import pvalue from apache_beam.runners.dataflow.native_io.iobase import ( _dict_printable_fields, _NativeWrite, @@ -28,10 +30,12 @@ DynamicSplitRequest, DynamicSplitResultWithPosition, NativeSink, + NativeSinkWriter, NativeSource, ReaderPosition, ReaderProgress ) +from apache_beam.testing.test_pipeline import TestPipeline class TestHelperFunctions(unittest.TestCase): @@ -154,6 +158,39 @@ def __init__(self, validate=False, dataset=None, project=None, fake_sink = FakeSink() self.assertEqual(fake_sink.__repr__(), "") + def test_on_direct_runner(self): + class FakeSink(NativeSink): + """A fake sink outputing a number of elements.""" + + def __init__(self): + self.written_values = [] + self.writer_instance = FakeSinkWriter(self.written_values) + + def writer(self): + return self.writer_instance + + class FakeSinkWriter(NativeSinkWriter): + """A fake sink writer for testing.""" + + def __init__(self, written_values): + self.written_values = written_values + + def __enter__(self): + return self + + def __exit__(self, *unused_args): + pass + + def Write(self, value): + self.written_values.append(value) + + p = TestPipeline() + sink = FakeSink() + p | Create(['a', 'b', 'c']) | _NativeWrite(sink) # pylint: disable=expression-not-assigned + p.run() + + self.assertEqual(['a', 'b', 'c'], sink.written_values) + class Test_NativeWrite(unittest.TestCase): diff --git a/sdks/python/apache_beam/runners/direct/evaluation_context.py b/sdks/python/apache_beam/runners/direct/evaluation_context.py index 68d99d373a7e1..8fa8e06922d03 100644 --- a/sdks/python/apache_beam/runners/direct/evaluation_context.py +++ b/sdks/python/apache_beam/runners/direct/evaluation_context.py @@ -27,22 +27,22 @@ from apache_beam.runners.direct.watermark_manager import WatermarkManager from apache_beam.runners.direct.executor import TransformExecutor from apache_beam.runners.direct.direct_metrics import DirectMetrics +from apache_beam.transforms.trigger import InMemoryUnmergedState from apache_beam.utils import counters class _ExecutionContext(object): - def __init__(self, watermarks, existing_state): - self._watermarks = watermarks - self._existing_state = existing_state + def __init__(self, watermarks, keyed_states): + self.watermarks = watermarks + self.keyed_states = keyed_states - @property - def watermarks(self): - return self._watermarks + self._step_context = None - @property - def existing_state(self): - return self._existing_state + def get_step_context(self): + if not self._step_context: + self._step_context = DirectStepContext(self.keyed_states) + return self._step_context class _SideInputView(object): @@ -145,9 +145,8 @@ def __init__(self, pipeline_options, bundle_factory, root_transforms, self._pcollection_to_views = collections.defaultdict(list) for view in views: self._pcollection_to_views[view.pvalue].append(view) - - # AppliedPTransform -> Evaluator specific state objects - self._application_state_interals = {} + self._transform_keyed_states = self._initialize_keyed_states( + root_transforms, value_to_consumers) self._watermark_manager = WatermarkManager( Clock(), root_transforms, value_to_consumers) self._side_inputs_container = _SideInputsContainer(views) @@ -158,6 +157,15 @@ def __init__(self, pipeline_options, bundle_factory, root_transforms, self._lock = threading.Lock() + def _initialize_keyed_states(self, root_transforms, value_to_consumers): + transform_keyed_states = {} + for transform in root_transforms: + transform_keyed_states[transform] = {} + for consumers in value_to_consumers.values(): + for consumer in consumers: + transform_keyed_states[consumer] = {} + return transform_keyed_states + def use_pvalue_cache(self, cache): assert not self._cache self._cache = cache @@ -231,7 +239,6 @@ def handle_result( counter.name, counter.combine_fn) merged_counter.accumulator.merge([counter.accumulator]) - self._application_state_interals[result.transform] = result.state return committed_bundles def get_aggregator_values(self, aggregator_or_name): @@ -256,7 +263,7 @@ def _commit_bundles(self, uncommitted_bundles): def get_execution_context(self, applied_ptransform): return _ExecutionContext( self._watermark_manager.get_watermarks(applied_ptransform), - self._application_state_interals.get(applied_ptransform)) + self._transform_keyed_states[applied_ptransform]) def create_bundle(self, output_pcollection): """Create an uncommitted bundle for the specified PCollection.""" @@ -296,3 +303,24 @@ def get_value_or_schedule_after_output(self, side_input, task): assert isinstance(task, TransformExecutor) return self._side_inputs_container.get_value_or_schedule_after_output( side_input, task) + + +class DirectUnmergedState(InMemoryUnmergedState): + """UnmergedState implementation for the DirectRunner.""" + + def __init__(self): + super(DirectUnmergedState, self).__init__(defensive_copy=False) + + +class DirectStepContext(object): + """Context for the currently-executing step.""" + + def __init__(self, keyed_existing_state): + self.keyed_existing_state = keyed_existing_state + + def get_keyed_state(self, key): + # TODO(ccy): consider implementing transactional copy on write semantics + # for state so that work items can be safely retried. + if not self.keyed_existing_state.get(key): + self.keyed_existing_state[key] = DirectUnmergedState() + return self.keyed_existing_state[key] diff --git a/sdks/python/apache_beam/runners/direct/transform_evaluator.py b/sdks/python/apache_beam/runners/direct/transform_evaluator.py index b1cb626ca0cb6..f5b5db5c0a773 100644 --- a/sdks/python/apache_beam/runners/direct/transform_evaluator.py +++ b/sdks/python/apache_beam/runners/direct/transform_evaluator.py @@ -33,6 +33,8 @@ from apache_beam.transforms import core from apache_beam.transforms.window import GlobalWindows from apache_beam.transforms.window import WindowedValue +from apache_beam.transforms.trigger import _CombiningValueStateTag +from apache_beam.transforms.trigger import _ListStateTag from apache_beam.typehints.typecheck import OutputCheckWrapperDoFn from apache_beam.typehints.typecheck import TypeCheckError from apache_beam.typehints.typecheck import TypeCheckWrapperDoFn @@ -207,7 +209,7 @@ def _read_values_to_bundles(reader): bundles = _read_values_to_bundles(reader) return TransformResult( - self._applied_ptransform, bundles, None, None, None, None) + self._applied_ptransform, bundles, None, None, None) class _FlattenEvaluator(_TransformEvaluator): @@ -231,7 +233,7 @@ def process_element(self, element): def finish_bundle(self): bundles = [self.bundle] return TransformResult( - self._applied_ptransform, bundles, None, None, None, None) + self._applied_ptransform, bundles, None, None, None) class _TaggedReceivers(dict): @@ -320,7 +322,7 @@ def finish_bundle(self): bundles = self._tagged_receivers.values() result_counters = self._counter_factory.get_counters() return TransformResult( - self._applied_ptransform, bundles, None, None, result_counters, None, + self._applied_ptransform, bundles, None, result_counters, None, self._tagged_receivers.undeclared_in_memory_tag_values) @@ -328,13 +330,8 @@ class _GroupByKeyOnlyEvaluator(_TransformEvaluator): """TransformEvaluator for _GroupByKeyOnly transform.""" MAX_ELEMENT_PER_BUNDLE = None - - class _GroupByKeyOnlyEvaluatorState(object): - - def __init__(self): - # output: {} key -> [values] - self.output = collections.defaultdict(list) - self.completed = False + ELEMENTS_TAG = _ListStateTag('elements') + COMPLETION_TAG = _CombiningValueStateTag('completed', any) def __init__(self, evaluation_context, applied_ptransform, input_committed_bundle, side_inputs, scoped_metrics_container): @@ -349,9 +346,8 @@ def _is_final_bundle(self): == WatermarkManager.WATERMARK_POS_INF) def start_bundle(self): - self.state = (self._execution_context.existing_state - if self._execution_context.existing_state - else _GroupByKeyOnlyEvaluator._GroupByKeyOnlyEvaluatorState()) + self.step_context = self._execution_context.get_step_context() + self.global_state = self.step_context.get_keyed_state(None) assert len(self._outputs) == 1 self.output_pcollection = list(self._outputs)[0] @@ -362,12 +358,15 @@ def start_bundle(self): self.key_coder = coders.registry.get_coder(kv_type_hint[0].tuple_types[0]) def process_element(self, element): - assert not self.state.completed + assert not self.global_state.get_state( + None, _GroupByKeyOnlyEvaluator.COMPLETION_TAG) if (isinstance(element, WindowedValue) and isinstance(element.value, collections.Iterable) and len(element.value) == 2): k, v = element.value - self.state.output[self.key_coder.encode(k)].append(v) + encoded_k = self.key_coder.encode(k) + state = self.step_context.get_keyed_state(encoded_k) + state.add_state(None, _GroupByKeyOnlyEvaluator.ELEMENTS_TAG, v) else: raise TypeCheckError('Input to _GroupByKeyOnly must be a PCollection of ' 'windowed key-value pairs. Instead received: %r.' @@ -375,15 +374,23 @@ def process_element(self, element): def finish_bundle(self): if self._is_final_bundle: - if self.state.completed: + if self.global_state.get_state( + None, _GroupByKeyOnlyEvaluator.COMPLETION_TAG): # Ignore empty bundles after emitting output. (This may happen because # empty bundles do not affect input watermarks.) bundles = [] else: - gbk_result = ( - map(GlobalWindows.windowed_value, ( - (self.key_coder.decode(k), v) - for k, v in self.state.output.iteritems()))) + gbk_result = [] + # TODO(ccy): perhaps we can clean this up to not use this + # internal attribute of the DirectStepContext. + for encoded_k in self.step_context.keyed_existing_state: + # Ignore global state. + if encoded_k is None: + continue + k = self.key_coder.decode(encoded_k) + state = self.step_context.get_keyed_state(encoded_k) + vs = state.get_state(None, _GroupByKeyOnlyEvaluator.ELEMENTS_TAG) + gbk_result.append(GlobalWindows.windowed_value((k, vs))) def len_element_fn(element): _, v = element.value @@ -393,21 +400,22 @@ def len_element_fn(element): self.output_pcollection, gbk_result, _GroupByKeyOnlyEvaluator.MAX_ELEMENT_PER_BUNDLE, len_element_fn) - self.state.completed = True - state = self.state + self.global_state.add_state( + None, _GroupByKeyOnlyEvaluator.COMPLETION_TAG, True) hold = WatermarkManager.WATERMARK_POS_INF else: bundles = [] - state = self.state hold = WatermarkManager.WATERMARK_NEG_INF return TransformResult( - self._applied_ptransform, bundles, state, None, None, hold) + self._applied_ptransform, bundles, None, None, hold) class _NativeWriteEvaluator(_TransformEvaluator): """TransformEvaluator for _NativeWrite transform.""" + ELEMENTS_TAG = _ListStateTag('elements') + def __init__(self, evaluation_context, applied_ptransform, input_committed_bundle, side_inputs, scoped_metrics_container): assert not side_inputs @@ -429,12 +437,12 @@ def _has_already_produced_output(self): == WatermarkManager.WATERMARK_POS_INF) def start_bundle(self): - # state: [values] - self.state = (self._execution_context.existing_state - if self._execution_context.existing_state else []) + self.step_context = self._execution_context.get_step_context() + self.global_state = self.step_context.get_keyed_state(None) def process_element(self, element): - self.state.append(element) + self.global_state.add_state( + None, _NativeWriteEvaluator.ELEMENTS_TAG, element) def finish_bundle(self): # finish_bundle will append incoming bundles in memory until all the bundles @@ -444,19 +452,19 @@ def finish_bundle(self): # ignored and would not generate additional output files. # TODO(altay): Do not wait until the last bundle to write in a single shard. if self._is_final_bundle: + elements = self.global_state.get_state( + None, _NativeWriteEvaluator.ELEMENTS_TAG) if self._has_already_produced_output: # Ignore empty bundles that arrive after the output is produced. - assert self.state == [] + assert elements == [] else: self._sink.pipeline_options = self._evaluation_context.pipeline_options with self._sink.writer() as writer: - for v in self.state: + for v in elements: writer.Write(v.value) - state = None hold = WatermarkManager.WATERMARK_POS_INF else: - state = self.state hold = WatermarkManager.WATERMARK_NEG_INF return TransformResult( - self._applied_ptransform, [], state, None, None, hold) + self._applied_ptransform, [], None, None, hold) diff --git a/sdks/python/apache_beam/runners/direct/transform_result.py b/sdks/python/apache_beam/runners/direct/transform_result.py index febdd202aa0a9..51593e3a434ba 100644 --- a/sdks/python/apache_beam/runners/direct/transform_result.py +++ b/sdks/python/apache_beam/runners/direct/transform_result.py @@ -25,12 +25,11 @@ class TransformResult(object): The result of evaluating an AppliedPTransform with a TransformEvaluator.""" - def __init__(self, applied_ptransform, uncommitted_output_bundles, state, + def __init__(self, applied_ptransform, uncommitted_output_bundles, timer_update, counters, watermark_hold, undeclared_tag_values=None): self.transform = applied_ptransform self.uncommitted_output_bundles = uncommitted_output_bundles - self.state = state # TODO: timer update is currently unused. self.timer_update = timer_update self.counters = counters From 628366cd2e33ebef4ddede67d3ce84663d725f84 Mon Sep 17 00:00:00 2001 From: Luke Cwik Date: Wed, 14 Jun 2017 20:13:01 -0700 Subject: [PATCH 042/200] [BEAM-1348] Mark Runner API like types declared within Fn API as deprecated. Add additional documentation. --- .../fn-api/src/main/proto/beam_fn_api.proto | 63 +++++++++++-------- 1 file changed, 37 insertions(+), 26 deletions(-) diff --git a/sdks/common/fn-api/src/main/proto/beam_fn_api.proto b/sdks/common/fn-api/src/main/proto/beam_fn_api.proto index 9fe2b2fa21d84..95fe0424f3ae1 100644 --- a/sdks/common/fn-api/src/main/proto/beam_fn_api.proto +++ b/sdks/common/fn-api/src/main/proto/beam_fn_api.proto @@ -67,43 +67,47 @@ message Target { string name = 2; } -// Information defining a PCollection +// (Deprecated) Information defining a PCollection +// +// Migrate to Runner API. message PCollection { // (Required) A reference to a coder. - string coder_reference = 1; + string coder_reference = 1 [deprecated = true]; // TODO: Windowing strategy, ... } -// A primitive transform within Apache Beam. +// (Deprecated) A primitive transform within Apache Beam. +// +// Migrate to Runner API. message PrimitiveTransform { // (Required) A pipeline level unique id which can be used as a reference to // refer to this. - string id = 1; + string id = 1 [deprecated = true]; // (Required) A function spec that is used by this primitive // transform to process data. - FunctionSpec function_spec = 2; + FunctionSpec function_spec = 2 [deprecated = true]; // A map of distinct input names to target definitions. // For example, in CoGbk this represents the tag name associated with each // distinct input name and a list of primitive transforms that are associated // with the specified input. - map inputs = 3; + map inputs = 3 [deprecated = true]; // A map from local output name to PCollection definitions. For example, in // DoFn this represents the tag name associated with each distinct output. - map outputs = 4; + map outputs = 4 [deprecated = true]; // TODO: Should we model side inputs as a special type of input for a // primitive transform or should it be modeled as the relationship that // the predecessor input will be a view primitive transform. // A map of from side input names to side inputs. - map side_inputs = 5; + map side_inputs = 5 [deprecated = true]; // The user name of this step. // TODO: This should really be in display data and not at this level - string step_name = 6; + string step_name = 6 [deprecated = true]; } /* @@ -112,13 +116,14 @@ message PrimitiveTransform { * This is still unstable mainly due to how we model the side input. */ -// Defines the common elements of user-definable functions, to allow the SDK to -// express the information the runner needs to execute work. -// Stable +// (Deprecated) Defines the common elements of user-definable functions, +// to allow the SDK to express the information the runner needs to execute work. +// +// Migrate to Runner API. message FunctionSpec { // (Required) A pipeline level unique id which can be used as a reference to // refer to this. - string id = 1; + string id = 1 [deprecated = true]; // (Required) A globally unique name representing this user definable // function. @@ -128,30 +133,31 @@ message FunctionSpec { // // For example: // urn:org.apache.beam:coder:kv:1.0 - string urn = 2; + string urn = 2 [deprecated = true]; // (Required) Reference to specification of execution environment required to // invoke this function. - string environment_reference = 3; + string environment_reference = 3 [deprecated = true]; // Data used to parameterize this function. Depending on the urn, this may be // optional or required. - google.protobuf.Any data = 4; + google.protobuf.Any data = 4 [deprecated = true]; } +// (Deprecated) Migrate to Runner API. message SideInput { // TODO: Coder? // For RunnerAPI. - Target input = 1; + Target input = 1 [deprecated = true]; // For FnAPI. - FunctionSpec view_fn = 2; + FunctionSpec view_fn = 2 [deprecated = true]; } -// Defines how to encode values into byte streams and decode values from byte -// streams. A coder can be parameterized by additional properties which may or -// may not be language agnostic. +// (Deprecated) Defines how to encode values into byte streams and decode +// values from byte streams. A coder can be parameterized by additional +// properties which may or may not be language agnostic. // // Coders using the urn:org.apache.beam:coder namespace must have their // encodings registered such that another may implement the encoding within @@ -160,14 +166,15 @@ message SideInput { // For example: // urn:org.apache.beam:coder:kv:1.0 // urn:org.apache.beam:coder:iterable:1.0 -// Stable +// +// Migrate to Runner API. message Coder { // TODO: This looks weird when compared to the other function specs // which use URN to differentiate themselves. Should "Coder" be embedded // inside the FunctionSpec data block. // The data associated with this coder used to reconstruct it. - FunctionSpec function_spec = 1; + FunctionSpec function_spec = 1 [deprecated = true]; // A list of component coder references. // @@ -180,7 +187,7 @@ message Coder { // // TODO: Perhaps this is redundant with the data of the FunctionSpec // for known coders? - repeated string component_coder_reference = 2; + repeated string component_coder_reference = 2 [deprecated = true]; } // A descriptor for connecting to a remote port using the Beam Fn Data API. @@ -273,10 +280,14 @@ message ProcessBundleDescriptor { // (Deprecated) A list of primitive transforms that should // be used to construct the bundle processing graph. - repeated PrimitiveTransform primitive_transform = 2; + // + // Migrate to Runner API definitions found within transforms field. + repeated PrimitiveTransform primitive_transform = 2 [deprecated = true]; // (Deprecated) The set of all coders referenced in this bundle. - repeated Coder coders = 4; + // + // Migrate to Runner API defintions found within codersyyy field. + repeated Coder coders = 4 [deprecated = true]; // (Required) A map from pipeline-scoped id to PTransform. map transforms = 5; From 81a72192dc4e792966de31c8eadda6a6c839a62c Mon Sep 17 00:00:00 2001 From: Kenneth Knowles Date: Mon, 12 Jun 2017 16:31:32 -0700 Subject: [PATCH 043/200] Fix getAdditionalInputs, etc, for DirectRunner stateful ParDo override --- .../direct/ParDoMultiOverrideFactory.java | 90 ++++++++++++++----- .../direct/StatefulParDoEvaluatorFactory.java | 11 ++- .../StatefulParDoEvaluatorFactoryTest.java | 65 +++++++------- 3 files changed, 102 insertions(+), 64 deletions(-) diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java index 858ea3400a796..b20113edf588d 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java @@ -19,6 +19,8 @@ import static com.google.common.base.Preconditions.checkState; +import com.google.common.collect.ImmutableMap; +import java.util.List; import java.util.Map; import org.apache.beam.runners.core.KeyedWorkItem; import org.apache.beam.runners.core.KeyedWorkItemCoder; @@ -27,7 +29,6 @@ import org.apache.beam.runners.core.construction.PTransformTranslation; import org.apache.beam.runners.core.construction.ReplacementOutputs; import org.apache.beam.runners.core.construction.SplittableParDo; -import org.apache.beam.sdk.coders.CannotProvideCoderException; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.runners.AppliedPTransform; @@ -48,6 +49,7 @@ import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionTuple; +import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.PValue; import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.TupleTagList; @@ -82,12 +84,14 @@ private PTransform, PCollectionTuple> getReplaceme return new SplittableParDo(transform); } else if (signature.stateDeclarations().size() > 0 || signature.timerDeclarations().size() > 0) { + // Based on the fact that the signature is stateful, DoFnSignatures ensures // that it is also keyed - MultiOutput, OutputT> keyedTransform = - (MultiOutput, OutputT>) transform; - - return new GbkThenStatefulParDo(keyedTransform); + return new GbkThenStatefulParDo( + fn, + transform.getMainOutputTag(), + transform.getAdditionalOutputTags(), + transform.getSideInputs()); } else { return transform; } @@ -101,10 +105,29 @@ public Map mapOutputs( static class GbkThenStatefulParDo extends PTransform>, PCollectionTuple> { - private final MultiOutput, OutputT> underlyingParDo; + private final transient DoFn, OutputT> doFn; + private final TupleTagList additionalOutputTags; + private final TupleTag mainOutputTag; + private final List> sideInputs; + + public GbkThenStatefulParDo( + DoFn, OutputT> doFn, + TupleTag mainOutputTag, + TupleTagList additionalOutputTags, + List> sideInputs) { + this.doFn = doFn; + this.additionalOutputTags = additionalOutputTags; + this.mainOutputTag = mainOutputTag; + this.sideInputs = sideInputs; + } - public GbkThenStatefulParDo(MultiOutput, OutputT> underlyingParDo) { - this.underlyingParDo = underlyingParDo; + @Override + public Map, PValue> getAdditionalInputs() { + ImmutableMap.Builder, PValue> additionalInputs = ImmutableMap.builder(); + for (PCollectionView sideInput : sideInputs) { + additionalInputs.put(sideInput.getTagInternal(), sideInput.getPCollection()); + } + return additionalInputs.build(); } @Override @@ -160,7 +183,9 @@ public PCollectionTuple expand(PCollection> input) { adjustedInput // Explode the resulting iterable into elements that are exactly the ones from // the input - .apply("Stateful ParDo", new StatefulParDo<>(underlyingParDo, input)); + .apply( + "Stateful ParDo", + new StatefulParDo<>(doFn, mainOutputTag, additionalOutputTags, sideInputs)); return outputs; } @@ -172,25 +197,45 @@ public PCollectionTuple expand(PCollection> input) { static class StatefulParDo extends PTransformTranslation.RawPTransform< PCollection>>, PCollectionTuple> { - private final transient MultiOutput, OutputT> underlyingParDo; - private final transient PCollection> originalInput; + private final transient DoFn, OutputT> doFn; + private final TupleTagList additionalOutputTags; + private final TupleTag mainOutputTag; + private final List> sideInputs; public StatefulParDo( - MultiOutput, OutputT> underlyingParDo, - PCollection> originalInput) { - this.underlyingParDo = underlyingParDo; - this.originalInput = originalInput; + DoFn, OutputT> doFn, + TupleTag mainOutputTag, + TupleTagList additionalOutputTags, + List> sideInputs) { + this.doFn = doFn; + this.mainOutputTag = mainOutputTag; + this.additionalOutputTags = additionalOutputTags; + this.sideInputs = sideInputs; + } + + public DoFn, OutputT> getDoFn() { + return doFn; + } + + public TupleTag getMainOutputTag() { + return mainOutputTag; + } + + public List> getSideInputs() { + return sideInputs; } - public MultiOutput, OutputT> getUnderlyingParDo() { - return underlyingParDo; + public TupleTagList getAdditionalOutputTags() { + return additionalOutputTags; } @Override - public Coder getDefaultOutputCoder( - PCollection>> input, PCollection output) - throws CannotProvideCoderException { - return underlyingParDo.getDefaultOutputCoder(originalInput, output); + public Map, PValue> getAdditionalInputs() { + ImmutableMap.Builder, PValue> additionalInputs = ImmutableMap.builder(); + for (PCollectionView sideInput : sideInputs) { + additionalInputs.put(sideInput.getTagInternal(), sideInput.getPCollection()); + } + return additionalInputs.build(); } @Override @@ -199,8 +244,7 @@ public PCollectionTuple expand(PCollection>> createEvaluator( throws Exception { final DoFn, OutputT> doFn = - application.getTransform().getUnderlyingParDo().getFn(); + application.getTransform().getDoFn(); final DoFnSignature signature = DoFnSignatures.getSignature(doFn.getClass()); // If the DoFn is stateful, schedule state clearing. @@ -120,9 +120,9 @@ private TransformEvaluator>> createEvaluator( (PCollection) inputBundle.getPCollection(), inputBundle.getKey(), doFn, - application.getTransform().getUnderlyingParDo().getSideInputs(), - application.getTransform().getUnderlyingParDo().getMainOutputTag(), - application.getTransform().getUnderlyingParDo().getAdditionalOutputTags().getAll()); + application.getTransform().getSideInputs(), + application.getTransform().getMainOutputTag(), + application.getTransform().getAdditionalOutputTags().getAll()); return new StatefulParDoEvaluator<>(delegateEvaluator); } @@ -152,12 +152,11 @@ public Runnable load( transformOutputWindow .getTransform() .getTransform() - .getUnderlyingParDo() .getMainOutputTag()); WindowingStrategy windowingStrategy = pc.getWindowingStrategy(); BoundedWindow window = transformOutputWindow.getWindow(); final DoFn doFn = - transformOutputWindow.getTransform().getTransform().getUnderlyingParDo().getFn(); + transformOutputWindow.getTransform().getTransform().getDoFn(); final DoFnSignature signature = DoFnSignatures.getSignature(doFn.getClass()); final DirectStepContext stepContext = diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactoryTest.java index 9366b7c9ff8c2..fe0b743c460b2 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactoryTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactoryTest.java @@ -41,6 +41,7 @@ import org.apache.beam.runners.core.StateNamespaces; import org.apache.beam.runners.core.StateTag; import org.apache.beam.runners.core.StateTags; +import org.apache.beam.runners.core.construction.TransformInputs; import org.apache.beam.runners.direct.ParDoMultiOverrideFactory.StatefulParDo; import org.apache.beam.runners.direct.WatermarkManager.TimerUpdate; import org.apache.beam.sdk.coders.StringUtf8Coder; @@ -52,7 +53,6 @@ import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.DoFn; -import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.View; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.FixedWindows; @@ -128,16 +128,17 @@ public void windowCleanupScheduled() throws Exception { input .apply( new ParDoMultiOverrideFactory.GbkThenStatefulParDo<>( - ParDo.of( - new DoFn, Integer>() { - @StateId(stateId) - private final StateSpec> spec = - StateSpecs.value(StringUtf8Coder.of()); - - @ProcessElement - public void process(ProcessContext c) {} - }) - .withOutputTags(mainOutput, TupleTagList.empty()))) + new DoFn, Integer>() { + @StateId(stateId) + private final StateSpec> spec = + StateSpecs.value(StringUtf8Coder.of()); + + @ProcessElement + public void process(ProcessContext c) {} + }, + mainOutput, + TupleTagList.empty(), + Collections.>emptyList())) .get(mainOutput) .setCoder(VarIntCoder.of()); @@ -153,8 +154,7 @@ public void process(ProcessContext c) {} when(mockEvaluationContext.getExecutionContext( eq(producingTransform), Mockito.any())) .thenReturn(mockExecutionContext); - when(mockExecutionContext.getStepContext(anyString())) - .thenReturn(mockStepContext); + when(mockExecutionContext.getStepContext(anyString())).thenReturn(mockStepContext); IntervalWindow firstWindow = new IntervalWindow(new Instant(0), new Instant(9)); IntervalWindow secondWindow = new IntervalWindow(new Instant(10), new Instant(19)); @@ -241,18 +241,17 @@ public void testUnprocessedElements() throws Exception { mainInput .apply( new ParDoMultiOverrideFactory.GbkThenStatefulParDo<>( - ParDo - .of( - new DoFn, Integer>() { - @StateId(stateId) - private final StateSpec> spec = - StateSpecs.value(StringUtf8Coder.of()); - - @ProcessElement - public void process(ProcessContext c) {} - }) - .withSideInputs(sideInput) - .withOutputTags(mainOutput, TupleTagList.empty()))) + new DoFn, Integer>() { + @StateId(stateId) + private final StateSpec> spec = + StateSpecs.value(StringUtf8Coder.of()); + + @ProcessElement + public void process(ProcessContext c) {} + }, + mainOutput, + TupleTagList.empty(), + Collections.>singletonList(sideInput))) .get(mainOutput) .setCoder(VarIntCoder.of()); @@ -269,8 +268,7 @@ public void process(ProcessContext c) {} when(mockEvaluationContext.getExecutionContext( eq(producingTransform), Mockito.any())) .thenReturn(mockExecutionContext); - when(mockExecutionContext.getStepContext(anyString())) - .thenReturn(mockStepContext); + when(mockExecutionContext.getStepContext(anyString())).thenReturn(mockStepContext); when(mockEvaluationContext.createBundle(Matchers.>any())) .thenReturn(mockUncommittedBundle); when(mockStepContext.getTimerUpdate()).thenReturn(TimerUpdate.empty()); @@ -287,11 +285,8 @@ public void process(ProcessContext c) {} // global window state merely by having the evaluator created. The cleanup logic does not // depend on the window. String key = "hello"; - WindowedValue> firstKv = WindowedValue.of( - KV.of(key, 1), - new Instant(3), - firstWindow, - PaneInfo.NO_FIRING); + WindowedValue> firstKv = + WindowedValue.of(KV.of(key, 1), new Instant(3), firstWindow, PaneInfo.NO_FIRING); WindowedValue>> gbkOutputElement = firstKv.withValue( @@ -306,7 +301,8 @@ public void process(ProcessContext c) {} BUNDLE_FACTORY .createBundle( (PCollection>>) - Iterables.getOnlyElement(producingTransform.getInputs().values())) + Iterables.getOnlyElement( + TransformInputs.nonAdditionalInputs(producingTransform))) .add(gbkOutputElement) .commit(Instant.now()); TransformEvaluator>> evaluator = @@ -316,8 +312,7 @@ public void process(ProcessContext c) {} // This should push back every element as a KV> // in the appropriate window. Since the keys are equal they are single-threaded - TransformResult>> result = - evaluator.finishBundle(); + TransformResult>> result = evaluator.finishBundle(); List pushedBackInts = new ArrayList<>(); From 581ee1520e497fca95e8c4aa75f90050952523d0 Mon Sep 17 00:00:00 2001 From: JingsongLi Date: Tue, 13 Jun 2017 11:26:38 +0800 Subject: [PATCH 044/200] [BEAM-2423] Port state internals tests to the new base class StateInternalsTest --- runners/apex/pom.xml | 7 + .../utils/ApexStateInternalsTest.java | 411 ++++-------------- .../core/InMemoryStateInternalsTest.java | 46 +- .../beam/runners/core/StateInternalsTest.java | 14 +- .../FlinkBroadcastStateInternalsTest.java | 242 +++-------- .../FlinkKeyGroupStateInternalsTest.java | 359 +++++++-------- .../FlinkSplitStateInternalsTest.java | 132 +++--- runners/spark/pom.xml | 7 + .../stateful/SparkStateInternalsTest.java | 66 +++ 9 files changed, 521 insertions(+), 763 deletions(-) create mode 100644 runners/spark/src/test/java/org/apache/beam/runners/spark/stateful/SparkStateInternalsTest.java diff --git a/runners/apex/pom.xml b/runners/apex/pom.xml index 4a36bec8ab2c5..d3d4318d2dcf4 100644 --- a/runners/apex/pom.xml +++ b/runners/apex/pom.xml @@ -184,6 +184,13 @@ test-jar test + + + org.apache.beam + beam-runners-core-java + test-jar + test + diff --git a/runners/apex/src/test/java/org/apache/beam/runners/apex/translation/utils/ApexStateInternalsTest.java b/runners/apex/src/test/java/org/apache/beam/runners/apex/translation/utils/ApexStateInternalsTest.java index a7e64af4dd5b6..87aa8c28041ee 100644 --- a/runners/apex/src/test/java/org/apache/beam/runners/apex/translation/utils/ApexStateInternalsTest.java +++ b/runners/apex/src/test/java/org/apache/beam/runners/apex/translation/utils/ApexStateInternalsTest.java @@ -18,350 +18,109 @@ package org.apache.beam.runners.apex.translation.utils; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertThat; import com.datatorrent.lib.util.KryoCloneUtils; -import java.util.Arrays; -import org.apache.beam.runners.apex.translation.utils.ApexStateInternals.ApexStateBackend; -import org.apache.beam.runners.apex.translation.utils.ApexStateInternals.ApexStateInternalsFactory; -import org.apache.beam.runners.core.StateMerging; +import org.apache.beam.runners.core.StateInternals; +import org.apache.beam.runners.core.StateInternalsTest; import org.apache.beam.runners.core.StateNamespace; import org.apache.beam.runners.core.StateNamespaceForTest; import org.apache.beam.runners.core.StateTag; import org.apache.beam.runners.core.StateTags; import org.apache.beam.sdk.coders.StringUtf8Coder; -import org.apache.beam.sdk.coders.VarIntCoder; -import org.apache.beam.sdk.state.BagState; -import org.apache.beam.sdk.state.CombiningState; -import org.apache.beam.sdk.state.GroupingState; -import org.apache.beam.sdk.state.ReadableState; import org.apache.beam.sdk.state.ValueState; -import org.apache.beam.sdk.state.WatermarkHoldState; -import org.apache.beam.sdk.transforms.Sum; -import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.transforms.windowing.IntervalWindow; -import org.apache.beam.sdk.transforms.windowing.TimestampCombiner; import org.hamcrest.Matchers; -import org.joda.time.Instant; -import org.junit.Before; +import org.junit.Ignore; import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.junit.runners.Suite; /** * Tests for {@link ApexStateInternals}. This is based on the tests for - * {@code InMemoryStateInternals}. + * {@code StateInternalsTest}. */ +@RunWith(Suite.class) +@Suite.SuiteClasses({ + ApexStateInternalsTest.StandardStateInternalsTests.class, + ApexStateInternalsTest.OtherTests.class +}) public class ApexStateInternalsTest { - private static final BoundedWindow WINDOW_1 = new IntervalWindow(new Instant(0), new Instant(10)); - private static final StateNamespace NAMESPACE_1 = new StateNamespaceForTest("ns1"); - private static final StateNamespace NAMESPACE_2 = new StateNamespaceForTest("ns2"); - private static final StateNamespace NAMESPACE_3 = new StateNamespaceForTest("ns3"); - private static final StateTag> STRING_VALUE_ADDR = - StateTags.value("stringValue", StringUtf8Coder.of()); - private static final StateTag> - SUM_INTEGER_ADDR = StateTags.combiningValueFromInputInternal( - "sumInteger", VarIntCoder.of(), Sum.ofIntegers()); - private static final StateTag> STRING_BAG_ADDR = - StateTags.bag("stringBag", StringUtf8Coder.of()); - private static final StateTag - WATERMARK_EARLIEST_ADDR = - StateTags.watermarkStateInternal("watermark", TimestampCombiner.EARLIEST); - private static final StateTag WATERMARK_LATEST_ADDR = - StateTags.watermarkStateInternal("watermark", TimestampCombiner.LATEST); - private static final StateTag WATERMARK_EOW_ADDR = - StateTags.watermarkStateInternal("watermark", TimestampCombiner.END_OF_WINDOW); - - private ApexStateInternals underTest; - - @Before - public void initStateInternals() { - underTest = new ApexStateInternals.ApexStateBackend() + private static StateInternals newStateInternals() { + return new ApexStateInternals.ApexStateBackend() .newStateInternalsFactory(StringUtf8Coder.of()) - .stateInternalsForKey((String) null); - } - - @Test - public void testBag() throws Exception { - BagState value = underTest.state(NAMESPACE_1, STRING_BAG_ADDR); - - assertEquals(value, underTest.state(NAMESPACE_1, STRING_BAG_ADDR)); - assertFalse(value.equals(underTest.state(NAMESPACE_2, STRING_BAG_ADDR))); - - assertThat(value.read(), Matchers.emptyIterable()); - value.add("hello"); - assertThat(value.read(), Matchers.containsInAnyOrder("hello")); - - value.add("world"); - assertThat(value.read(), Matchers.containsInAnyOrder("hello", "world")); - - value.clear(); - assertThat(value.read(), Matchers.emptyIterable()); - assertEquals(underTest.state(NAMESPACE_1, STRING_BAG_ADDR), value); - - } - - @Test - public void testBagIsEmpty() throws Exception { - BagState value = underTest.state(NAMESPACE_1, STRING_BAG_ADDR); - - assertThat(value.isEmpty().read(), Matchers.is(true)); - ReadableState readFuture = value.isEmpty(); - value.add("hello"); - assertThat(readFuture.read(), Matchers.is(false)); - - value.clear(); - assertThat(readFuture.read(), Matchers.is(true)); - } - - @Test - public void testMergeBagIntoSource() throws Exception { - BagState bag1 = underTest.state(NAMESPACE_1, STRING_BAG_ADDR); - BagState bag2 = underTest.state(NAMESPACE_2, STRING_BAG_ADDR); - - bag1.add("Hello"); - bag2.add("World"); - bag1.add("!"); - - StateMerging.mergeBags(Arrays.asList(bag1, bag2), bag1); - - // Reading the merged bag gets both the contents - assertThat(bag1.read(), Matchers.containsInAnyOrder("Hello", "World", "!")); - assertThat(bag2.read(), Matchers.emptyIterable()); - } - - @Test - public void testMergeBagIntoNewNamespace() throws Exception { - BagState bag1 = underTest.state(NAMESPACE_1, STRING_BAG_ADDR); - BagState bag2 = underTest.state(NAMESPACE_2, STRING_BAG_ADDR); - BagState bag3 = underTest.state(NAMESPACE_3, STRING_BAG_ADDR); - - bag1.add("Hello"); - bag2.add("World"); - bag1.add("!"); - - StateMerging.mergeBags(Arrays.asList(bag1, bag2, bag3), bag3); - - // Reading the merged bag gets both the contents - assertThat(bag3.read(), Matchers.containsInAnyOrder("Hello", "World", "!")); - assertThat(bag1.read(), Matchers.emptyIterable()); - assertThat(bag2.read(), Matchers.emptyIterable()); + .stateInternalsForKey("dummyKey"); + } + + /** + * A standard StateInternals test. Ignore set and map tests. + */ + @RunWith(JUnit4.class) + public static class StandardStateInternalsTests extends StateInternalsTest { + @Override + protected StateInternals createStateInternals() { + return newStateInternals(); + } + + @Override + @Ignore + public void testSet() {} + + @Override + @Ignore + public void testSetIsEmpty() {} + + @Override + @Ignore + public void testMergeSetIntoSource() {} + + @Override + @Ignore + public void testMergeSetIntoNewNamespace() {} + + @Override + @Ignore + public void testMap() {} + + @Override + @Ignore + public void testSetReadable() {} + + @Override + @Ignore + public void testMapReadable() {} + } + + /** + * A specific test of ApexStateInternalsTest. + */ + @RunWith(JUnit4.class) + public static class OtherTests { + + private static final StateNamespace NAMESPACE = new StateNamespaceForTest("ns"); + private static final StateTag> STRING_VALUE_ADDR = + StateTags.value("stringValue", StringUtf8Coder.of()); + + @Test + public void testSerialization() throws Exception { + ApexStateInternals.ApexStateInternalsFactory sif = + new ApexStateInternals.ApexStateBackend(). + newStateInternalsFactory(StringUtf8Coder.of()); + ApexStateInternals keyAndState = sif.stateInternalsForKey("dummy"); + + ValueState value = keyAndState.state(NAMESPACE, STRING_VALUE_ADDR); + assertEquals(keyAndState.state(NAMESPACE, STRING_VALUE_ADDR), value); + value.write("hello"); + + ApexStateInternals.ApexStateInternalsFactory cloned; + assertNotNull("Serialization", cloned = KryoCloneUtils.cloneObject(sif)); + ApexStateInternals clonedKeyAndState = cloned.stateInternalsForKey("dummy"); + + ValueState clonedValue = clonedKeyAndState.state(NAMESPACE, STRING_VALUE_ADDR); + assertThat(clonedValue.read(), Matchers.equalTo("hello")); + assertEquals(clonedKeyAndState.state(NAMESPACE, STRING_VALUE_ADDR), value); + } } - - @Test - public void testCombiningValue() throws Exception { - GroupingState value = underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR); - - // State instances are cached, but depend on the namespace. - assertEquals(value, underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR)); - assertFalse(value.equals(underTest.state(NAMESPACE_2, SUM_INTEGER_ADDR))); - - assertThat(value.read(), Matchers.equalTo(0)); - value.add(2); - assertThat(value.read(), Matchers.equalTo(2)); - - value.add(3); - assertThat(value.read(), Matchers.equalTo(5)); - - value.clear(); - assertThat(value.read(), Matchers.equalTo(0)); - assertEquals(underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR), value); - } - - @Test - public void testCombiningIsEmpty() throws Exception { - GroupingState value = underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR); - - assertThat(value.isEmpty().read(), Matchers.is(true)); - ReadableState readFuture = value.isEmpty(); - value.add(5); - assertThat(readFuture.read(), Matchers.is(false)); - - value.clear(); - assertThat(readFuture.read(), Matchers.is(true)); - } - - @Test - public void testMergeCombiningValueIntoSource() throws Exception { - CombiningState value1 = - underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR); - CombiningState value2 = - underTest.state(NAMESPACE_2, SUM_INTEGER_ADDR); - - value1.add(5); - value2.add(10); - value1.add(6); - - assertThat(value1.read(), Matchers.equalTo(11)); - assertThat(value2.read(), Matchers.equalTo(10)); - - // Merging clears the old values and updates the result value. - StateMerging.mergeCombiningValues(Arrays.asList(value1, value2), value1); - - assertThat(value1.read(), Matchers.equalTo(21)); - assertThat(value2.read(), Matchers.equalTo(0)); - } - - @Test - public void testMergeCombiningValueIntoNewNamespace() throws Exception { - CombiningState value1 = - underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR); - CombiningState value2 = - underTest.state(NAMESPACE_2, SUM_INTEGER_ADDR); - CombiningState value3 = - underTest.state(NAMESPACE_3, SUM_INTEGER_ADDR); - - value1.add(5); - value2.add(10); - value1.add(6); - - StateMerging.mergeCombiningValues(Arrays.asList(value1, value2), value3); - - // Merging clears the old values and updates the result value. - assertThat(value1.read(), Matchers.equalTo(0)); - assertThat(value2.read(), Matchers.equalTo(0)); - assertThat(value3.read(), Matchers.equalTo(21)); - } - - @Test - public void testWatermarkEarliestState() throws Exception { - WatermarkHoldState value = - underTest.state(NAMESPACE_1, WATERMARK_EARLIEST_ADDR); - - // State instances are cached, but depend on the namespace. - assertEquals(value, underTest.state(NAMESPACE_1, WATERMARK_EARLIEST_ADDR)); - assertFalse(value.equals(underTest.state(NAMESPACE_2, WATERMARK_EARLIEST_ADDR))); - - assertThat(value.read(), Matchers.nullValue()); - value.add(new Instant(2000)); - assertThat(value.read(), Matchers.equalTo(new Instant(2000))); - - value.add(new Instant(3000)); - assertThat(value.read(), Matchers.equalTo(new Instant(2000))); - - value.add(new Instant(1000)); - assertThat(value.read(), Matchers.equalTo(new Instant(1000))); - - value.clear(); - assertThat(value.read(), Matchers.equalTo(null)); - assertEquals(underTest.state(NAMESPACE_1, WATERMARK_EARLIEST_ADDR), value); - } - - @Test - public void testWatermarkLatestState() throws Exception { - WatermarkHoldState value = - underTest.state(NAMESPACE_1, WATERMARK_LATEST_ADDR); - - // State instances are cached, but depend on the namespace. - assertEquals(value, underTest.state(NAMESPACE_1, WATERMARK_LATEST_ADDR)); - assertFalse(value.equals(underTest.state(NAMESPACE_2, WATERMARK_LATEST_ADDR))); - - assertThat(value.read(), Matchers.nullValue()); - value.add(new Instant(2000)); - assertThat(value.read(), Matchers.equalTo(new Instant(2000))); - - value.add(new Instant(3000)); - assertThat(value.read(), Matchers.equalTo(new Instant(3000))); - - value.add(new Instant(1000)); - assertThat(value.read(), Matchers.equalTo(new Instant(3000))); - - value.clear(); - assertThat(value.read(), Matchers.equalTo(null)); - assertEquals(underTest.state(NAMESPACE_1, WATERMARK_LATEST_ADDR), value); - } - - @Test - public void testWatermarkEndOfWindowState() throws Exception { - WatermarkHoldState value = underTest.state(NAMESPACE_1, WATERMARK_EOW_ADDR); - - // State instances are cached, but depend on the namespace. - assertEquals(value, underTest.state(NAMESPACE_1, WATERMARK_EOW_ADDR)); - assertFalse(value.equals(underTest.state(NAMESPACE_2, WATERMARK_EOW_ADDR))); - - assertThat(value.read(), Matchers.nullValue()); - value.add(new Instant(2000)); - assertThat(value.read(), Matchers.equalTo(new Instant(2000))); - - value.clear(); - assertThat(value.read(), Matchers.equalTo(null)); - assertEquals(underTest.state(NAMESPACE_1, WATERMARK_EOW_ADDR), value); - } - - @Test - public void testWatermarkStateIsEmpty() throws Exception { - WatermarkHoldState value = - underTest.state(NAMESPACE_1, WATERMARK_EARLIEST_ADDR); - - assertThat(value.isEmpty().read(), Matchers.is(true)); - ReadableState readFuture = value.isEmpty(); - value.add(new Instant(1000)); - assertThat(readFuture.read(), Matchers.is(false)); - - value.clear(); - assertThat(readFuture.read(), Matchers.is(true)); - } - - @Test - public void testMergeEarliestWatermarkIntoSource() throws Exception { - WatermarkHoldState value1 = - underTest.state(NAMESPACE_1, WATERMARK_EARLIEST_ADDR); - WatermarkHoldState value2 = - underTest.state(NAMESPACE_2, WATERMARK_EARLIEST_ADDR); - - value1.add(new Instant(3000)); - value2.add(new Instant(5000)); - value1.add(new Instant(4000)); - value2.add(new Instant(2000)); - - // Merging clears the old values and updates the merged value. - StateMerging.mergeWatermarks(Arrays.asList(value1, value2), value1, WINDOW_1); - - assertThat(value1.read(), Matchers.equalTo(new Instant(2000))); - assertThat(value2.read(), Matchers.equalTo(null)); - } - - @Test - public void testMergeLatestWatermarkIntoSource() throws Exception { - WatermarkHoldState value1 = - underTest.state(NAMESPACE_1, WATERMARK_LATEST_ADDR); - WatermarkHoldState value2 = - underTest.state(NAMESPACE_2, WATERMARK_LATEST_ADDR); - WatermarkHoldState value3 = - underTest.state(NAMESPACE_3, WATERMARK_LATEST_ADDR); - - value1.add(new Instant(3000)); - value2.add(new Instant(5000)); - value1.add(new Instant(4000)); - value2.add(new Instant(2000)); - - // Merging clears the old values and updates the result value. - StateMerging.mergeWatermarks(Arrays.asList(value1, value2), value3, WINDOW_1); - - // Merging clears the old values and updates the result value. - assertThat(value3.read(), Matchers.equalTo(new Instant(5000))); - assertThat(value1.read(), Matchers.equalTo(null)); - assertThat(value2.read(), Matchers.equalTo(null)); - } - - @Test - public void testSerialization() throws Exception { - ApexStateInternalsFactory sif = new ApexStateBackend(). - newStateInternalsFactory(StringUtf8Coder.of()); - ApexStateInternals keyAndState = sif.stateInternalsForKey("dummy"); - - ValueState value = keyAndState.state(NAMESPACE_1, STRING_VALUE_ADDR); - assertEquals(keyAndState.state(NAMESPACE_1, STRING_VALUE_ADDR), value); - value.write("hello"); - - ApexStateInternalsFactory cloned; - assertNotNull("Serialization", cloned = KryoCloneUtils.cloneObject(sif)); - ApexStateInternals clonedKeyAndState = cloned.stateInternalsForKey("dummy"); - - ValueState clonedValue = clonedKeyAndState.state(NAMESPACE_1, STRING_VALUE_ADDR); - assertThat(clonedValue.read(), Matchers.equalTo("hello")); - assertEquals(clonedKeyAndState.state(NAMESPACE_1, STRING_VALUE_ADDR), value); - } - } diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/InMemoryStateInternalsTest.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/InMemoryStateInternalsTest.java index 335c2f853c97a..1c6cd3003b2d8 100644 --- a/runners/core-java/src/test/java/org/apache/beam/runners/core/InMemoryStateInternalsTest.java +++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/InMemoryStateInternalsTest.java @@ -19,7 +19,17 @@ import static org.junit.Assert.assertThat; +import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.coders.VarIntCoder; +import org.apache.beam.sdk.state.BagState; +import org.apache.beam.sdk.state.CombiningState; +import org.apache.beam.sdk.state.MapState; +import org.apache.beam.sdk.state.SetState; import org.apache.beam.sdk.state.State; +import org.apache.beam.sdk.state.ValueState; +import org.apache.beam.sdk.state.WatermarkHoldState; +import org.apache.beam.sdk.transforms.Sum; +import org.apache.beam.sdk.transforms.windowing.TimestampCombiner; import org.hamcrest.Matchers; import org.junit.Test; import org.junit.runner.RunWith; @@ -53,21 +63,41 @@ protected StateInternals createStateInternals() { @RunWith(JUnit4.class) public static class OtherTests { + private static final StateNamespace NAMESPACE = new StateNamespaceForTest("ns"); + + private static final StateTag> STRING_VALUE_ADDR = + StateTags.value("stringValue", StringUtf8Coder.of()); + private static final StateTag> + SUM_INTEGER_ADDR = StateTags.combiningValueFromInputInternal( + "sumInteger", VarIntCoder.of(), Sum.ofIntegers()); + private static final StateTag> STRING_BAG_ADDR = + StateTags.bag("stringBag", StringUtf8Coder.of()); + private static final StateTag> STRING_SET_ADDR = + StateTags.set("stringSet", StringUtf8Coder.of()); + private static final StateTag> STRING_MAP_ADDR = + StateTags.map("stringMap", StringUtf8Coder.of(), VarIntCoder.of()); + private static final StateTag WATERMARK_EARLIEST_ADDR = + StateTags.watermarkStateInternal("watermark", TimestampCombiner.EARLIEST); + private static final StateTag WATERMARK_LATEST_ADDR = + StateTags.watermarkStateInternal("watermark", TimestampCombiner.LATEST); + private static final StateTag WATERMARK_EOW_ADDR = + StateTags.watermarkStateInternal("watermark", TimestampCombiner.END_OF_WINDOW); + StateInternals underTest = new InMemoryStateInternals<>("dummyKey"); @Test public void testSameInstance() { - assertSameInstance(StateInternalsTest.STRING_VALUE_ADDR); - assertSameInstance(StateInternalsTest.SUM_INTEGER_ADDR); - assertSameInstance(StateInternalsTest.STRING_BAG_ADDR); - assertSameInstance(StateInternalsTest.STRING_SET_ADDR); - assertSameInstance(StateInternalsTest.STRING_MAP_ADDR); - assertSameInstance(StateInternalsTest.WATERMARK_EARLIEST_ADDR); + assertSameInstance(STRING_VALUE_ADDR); + assertSameInstance(SUM_INTEGER_ADDR); + assertSameInstance(STRING_BAG_ADDR); + assertSameInstance(STRING_SET_ADDR); + assertSameInstance(STRING_MAP_ADDR); + assertSameInstance(WATERMARK_EARLIEST_ADDR); } private void assertSameInstance(StateTag address) { - assertThat(underTest.state(StateInternalsTest.NAMESPACE_1, address), - Matchers.sameInstance(underTest.state(StateInternalsTest.NAMESPACE_1, address))); + assertThat(underTest.state(NAMESPACE, address), + Matchers.sameInstance(underTest.state(NAMESPACE, address))); } } diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/StateInternalsTest.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/StateInternalsTest.java index 6011fb48aed67..ae07fe6b1ced3 100644 --- a/runners/core-java/src/test/java/org/apache/beam/runners/core/StateInternalsTest.java +++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/StateInternalsTest.java @@ -56,22 +56,22 @@ public abstract class StateInternalsTest { private static final BoundedWindow WINDOW_1 = new IntervalWindow(new Instant(0), new Instant(10)); - static final StateNamespace NAMESPACE_1 = new StateNamespaceForTest("ns1"); + private static final StateNamespace NAMESPACE_1 = new StateNamespaceForTest("ns1"); private static final StateNamespace NAMESPACE_2 = new StateNamespaceForTest("ns2"); private static final StateNamespace NAMESPACE_3 = new StateNamespaceForTest("ns3"); - static final StateTag> STRING_VALUE_ADDR = + private static final StateTag> STRING_VALUE_ADDR = StateTags.value("stringValue", StringUtf8Coder.of()); - static final StateTag> + private static final StateTag> SUM_INTEGER_ADDR = StateTags.combiningValueFromInputInternal( "sumInteger", VarIntCoder.of(), Sum.ofIntegers()); - static final StateTag> STRING_BAG_ADDR = + private static final StateTag> STRING_BAG_ADDR = StateTags.bag("stringBag", StringUtf8Coder.of()); - static final StateTag> STRING_SET_ADDR = + private static final StateTag> STRING_SET_ADDR = StateTags.set("stringSet", StringUtf8Coder.of()); - static final StateTag> STRING_MAP_ADDR = + private static final StateTag> STRING_MAP_ADDR = StateTags.map("stringMap", StringUtf8Coder.of(), VarIntCoder.of()); - static final StateTag WATERMARK_EARLIEST_ADDR = + private static final StateTag WATERMARK_EARLIEST_ADDR = StateTags.watermarkStateInternal("watermark", TimestampCombiner.EARLIEST); private static final StateTag WATERMARK_LATEST_ADDR = StateTags.watermarkStateInternal("watermark", TimestampCombiner.LATEST); diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkBroadcastStateInternalsTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkBroadcastStateInternalsTest.java index 2b96d91e917a0..3409d276049b3 100644 --- a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkBroadcastStateInternalsTest.java +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkBroadcastStateInternalsTest.java @@ -17,229 +17,87 @@ */ package org.apache.beam.runners.flink.streaming; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertNotEquals; -import static org.junit.Assert.assertThat; - -import java.util.Arrays; -import org.apache.beam.runners.core.StateMerging; -import org.apache.beam.runners.core.StateNamespace; -import org.apache.beam.runners.core.StateNamespaceForTest; -import org.apache.beam.runners.core.StateTag; -import org.apache.beam.runners.core.StateTags; +import org.apache.beam.runners.core.StateInternals; +import org.apache.beam.runners.core.StateInternalsTest; import org.apache.beam.runners.flink.translation.wrappers.streaming.state.FlinkBroadcastStateInternals; -import org.apache.beam.sdk.coders.StringUtf8Coder; -import org.apache.beam.sdk.coders.VarIntCoder; -import org.apache.beam.sdk.state.BagState; -import org.apache.beam.sdk.state.CombiningState; -import org.apache.beam.sdk.state.GroupingState; -import org.apache.beam.sdk.state.ReadableState; -import org.apache.beam.sdk.state.ValueState; -import org.apache.beam.sdk.transforms.Sum; import org.apache.flink.runtime.operators.testutils.DummyEnvironment; import org.apache.flink.runtime.state.OperatorStateBackend; import org.apache.flink.runtime.state.memory.MemoryStateBackend; -import org.hamcrest.Matchers; -import org.junit.Before; -import org.junit.Test; +import org.junit.Ignore; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; /** * Tests for {@link FlinkBroadcastStateInternals}. This is based on the tests for - * {@code InMemoryStateInternals}. + * {@code StateInternalsTest}. + * + *

    Just test value, bag and combining. */ @RunWith(JUnit4.class) -public class FlinkBroadcastStateInternalsTest { - private static final StateNamespace NAMESPACE_1 = new StateNamespaceForTest("ns1"); - private static final StateNamespace NAMESPACE_2 = new StateNamespaceForTest("ns2"); - private static final StateNamespace NAMESPACE_3 = new StateNamespaceForTest("ns3"); - - private static final StateTag> STRING_VALUE_ADDR = - StateTags.value("stringValue", StringUtf8Coder.of()); - private static final StateTag> - SUM_INTEGER_ADDR = StateTags.combiningValueFromInputInternal( - "sumInteger", VarIntCoder.of(), Sum.ofIntegers()); - private static final StateTag> STRING_BAG_ADDR = - StateTags.bag("stringBag", StringUtf8Coder.of()); - - FlinkBroadcastStateInternals underTest; - - @Before - public void initStateInternals() { +public class FlinkBroadcastStateInternalsTest extends StateInternalsTest { + + @Override + protected StateInternals createStateInternals() { MemoryStateBackend backend = new MemoryStateBackend(); try { OperatorStateBackend operatorStateBackend = backend.createOperatorStateBackend(new DummyEnvironment("test", 1, 0), ""); - underTest = new FlinkBroadcastStateInternals<>(1, operatorStateBackend); - + return new FlinkBroadcastStateInternals<>(1, operatorStateBackend); } catch (Exception e) { throw new RuntimeException(e); } } - @Test - public void testValue() throws Exception { - ValueState value = underTest.state(NAMESPACE_1, STRING_VALUE_ADDR); - - assertEquals(underTest.state(NAMESPACE_1, STRING_VALUE_ADDR), value); - assertNotEquals( - underTest.state(NAMESPACE_2, STRING_VALUE_ADDR), - value); - - assertThat(value.read(), Matchers.nullValue()); - value.write("hello"); - assertThat(value.read(), Matchers.equalTo("hello")); - value.write("world"); - assertThat(value.read(), Matchers.equalTo("world")); - - value.clear(); - assertThat(value.read(), Matchers.nullValue()); - assertEquals(underTest.state(NAMESPACE_1, STRING_VALUE_ADDR), value); - - } - - @Test - public void testBag() throws Exception { - BagState value = underTest.state(NAMESPACE_1, STRING_BAG_ADDR); + @Override + @Ignore + public void testSet() {} - assertEquals(value, underTest.state(NAMESPACE_1, STRING_BAG_ADDR)); - assertFalse(value.equals(underTest.state(NAMESPACE_2, STRING_BAG_ADDR))); + @Override + @Ignore + public void testSetIsEmpty() {} - assertThat(value.read(), Matchers.emptyIterable()); - value.add("hello"); - assertThat(value.read(), Matchers.containsInAnyOrder("hello")); + @Override + @Ignore + public void testMergeSetIntoSource() {} - value.add("world"); - assertThat(value.read(), Matchers.containsInAnyOrder("hello", "world")); + @Override + @Ignore + public void testMergeSetIntoNewNamespace() {} - value.clear(); - assertThat(value.read(), Matchers.emptyIterable()); - assertEquals(underTest.state(NAMESPACE_1, STRING_BAG_ADDR), value); + @Override + @Ignore + public void testMap() {} - } - - @Test - public void testBagIsEmpty() throws Exception { - BagState value = underTest.state(NAMESPACE_1, STRING_BAG_ADDR); + @Override + @Ignore + public void testWatermarkEarliestState() {} - assertThat(value.isEmpty().read(), Matchers.is(true)); - ReadableState readFuture = value.isEmpty(); - value.add("hello"); - assertThat(readFuture.read(), Matchers.is(false)); + @Override + @Ignore + public void testWatermarkLatestState() {} - value.clear(); - assertThat(readFuture.read(), Matchers.is(true)); - } + @Override + @Ignore + public void testWatermarkEndOfWindowState() {} - @Test - public void testMergeBagIntoSource() throws Exception { - BagState bag1 = underTest.state(NAMESPACE_1, STRING_BAG_ADDR); - BagState bag2 = underTest.state(NAMESPACE_2, STRING_BAG_ADDR); + @Override + @Ignore + public void testWatermarkStateIsEmpty() {} - bag1.add("Hello"); - bag2.add("World"); - bag1.add("!"); + @Override + @Ignore + public void testMergeEarliestWatermarkIntoSource() {} - StateMerging.mergeBags(Arrays.asList(bag1, bag2), bag1); - - // Reading the merged bag gets both the contents - assertThat(bag1.read(), Matchers.containsInAnyOrder("Hello", "World", "!")); - assertThat(bag2.read(), Matchers.emptyIterable()); - } + @Override + @Ignore + public void testMergeLatestWatermarkIntoSource() {} - @Test - public void testMergeBagIntoNewNamespace() throws Exception { - BagState bag1 = underTest.state(NAMESPACE_1, STRING_BAG_ADDR); - BagState bag2 = underTest.state(NAMESPACE_2, STRING_BAG_ADDR); - BagState bag3 = underTest.state(NAMESPACE_3, STRING_BAG_ADDR); - - bag1.add("Hello"); - bag2.add("World"); - bag1.add("!"); - - StateMerging.mergeBags(Arrays.asList(bag1, bag2, bag3), bag3); - - // Reading the merged bag gets both the contents - assertThat(bag3.read(), Matchers.containsInAnyOrder("Hello", "World", "!")); - assertThat(bag1.read(), Matchers.emptyIterable()); - assertThat(bag2.read(), Matchers.emptyIterable()); - } + @Override + @Ignore + public void testSetReadable() {} - @Test - public void testCombiningValue() throws Exception { - GroupingState value = underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR); - - // State instances are cached, but depend on the namespace. - assertEquals(value, underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR)); - assertFalse(value.equals(underTest.state(NAMESPACE_2, SUM_INTEGER_ADDR))); - - assertThat(value.read(), Matchers.equalTo(0)); - value.add(2); - assertThat(value.read(), Matchers.equalTo(2)); - - value.add(3); - assertThat(value.read(), Matchers.equalTo(5)); - - value.clear(); - assertThat(value.read(), Matchers.equalTo(0)); - assertEquals(underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR), value); - } - - @Test - public void testCombiningIsEmpty() throws Exception { - GroupingState value = underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR); - - assertThat(value.isEmpty().read(), Matchers.is(true)); - ReadableState readFuture = value.isEmpty(); - value.add(5); - assertThat(readFuture.read(), Matchers.is(false)); - - value.clear(); - assertThat(readFuture.read(), Matchers.is(true)); - } - - @Test - public void testMergeCombiningValueIntoSource() throws Exception { - CombiningState value1 = - underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR); - CombiningState value2 = - underTest.state(NAMESPACE_2, SUM_INTEGER_ADDR); - - value1.add(5); - value2.add(10); - value1.add(6); - - assertThat(value1.read(), Matchers.equalTo(11)); - assertThat(value2.read(), Matchers.equalTo(10)); - - // Merging clears the old values and updates the result value. - StateMerging.mergeCombiningValues(Arrays.asList(value1, value2), value1); - - assertThat(value1.read(), Matchers.equalTo(21)); - assertThat(value2.read(), Matchers.equalTo(0)); - } - - @Test - public void testMergeCombiningValueIntoNewNamespace() throws Exception { - CombiningState value1 = - underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR); - CombiningState value2 = - underTest.state(NAMESPACE_2, SUM_INTEGER_ADDR); - CombiningState value3 = - underTest.state(NAMESPACE_3, SUM_INTEGER_ADDR); - - value1.add(5); - value2.add(10); - value1.add(6); - - StateMerging.mergeCombiningValues(Arrays.asList(value1, value2), value3); - - // Merging clears the old values and updates the result value. - assertThat(value1.read(), Matchers.equalTo(0)); - assertThat(value2.read(), Matchers.equalTo(0)); - assertThat(value3.read(), Matchers.equalTo(21)); - } + @Override + @Ignore + public void testMapReadable() {} } diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkKeyGroupStateInternalsTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkKeyGroupStateInternalsTest.java index 40123737d2ae5..aed14f3de0c31 100644 --- a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkKeyGroupStateInternalsTest.java +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkKeyGroupStateInternalsTest.java @@ -17,8 +17,6 @@ */ package org.apache.beam.runners.flink.streaming; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertThat; import java.io.ByteArrayInputStream; @@ -26,8 +24,8 @@ import java.io.DataInputStream; import java.io.DataOutputStream; import java.nio.ByteBuffer; -import java.util.Arrays; -import org.apache.beam.runners.core.StateMerging; +import org.apache.beam.runners.core.StateInternals; +import org.apache.beam.runners.core.StateInternalsTest; import org.apache.beam.runners.core.StateNamespace; import org.apache.beam.runners.core.StateNamespaceForTest; import org.apache.beam.runners.core.StateTag; @@ -35,7 +33,6 @@ import org.apache.beam.runners.flink.translation.wrappers.streaming.state.FlinkKeyGroupStateInternals; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.state.BagState; -import org.apache.beam.sdk.state.ReadableState; import org.apache.beam.sdk.util.CoderUtils; import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.JobID; @@ -47,215 +44,219 @@ import org.apache.flink.runtime.state.KeyGroupRange; import org.apache.flink.runtime.state.KeyedStateBackend; import org.apache.flink.runtime.state.memory.MemoryStateBackend; -import org.apache.flink.streaming.api.operators.KeyContext; import org.hamcrest.Matchers; -import org.junit.Before; +import org.junit.Ignore; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.junit.runners.Suite; /** * Tests for {@link FlinkKeyGroupStateInternals}. This is based on the tests for - * {@code InMemoryStateInternals}. + * {@code StateInternalsTest}. */ -@RunWith(JUnit4.class) +@RunWith(Suite.class) +@Suite.SuiteClasses({ + FlinkKeyGroupStateInternalsTest.StandardStateInternalsTests.class, + FlinkKeyGroupStateInternalsTest.OtherTests.class +}) public class FlinkKeyGroupStateInternalsTest { - private static final StateNamespace NAMESPACE_1 = new StateNamespaceForTest("ns1"); - private static final StateNamespace NAMESPACE_2 = new StateNamespaceForTest("ns2"); - private static final StateNamespace NAMESPACE_3 = new StateNamespaceForTest("ns3"); - private static final StateTag> STRING_BAG_ADDR = - StateTags.bag("stringBag", StringUtf8Coder.of()); - - FlinkKeyGroupStateInternals underTest; - private KeyedStateBackend keyedStateBackend; - - @Before - public void initStateInternals() { - try { - keyedStateBackend = getKeyedStateBackend(2, new KeyGroupRange(0, 1)); - underTest = new FlinkKeyGroupStateInternals<>(StringUtf8Coder.of(), keyedStateBackend); - } catch (Exception e) { - throw new RuntimeException(e); + /** + * A standard StateInternals test. Just test BagState. + */ + @RunWith(JUnit4.class) + public static class StandardStateInternalsTests extends StateInternalsTest { + @Override + protected StateInternals createStateInternals() { + KeyedStateBackend keyedStateBackend = + getKeyedStateBackend(2, new KeyGroupRange(0, 1)); + return new FlinkKeyGroupStateInternals<>(StringUtf8Coder.of(), keyedStateBackend); } - } - private KeyedStateBackend getKeyedStateBackend(int numberOfKeyGroups, - KeyGroupRange keyGroupRange) { - MemoryStateBackend backend = new MemoryStateBackend(); - try { - AbstractKeyedStateBackend keyedStateBackend = backend.createKeyedStateBackend( - new DummyEnvironment("test", 1, 0), - new JobID(), - "test_op", - new GenericTypeInfo<>(ByteBuffer.class).createSerializer(new ExecutionConfig()), - numberOfKeyGroups, - keyGroupRange, - new KvStateRegistry().createTaskRegistry(new JobID(), new JobVertexID())); - keyedStateBackend.setCurrentKey(ByteBuffer.wrap( - CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "1"))); - return keyedStateBackend; - } catch (Exception e) { - throw new RuntimeException(e); - } - } + @Override + @Ignore + public void testValue() {} - @Test - public void testBag() throws Exception { - BagState value = underTest.state(NAMESPACE_1, STRING_BAG_ADDR); + @Override + @Ignore + public void testSet() {} - assertEquals(value, underTest.state(NAMESPACE_1, STRING_BAG_ADDR)); - assertFalse(value.equals(underTest.state(NAMESPACE_2, STRING_BAG_ADDR))); + @Override + @Ignore + public void testSetIsEmpty() {} - assertThat(value.read(), Matchers.emptyIterable()); - value.add("hello"); - assertThat(value.read(), Matchers.containsInAnyOrder("hello")); + @Override + @Ignore + public void testMergeSetIntoSource() {} - value.add("world"); - assertThat(value.read(), Matchers.containsInAnyOrder("hello", "world")); + @Override + @Ignore + public void testMergeSetIntoNewNamespace() {} - value.clear(); - assertThat(value.read(), Matchers.emptyIterable()); - assertEquals(underTest.state(NAMESPACE_1, STRING_BAG_ADDR), value); + @Override + @Ignore + public void testMap() {} - } + @Override + @Ignore + public void testCombiningValue() {} - @Test - public void testBagIsEmpty() throws Exception { - BagState value = underTest.state(NAMESPACE_1, STRING_BAG_ADDR); + @Override + @Ignore + public void testCombiningIsEmpty() {} - assertThat(value.isEmpty().read(), Matchers.is(true)); - ReadableState readFuture = value.isEmpty(); - value.add("hello"); - assertThat(readFuture.read(), Matchers.is(false)); + @Override + @Ignore + public void testMergeCombiningValueIntoSource() {} - value.clear(); - assertThat(readFuture.read(), Matchers.is(true)); - } + @Override + @Ignore + public void testMergeCombiningValueIntoNewNamespace() {} - @Test - public void testMergeBagIntoSource() throws Exception { - BagState bag1 = underTest.state(NAMESPACE_1, STRING_BAG_ADDR); - BagState bag2 = underTest.state(NAMESPACE_2, STRING_BAG_ADDR); + @Override + @Ignore + public void testWatermarkEarliestState() {} - bag1.add("Hello"); - bag2.add("World"); - bag1.add("!"); + @Override + @Ignore + public void testWatermarkLatestState() {} - StateMerging.mergeBags(Arrays.asList(bag1, bag2), bag1); + @Override + @Ignore + public void testWatermarkEndOfWindowState() {} - // Reading the merged bag gets both the contents - assertThat(bag1.read(), Matchers.containsInAnyOrder("Hello", "World", "!")); - assertThat(bag2.read(), Matchers.emptyIterable()); - } + @Override + @Ignore + public void testWatermarkStateIsEmpty() {} - @Test - public void testMergeBagIntoNewNamespace() throws Exception { - BagState bag1 = underTest.state(NAMESPACE_1, STRING_BAG_ADDR); - BagState bag2 = underTest.state(NAMESPACE_2, STRING_BAG_ADDR); - BagState bag3 = underTest.state(NAMESPACE_3, STRING_BAG_ADDR); + @Override + @Ignore + public void testMergeEarliestWatermarkIntoSource() {} - bag1.add("Hello"); - bag2.add("World"); - bag1.add("!"); + @Override + @Ignore + public void testMergeLatestWatermarkIntoSource() {} - StateMerging.mergeBags(Arrays.asList(bag1, bag2, bag3), bag3); + @Override + @Ignore + public void testSetReadable() {} - // Reading the merged bag gets both the contents - assertThat(bag3.read(), Matchers.containsInAnyOrder("Hello", "World", "!")); - assertThat(bag1.read(), Matchers.emptyIterable()); - assertThat(bag2.read(), Matchers.emptyIterable()); + @Override + @Ignore + public void testMapReadable() {} } - @Test - public void testKeyGroupAndCheckpoint() throws Exception { - // assign to keyGroup 0 - ByteBuffer key0 = ByteBuffer.wrap( - CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "11111111")); - // assign to keyGroup 1 - ByteBuffer key1 = ByteBuffer.wrap( - CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "22222222")); - FlinkKeyGroupStateInternals allState; - { - KeyedStateBackend keyedStateBackend = getKeyedStateBackend(2, new KeyGroupRange(0, 1)); - allState = new FlinkKeyGroupStateInternals<>( - StringUtf8Coder.of(), keyedStateBackend); - BagState valueForNamespace0 = allState.state(NAMESPACE_1, STRING_BAG_ADDR); - BagState valueForNamespace1 = allState.state(NAMESPACE_2, STRING_BAG_ADDR); - keyedStateBackend.setCurrentKey(key0); - valueForNamespace0.add("0"); - valueForNamespace1.add("2"); - keyedStateBackend.setCurrentKey(key1); - valueForNamespace0.add("1"); - valueForNamespace1.add("3"); - assertThat(valueForNamespace0.read(), Matchers.containsInAnyOrder("0", "1")); - assertThat(valueForNamespace1.read(), Matchers.containsInAnyOrder("2", "3")); - } - - ClassLoader classLoader = FlinkKeyGroupStateInternalsTest.class.getClassLoader(); - - // 1. scale up - ByteArrayOutputStream out0 = new ByteArrayOutputStream(); - allState.snapshotKeyGroupState(0, new DataOutputStream(out0)); - DataInputStream in0 = new DataInputStream( - new ByteArrayInputStream(out0.toByteArray())); - { - KeyedStateBackend keyedStateBackend = getKeyedStateBackend(2, new KeyGroupRange(0, 0)); - FlinkKeyGroupStateInternals state0 = - new FlinkKeyGroupStateInternals<>( - StringUtf8Coder.of(), keyedStateBackend); - state0.restoreKeyGroupState(0, in0, classLoader); - BagState valueForNamespace0 = state0.state(NAMESPACE_1, STRING_BAG_ADDR); - BagState valueForNamespace1 = state0.state(NAMESPACE_2, STRING_BAG_ADDR); - assertThat(valueForNamespace0.read(), Matchers.containsInAnyOrder("0")); - assertThat(valueForNamespace1.read(), Matchers.containsInAnyOrder("2")); - } - - ByteArrayOutputStream out1 = new ByteArrayOutputStream(); - allState.snapshotKeyGroupState(1, new DataOutputStream(out1)); - DataInputStream in1 = new DataInputStream( - new ByteArrayInputStream(out1.toByteArray())); - { - KeyedStateBackend keyedStateBackend = getKeyedStateBackend(2, new KeyGroupRange(1, 1)); - FlinkKeyGroupStateInternals state1 = - new FlinkKeyGroupStateInternals<>( - StringUtf8Coder.of(), keyedStateBackend); - state1.restoreKeyGroupState(1, in1, classLoader); - BagState valueForNamespace0 = state1.state(NAMESPACE_1, STRING_BAG_ADDR); - BagState valueForNamespace1 = state1.state(NAMESPACE_2, STRING_BAG_ADDR); - assertThat(valueForNamespace0.read(), Matchers.containsInAnyOrder("1")); - assertThat(valueForNamespace1.read(), Matchers.containsInAnyOrder("3")); - } - - // 2. scale down - { - KeyedStateBackend keyedStateBackend = getKeyedStateBackend(2, new KeyGroupRange(0, 1)); - FlinkKeyGroupStateInternals newAllState = new FlinkKeyGroupStateInternals<>( - StringUtf8Coder.of(), keyedStateBackend); - in0.reset(); - in1.reset(); - newAllState.restoreKeyGroupState(0, in0, classLoader); - newAllState.restoreKeyGroupState(1, in1, classLoader); - BagState valueForNamespace0 = newAllState.state(NAMESPACE_1, STRING_BAG_ADDR); - BagState valueForNamespace1 = newAllState.state(NAMESPACE_2, STRING_BAG_ADDR); - assertThat(valueForNamespace0.read(), Matchers.containsInAnyOrder("0", "1")); - assertThat(valueForNamespace1.read(), Matchers.containsInAnyOrder("2", "3")); + /** + * A specific test of FlinkKeyGroupStateInternalsTest. + */ + @RunWith(JUnit4.class) + public static class OtherTests { + + private static final StateNamespace NAMESPACE_1 = new StateNamespaceForTest("ns1"); + private static final StateNamespace NAMESPACE_2 = new StateNamespaceForTest("ns2"); + private static final StateTag> STRING_BAG_ADDR = + StateTags.bag("stringBag", StringUtf8Coder.of()); + + @Test + public void testKeyGroupAndCheckpoint() throws Exception { + // assign to keyGroup 0 + ByteBuffer key0 = ByteBuffer.wrap( + CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "11111111")); + // assign to keyGroup 1 + ByteBuffer key1 = ByteBuffer.wrap( + CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "22222222")); + FlinkKeyGroupStateInternals allState; + { + KeyedStateBackend keyedStateBackend = + getKeyedStateBackend(2, new KeyGroupRange(0, 1)); + allState = new FlinkKeyGroupStateInternals<>( + StringUtf8Coder.of(), keyedStateBackend); + BagState valueForNamespace0 = allState.state(NAMESPACE_1, STRING_BAG_ADDR); + BagState valueForNamespace1 = allState.state(NAMESPACE_2, STRING_BAG_ADDR); + keyedStateBackend.setCurrentKey(key0); + valueForNamespace0.add("0"); + valueForNamespace1.add("2"); + keyedStateBackend.setCurrentKey(key1); + valueForNamespace0.add("1"); + valueForNamespace1.add("3"); + assertThat(valueForNamespace0.read(), Matchers.containsInAnyOrder("0", "1")); + assertThat(valueForNamespace1.read(), Matchers.containsInAnyOrder("2", "3")); + } + + ClassLoader classLoader = FlinkKeyGroupStateInternalsTest.class.getClassLoader(); + + // 1. scale up + ByteArrayOutputStream out0 = new ByteArrayOutputStream(); + allState.snapshotKeyGroupState(0, new DataOutputStream(out0)); + DataInputStream in0 = new DataInputStream( + new ByteArrayInputStream(out0.toByteArray())); + { + KeyedStateBackend keyedStateBackend = + getKeyedStateBackend(2, new KeyGroupRange(0, 0)); + FlinkKeyGroupStateInternals state0 = + new FlinkKeyGroupStateInternals<>( + StringUtf8Coder.of(), keyedStateBackend); + state0.restoreKeyGroupState(0, in0, classLoader); + BagState valueForNamespace0 = state0.state(NAMESPACE_1, STRING_BAG_ADDR); + BagState valueForNamespace1 = state0.state(NAMESPACE_2, STRING_BAG_ADDR); + assertThat(valueForNamespace0.read(), Matchers.containsInAnyOrder("0")); + assertThat(valueForNamespace1.read(), Matchers.containsInAnyOrder("2")); + } + + ByteArrayOutputStream out1 = new ByteArrayOutputStream(); + allState.snapshotKeyGroupState(1, new DataOutputStream(out1)); + DataInputStream in1 = new DataInputStream( + new ByteArrayInputStream(out1.toByteArray())); + { + KeyedStateBackend keyedStateBackend = + getKeyedStateBackend(2, new KeyGroupRange(1, 1)); + FlinkKeyGroupStateInternals state1 = + new FlinkKeyGroupStateInternals<>( + StringUtf8Coder.of(), keyedStateBackend); + state1.restoreKeyGroupState(1, in1, classLoader); + BagState valueForNamespace0 = state1.state(NAMESPACE_1, STRING_BAG_ADDR); + BagState valueForNamespace1 = state1.state(NAMESPACE_2, STRING_BAG_ADDR); + assertThat(valueForNamespace0.read(), Matchers.containsInAnyOrder("1")); + assertThat(valueForNamespace1.read(), Matchers.containsInAnyOrder("3")); + } + + // 2. scale down + { + KeyedStateBackend keyedStateBackend = + getKeyedStateBackend(2, new KeyGroupRange(0, 1)); + FlinkKeyGroupStateInternals newAllState = new FlinkKeyGroupStateInternals<>( + StringUtf8Coder.of(), keyedStateBackend); + in0.reset(); + in1.reset(); + newAllState.restoreKeyGroupState(0, in0, classLoader); + newAllState.restoreKeyGroupState(1, in1, classLoader); + BagState valueForNamespace0 = newAllState.state(NAMESPACE_1, STRING_BAG_ADDR); + BagState valueForNamespace1 = newAllState.state(NAMESPACE_2, STRING_BAG_ADDR); + assertThat(valueForNamespace0.read(), Matchers.containsInAnyOrder("0", "1")); + assertThat(valueForNamespace1.read(), Matchers.containsInAnyOrder("2", "3")); + } } } - private static class TestKeyContext implements KeyContext { - - private Object key; - - @Override - public void setCurrentKey(Object key) { - this.key = key; - } - - @Override - public Object getCurrentKey() { - return key; + private static KeyedStateBackend getKeyedStateBackend(int numberOfKeyGroups, + KeyGroupRange keyGroupRange) { + MemoryStateBackend backend = new MemoryStateBackend(); + try { + AbstractKeyedStateBackend keyedStateBackend = backend.createKeyedStateBackend( + new DummyEnvironment("test", 1, 0), + new JobID(), + "test_op", + new GenericTypeInfo<>(ByteBuffer.class).createSerializer(new ExecutionConfig()), + numberOfKeyGroups, + keyGroupRange, + new KvStateRegistry().createTaskRegistry(new JobID(), new JobVertexID())); + keyedStateBackend.setCurrentKey(ByteBuffer.wrap( + CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "1"))); + return keyedStateBackend; + } catch (Exception e) { + throw new RuntimeException(e); } } diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkSplitStateInternalsTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkSplitStateInternalsTest.java index 17cd3f5d2c845..667b5ba39544c 100644 --- a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkSplitStateInternalsTest.java +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkSplitStateInternalsTest.java @@ -17,85 +17,115 @@ */ package org.apache.beam.runners.flink.streaming; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertThat; - -import org.apache.beam.runners.core.StateNamespace; -import org.apache.beam.runners.core.StateNamespaceForTest; -import org.apache.beam.runners.core.StateTag; -import org.apache.beam.runners.core.StateTags; +import org.apache.beam.runners.core.StateInternals; +import org.apache.beam.runners.core.StateInternalsTest; import org.apache.beam.runners.flink.translation.wrappers.streaming.state.FlinkSplitStateInternals; -import org.apache.beam.sdk.coders.StringUtf8Coder; -import org.apache.beam.sdk.state.BagState; -import org.apache.beam.sdk.state.ReadableState; import org.apache.flink.runtime.operators.testutils.DummyEnvironment; import org.apache.flink.runtime.state.OperatorStateBackend; import org.apache.flink.runtime.state.memory.MemoryStateBackend; -import org.hamcrest.Matchers; -import org.junit.Before; -import org.junit.Test; +import org.junit.Ignore; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; /** * Tests for {@link FlinkSplitStateInternals}. This is based on the tests for - * {@code InMemoryStateInternals}. + * {@code StateInternalsTest}. + * + *

    Just test testBag and testBagIsEmpty. */ @RunWith(JUnit4.class) -public class FlinkSplitStateInternalsTest { - private static final StateNamespace NAMESPACE_1 = new StateNamespaceForTest("ns1"); - private static final StateNamespace NAMESPACE_2 = new StateNamespaceForTest("ns2"); - - private static final StateTag> STRING_BAG_ADDR = - StateTags.bag("stringBag", StringUtf8Coder.of()); - - FlinkSplitStateInternals underTest; +public class FlinkSplitStateInternalsTest extends StateInternalsTest { - @Before - public void initStateInternals() { + @Override + protected StateInternals createStateInternals() { MemoryStateBackend backend = new MemoryStateBackend(); try { OperatorStateBackend operatorStateBackend = backend.createOperatorStateBackend(new DummyEnvironment("test", 1, 0), ""); - underTest = new FlinkSplitStateInternals<>(operatorStateBackend); - + return new FlinkSplitStateInternals<>(operatorStateBackend); } catch (Exception e) { throw new RuntimeException(e); } } - @Test - public void testBag() throws Exception { - BagState value = underTest.state(NAMESPACE_1, STRING_BAG_ADDR); + @Override + @Ignore + public void testMergeBagIntoSource() {} - assertEquals(value, underTest.state(NAMESPACE_1, STRING_BAG_ADDR)); - assertFalse(value.equals(underTest.state(NAMESPACE_2, STRING_BAG_ADDR))); + @Override + @Ignore + public void testMergeBagIntoNewNamespace() {} - assertThat(value.read(), Matchers.emptyIterable()); - value.add("hello"); - assertThat(value.read(), Matchers.containsInAnyOrder("hello")); + @Override + @Ignore + public void testValue() {} - value.add("world"); - assertThat(value.read(), Matchers.containsInAnyOrder("hello", "world")); + @Override + @Ignore + public void testSet() {} - value.clear(); - assertThat(value.read(), Matchers.emptyIterable()); - assertEquals(underTest.state(NAMESPACE_1, STRING_BAG_ADDR), value); + @Override + @Ignore + public void testSetIsEmpty() {} - } + @Override + @Ignore + public void testMergeSetIntoSource() {} - @Test - public void testBagIsEmpty() throws Exception { - BagState value = underTest.state(NAMESPACE_1, STRING_BAG_ADDR); + @Override + @Ignore + public void testMergeSetIntoNewNamespace() {} - assertThat(value.isEmpty().read(), Matchers.is(true)); - ReadableState readFuture = value.isEmpty(); - value.add("hello"); - assertThat(readFuture.read(), Matchers.is(false)); + @Override + @Ignore + public void testMap() {} - value.clear(); - assertThat(readFuture.read(), Matchers.is(true)); - } + @Override + @Ignore + public void testCombiningValue() {} + + @Override + @Ignore + public void testCombiningIsEmpty() {} + + @Override + @Ignore + public void testMergeCombiningValueIntoSource() {} + + @Override + @Ignore + public void testMergeCombiningValueIntoNewNamespace() {} + + @Override + @Ignore + public void testWatermarkEarliestState() {} + + @Override + @Ignore + public void testWatermarkLatestState() {} + + @Override + @Ignore + public void testWatermarkEndOfWindowState() {} + + @Override + @Ignore + public void testWatermarkStateIsEmpty() {} + + @Override + @Ignore + public void testMergeEarliestWatermarkIntoSource() {} + + @Override + @Ignore + public void testMergeLatestWatermarkIntoSource() {} + + @Override + @Ignore + public void testSetReadable() {} + + @Override + @Ignore + public void testMapReadable() {} } diff --git a/runners/spark/pom.xml b/runners/spark/pom.xml index ddb4aca73327b..d1dba323b94ff 100644 --- a/runners/spark/pom.xml +++ b/runners/spark/pom.xml @@ -321,6 +321,13 @@ test-jar test + + + org.apache.beam + beam-runners-core-java + test-jar + test + diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/stateful/SparkStateInternalsTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/stateful/SparkStateInternalsTest.java new file mode 100644 index 0000000000000..b4597f9ef86af --- /dev/null +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/stateful/SparkStateInternalsTest.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.stateful; + +import org.apache.beam.runners.core.StateInternals; +import org.apache.beam.runners.core.StateInternalsTest; +import org.junit.Ignore; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Tests for {@link SparkStateInternals}. This is based on {@link StateInternalsTest}. + * Ignore set and map tests. + */ +@RunWith(JUnit4.class) +public class SparkStateInternalsTest extends StateInternalsTest { + + @Override + protected StateInternals createStateInternals() { + return SparkStateInternals.forKey("dummyKey"); + } + + @Override + @Ignore + public void testSet() {} + + @Override + @Ignore + public void testSetIsEmpty() {} + + @Override + @Ignore + public void testMergeSetIntoSource() {} + + @Override + @Ignore + public void testMergeSetIntoNewNamespace() {} + + @Override + @Ignore + public void testMap() {} + + @Override + @Ignore + public void testSetReadable() {} + + @Override + @Ignore + public void testMapReadable() {} + +} From 027d89c91f8851195dedbab2879699738376ff77 Mon Sep 17 00:00:00 2001 From: Thomas Groh Date: Thu, 15 Jun 2017 13:54:47 -0700 Subject: [PATCH 045/200] Use the appropriate context in CombineTest Coder The Accumulator was improperly decoding the seed value in the outer context, as it is in the nested context. --- .../java/org/apache/beam/sdk/transforms/CombineTest.java | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/CombineTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/CombineTest.java index 6a4348de5567e..c4ba62d148fae 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/CombineTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/CombineTest.java @@ -1000,7 +1000,7 @@ public void encode(Accumulator accumulator, OutputStream outStream) @Override public void encode(Accumulator accumulator, OutputStream outStream, Coder.Context context) throws CoderException, IOException { - StringUtf8Coder.of().encode(accumulator.seed, outStream, context); + StringUtf8Coder.of().encode(accumulator.seed, outStream, context.nested()); StringUtf8Coder.of().encode(accumulator.value, outStream, context); } @@ -1012,9 +1012,9 @@ public Accumulator decode(InputStream inStream) throws CoderException, IOExcepti @Override public Accumulator decode(InputStream inStream, Coder.Context context) throws CoderException, IOException { - return new Accumulator( - StringUtf8Coder.of().decode(inStream, context), - StringUtf8Coder.of().decode(inStream, context)); + String seed = StringUtf8Coder.of().decode(inStream, context.nested()); + String value = StringUtf8Coder.of().decode(inStream, context); + return new Accumulator(seed, value); } }; } From c597a020c87f272a5920bf29195d9e314af2f828 Mon Sep 17 00:00:00 2001 From: Marian Dvorsky Date: Fri, 16 Jun 2017 15:57:13 +0200 Subject: [PATCH 046/200] Fixed handling of use_public_ips. Added test. --- .../apache_beam/options/pipeline_options.py | 9 ++++++- .../dataflow/internal/apiclient_test.py | 24 +++++++++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/options/pipeline_options.py b/sdks/python/apache_beam/options/pipeline_options.py index 283b340ecfc37..8644e51b2dbe0 100644 --- a/sdks/python/apache_beam/options/pipeline_options.py +++ b/sdks/python/apache_beam/options/pipeline_options.py @@ -465,7 +465,14 @@ def _add_argparse_args(cls, parser): parser.add_argument( '--use_public_ips', default=None, - help='Whether to assign public IP addresses to the worker machines.') + action='store_true', + help='Whether to assign public IP addresses to the worker VMs.') + parser.add_argument( + '--no_use_public_ips', + dest='use_public_ips', + default=None, + action='store_false', + help='Whether to assign only private IP addresses to the worker VMs.') def validate(self, validator): errors = [] diff --git a/sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py b/sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py index 67cf77fcb93f2..55211f7588aae 100644 --- a/sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py +++ b/sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py @@ -122,6 +122,30 @@ def test_translate_means(self): self.assertEqual( metric_update.floatingPointMean.count.lowBits, accumulator.count) + def test_default_ip_configuration(self): + pipeline_options = PipelineOptions( + ['--temp_location', 'gs://any-location/temp']) + env = apiclient.Environment([], pipeline_options, '2.0.0') + self.assertEqual(env.proto.workerPools[0].ipConfiguration, None) + + def test_public_ip_configuration(self): + pipeline_options = PipelineOptions( + ['--temp_location', 'gs://any-location/temp', + '--use_public_ips']) + env = apiclient.Environment([], pipeline_options, '2.0.0') + self.assertEqual( + env.proto.workerPools[0].ipConfiguration, + dataflow.WorkerPool.IpConfigurationValueValuesEnum.WORKER_IP_PUBLIC) + + def test_private_ip_configuration(self): + pipeline_options = PipelineOptions( + ['--temp_location', 'gs://any-location/temp', + '--no_use_public_ips']) + env = apiclient.Environment([], pipeline_options, '2.0.0') + self.assertEqual( + env.proto.workerPools[0].ipConfiguration, + dataflow.WorkerPool.IpConfigurationValueValuesEnum.WORKER_IP_PRIVATE) + if __name__ == '__main__': unittest.main() From 5aee624cbc2815efaf04c7e4854138370a45a1f6 Mon Sep 17 00:00:00 2001 From: Charles Chen Date: Thu, 15 Jun 2017 14:27:47 -0700 Subject: [PATCH 047/200] Introduce pending bundles and RootBundleProvider in DirectRunner --- .../runners/direct/bundle_factory.py | 2 +- .../apache_beam/runners/direct/executor.py | 64 +++++++++++-------- .../runners/direct/transform_evaluator.py | 39 ++++++++++- 3 files changed, 77 insertions(+), 28 deletions(-) diff --git a/sdks/python/apache_beam/runners/direct/bundle_factory.py b/sdks/python/apache_beam/runners/direct/bundle_factory.py index ed00b03310bbd..0182b4c9e9cad 100644 --- a/sdks/python/apache_beam/runners/direct/bundle_factory.py +++ b/sdks/python/apache_beam/runners/direct/bundle_factory.py @@ -108,7 +108,7 @@ def windowed_values(self): self._initial_windowed_value.windows) def __init__(self, pcollection, stacked=True): - assert isinstance(pcollection, pvalue.PCollection) + assert isinstance(pcollection, (pvalue.PBegin, pvalue.PCollection)) self._pcollection = pcollection self._elements = [] self._stacked = stacked diff --git a/sdks/python/apache_beam/runners/direct/executor.py b/sdks/python/apache_beam/runners/direct/executor.py index 86db29159f375..eff2d3c41e66d 100644 --- a/sdks/python/apache_beam/runners/direct/executor.py +++ b/sdks/python/apache_beam/runners/direct/executor.py @@ -20,6 +20,7 @@ from __future__ import absolute_import import collections +import itertools import logging import Queue import sys @@ -250,12 +251,12 @@ class TransformExecutor(_ExecutorService.CallableTask): """ def __init__(self, transform_evaluator_registry, evaluation_context, - input_bundle, applied_transform, completion_callback, + input_bundle, applied_ptransform, completion_callback, transform_evaluation_state): self._transform_evaluator_registry = transform_evaluator_registry self._evaluation_context = evaluation_context self._input_bundle = input_bundle - self._applied_transform = applied_transform + self._applied_ptransform = applied_ptransform self._completion_callback = completion_callback self._transform_evaluation_state = transform_evaluation_state self._side_input_values = {} @@ -264,11 +265,11 @@ def __init__(self, transform_evaluator_registry, evaluation_context, def call(self): self._call_count += 1 - assert self._call_count <= (1 + len(self._applied_transform.side_inputs)) - metrics_container = MetricsContainer(self._applied_transform.full_label) + assert self._call_count <= (1 + len(self._applied_ptransform.side_inputs)) + metrics_container = MetricsContainer(self._applied_ptransform.full_label) scoped_metrics_container = ScopedMetricsContainer(metrics_container) - for side_input in self._applied_transform.side_inputs: + for side_input in self._applied_ptransform.side_inputs: if side_input not in self._side_input_values: has_result, value = ( self._evaluation_context.get_value_or_schedule_after_output( @@ -280,11 +281,11 @@ def call(self): self._side_input_values[side_input] = value side_input_values = [self._side_input_values[side_input] - for side_input in self._applied_transform.side_inputs] + for side_input in self._applied_ptransform.side_inputs] try: - evaluator = self._transform_evaluator_registry.for_application( - self._applied_transform, self._input_bundle, + evaluator = self._transform_evaluator_registry.get_evaluator( + self._applied_ptransform, self._input_bundle, side_input_values, scoped_metrics_container) if self._input_bundle: @@ -298,13 +299,13 @@ def call(self): if self._evaluation_context.has_cache: for uncommitted_bundle in result.uncommitted_output_bundles: self._evaluation_context.append_to_cache( - self._applied_transform, uncommitted_bundle.tag, + self._applied_ptransform, uncommitted_bundle.tag, uncommitted_bundle.get_elements_iterable()) undeclared_tag_values = result.undeclared_tag_values if undeclared_tag_values: for tag, value in undeclared_tag_values.iteritems(): self._evaluation_context.append_to_cache( - self._applied_transform, tag, value) + self._applied_ptransform, tag, value) self._completion_callback.handle_result(self._input_bundle, result) return result @@ -353,6 +354,15 @@ def __init__(self, value_to_consumers, transform_evaluator_registry, def start(self, roots): self.root_nodes = frozenset(roots) + self.all_nodes = frozenset( + itertools.chain( + roots, + *itertools.chain(self.value_to_consumers.values()))) + self.node_to_pending_bundles = {} + for root_node in self.root_nodes: + provider = (self.transform_evaluator_registry + .get_root_bundle_provider(root_node)) + self.node_to_pending_bundles[root_node] = provider.get_root_bundles() self.executor_service.submit( _ExecutorServiceParallelExecutor._MonitorTask(self)) @@ -372,22 +382,22 @@ def schedule_consumers(self, committed_bundle): self.schedule_consumption(applied_ptransform, committed_bundle, self.default_completion_callback) - def schedule_consumption(self, consumer_applied_transform, committed_bundle, + def schedule_consumption(self, consumer_applied_ptransform, committed_bundle, on_complete): """Schedules evaluation of the given bundle with the transform.""" - assert all([consumer_applied_transform, on_complete]) - assert committed_bundle or consumer_applied_transform in self.root_nodes - if (committed_bundle - and self.transform_evaluator_registry.should_execute_serially( - consumer_applied_transform)): + assert consumer_applied_ptransform + assert committed_bundle + assert on_complete + if self.transform_evaluator_registry.should_execute_serially( + consumer_applied_ptransform): transform_executor_service = self.transform_executor_services.serial( - consumer_applied_transform) + consumer_applied_ptransform) else: transform_executor_service = self.transform_executor_services.parallel() transform_executor = TransformExecutor( self.transform_evaluator_registry, self.evaluation_context, - committed_bundle, consumer_applied_transform, on_complete, + committed_bundle, consumer_applied_ptransform, on_complete, transform_executor_service) transform_executor_service.schedule(transform_executor) @@ -564,10 +574,14 @@ def _add_work_if_necessary(self, timers_fired): # additional work. return - # All current TransformExecutors are blocked; add more work from the - # roots. - for applied_transform in self._executor.root_nodes: - if not self._executor.evaluation_context.is_done(applied_transform): - self._executor.schedule_consumption( - applied_transform, None, - self._executor.default_completion_callback) + # All current TransformExecutors are blocked; add more work from any + # pending bundles. + for applied_ptransform in self._executor.all_nodes: + if not self._executor.evaluation_context.is_done(applied_ptransform): + pending_bundles = self._executor.node_to_pending_bundles.get( + applied_ptransform, []) + for bundle in pending_bundles: + self._executor.schedule_consumption( + applied_ptransform, bundle, + self._executor.default_completion_callback) + self._executor.node_to_pending_bundles[applied_ptransform] = [] diff --git a/sdks/python/apache_beam/runners/direct/transform_evaluator.py b/sdks/python/apache_beam/runners/direct/transform_evaluator.py index f5b5db5c0a773..6e73561d3fe08 100644 --- a/sdks/python/apache_beam/runners/direct/transform_evaluator.py +++ b/sdks/python/apache_beam/runners/direct/transform_evaluator.py @@ -58,8 +58,11 @@ def __init__(self, evaluation_context): core._GroupByKeyOnly: _GroupByKeyOnlyEvaluator, _NativeWrite: _NativeWriteEvaluator, } + self._root_bundle_providers = { + core.PTransform: DefaultRootBundleProvider, + } - def for_application( + def get_evaluator( self, applied_ptransform, input_committed_bundle, side_inputs, scoped_metrics_container): """Returns a TransformEvaluator suitable for processing given inputs.""" @@ -81,6 +84,18 @@ def for_application( input_committed_bundle, side_inputs, scoped_metrics_container) + def get_root_bundle_provider(self, applied_ptransform): + provider_cls = None + for cls in applied_ptransform.transform.__class__.mro(): + provider_cls = self._root_bundle_providers.get(cls) + if provider_cls: + break + if not provider_cls: + raise NotImplementedError( + 'Root provider for [%s] not implemented in runner %s' % ( + type(applied_ptransform.transform), self)) + return provider_cls(self._evaluation_context, applied_ptransform) + def should_execute_serially(self, applied_ptransform): """Returns True if this applied_ptransform should run one bundle at a time. @@ -104,6 +119,27 @@ def should_execute_serially(self, applied_ptransform): (core._GroupByKeyOnly, _NativeWrite)) +class RootBundleProvider(object): + """Provides bundles for the initial execution of a root transform.""" + + def __init__(self, evaluation_context, applied_ptransform): + self._evaluation_context = evaluation_context + self._applied_ptransform = applied_ptransform + + def get_root_bundles(self): + raise NotImplementedError + + +class DefaultRootBundleProvider(RootBundleProvider): + """Provides an empty bundle by default for root transforms.""" + + def get_root_bundles(self): + input_node = pvalue.PBegin(self._applied_ptransform.transform.pipeline) + empty_bundle = ( + self._evaluation_context.create_empty_committed_bundle(input_node)) + return [empty_bundle] + + class _TransformEvaluator(object): """An evaluator of a specific application of a transform.""" @@ -180,7 +216,6 @@ class _BoundedReadEvaluator(_TransformEvaluator): def __init__(self, evaluation_context, applied_ptransform, input_committed_bundle, side_inputs, scoped_metrics_container): - assert not input_committed_bundle assert not side_inputs self._source = applied_ptransform.transform.source self._source.pipeline_options = evaluation_context.pipeline_options From c9c1a05dc07a9a7e57fefbe6e43f723b330499d5 Mon Sep 17 00:00:00 2001 From: Luke Cwik Date: Thu, 15 Jun 2017 16:36:22 -0700 Subject: [PATCH 048/200] [BEAM-1347] Break apart ProcessBundleHandler to use service locator pattern based upon URNs. This cleans up ProcessBundleHandler and allows for separate improvements of the various PTransform handler factories. --- sdks/java/harness/pom.xml | 6 + .../harness/control/ProcessBundleHandler.java | 293 +++------- .../runners/core/BeamFnDataReadRunner.java | 70 ++- .../runners/core/BeamFnDataWriteRunner.java | 67 ++- .../runners/core/BoundedSourceRunner.java | 74 ++- .../beam/runners/core/DoFnRunnerFactory.java | 182 ++++++ .../runners/core/PTransformRunnerFactory.java | 81 +++ .../control/ProcessBundleHandlerTest.java | 521 +++--------------- .../core/BeamFnDataReadRunnerTest.java | 112 +++- .../core/BeamFnDataWriteRunnerTest.java | 120 +++- .../runners/core/BoundedSourceRunnerTest.java | 124 ++++- .../runners/core/DoFnRunnerFactoryTest.java | 209 +++++++ 12 files changed, 1134 insertions(+), 725 deletions(-) create mode 100644 sdks/java/harness/src/main/java/org/apache/beam/runners/core/DoFnRunnerFactory.java create mode 100644 sdks/java/harness/src/main/java/org/apache/beam/runners/core/PTransformRunnerFactory.java create mode 100644 sdks/java/harness/src/test/java/org/apache/beam/runners/core/DoFnRunnerFactoryTest.java diff --git a/sdks/java/harness/pom.xml b/sdks/java/harness/pom.xml index 61a170ae4afba..a35481d7b58d7 100644 --- a/sdks/java/harness/pom.xml +++ b/sdks/java/harness/pom.xml @@ -154,6 +154,12 @@ slf4j-api + + com.google.auto.service + auto-service + true + + org.hamcrest diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java index e33277af15bcc..4c4f73d4326b0 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java @@ -18,51 +18,32 @@ package org.apache.beam.fn.harness.control; -import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.collect.Iterables.getOnlyElement; - -import com.google.common.collect.Collections2; +import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.HashMultimap; import com.google.common.collect.ImmutableMap; -import com.google.common.collect.ImmutableMultimap; import com.google.common.collect.Lists; import com.google.common.collect.Multimap; -import com.google.protobuf.ByteString; -import com.google.protobuf.BytesValue; -import com.google.protobuf.InvalidProtocolBufferException; +import com.google.common.collect.Sets; import com.google.protobuf.Message; import java.io.IOException; import java.util.ArrayList; -import java.util.Collection; -import java.util.HashSet; import java.util.List; import java.util.Map; -import java.util.Objects; +import java.util.ServiceLoader; +import java.util.Set; import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Supplier; import org.apache.beam.fn.harness.data.BeamFnDataClient; -import org.apache.beam.fn.harness.fake.FakeStepContext; import org.apache.beam.fn.harness.fn.ThrowingConsumer; import org.apache.beam.fn.harness.fn.ThrowingRunnable; import org.apache.beam.fn.v1.BeamFnApi; -import org.apache.beam.runners.core.BeamFnDataReadRunner; -import org.apache.beam.runners.core.BeamFnDataWriteRunner; -import org.apache.beam.runners.core.BoundedSourceRunner; -import org.apache.beam.runners.core.DoFnRunner; -import org.apache.beam.runners.core.DoFnRunners; -import org.apache.beam.runners.core.DoFnRunners.OutputManager; -import org.apache.beam.runners.core.NullSideInputReader; -import org.apache.beam.runners.dataflow.util.DoFnInfo; +import org.apache.beam.runners.core.PTransformRunnerFactory; +import org.apache.beam.runners.core.PTransformRunnerFactory.Registrar; import org.apache.beam.sdk.common.runner.v1.RunnerApi; -import org.apache.beam.sdk.io.BoundedSource; import org.apache.beam.sdk.options.PipelineOptions; -import org.apache.beam.sdk.options.PipelineOptionsFactory; -import org.apache.beam.sdk.transforms.DoFn; -import org.apache.beam.sdk.util.SerializableUtils; import org.apache.beam.sdk.util.WindowedValue; -import org.apache.beam.sdk.values.TupleTag; -import org.apache.beam.sdk.values.WindowingStrategy; +import org.apache.beam.sdk.util.common.ReflectHelpers; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -75,25 +56,73 @@ * and finishing all runners in forward topological order. */ public class ProcessBundleHandler { + // TODO: What should the initial set of URNs be? private static final String DATA_INPUT_URN = "urn:org.apache.beam:source:runner:0.1"; - private static final String DATA_OUTPUT_URN = "urn:org.apache.beam:sink:runner:0.1"; - private static final String JAVA_DO_FN_URN = "urn:org.apache.beam:dofn:java:0.1"; - private static final String JAVA_SOURCE_URN = "urn:org.apache.beam:source:java:0.1"; + public static final String JAVA_SOURCE_URN = "urn:org.apache.beam:source:java:0.1"; private static final Logger LOG = LoggerFactory.getLogger(ProcessBundleHandler.class); + private static final Map REGISTERED_RUNNER_FACTORIES; + + static { + Set pipelineRunnerRegistrars = + Sets.newTreeSet(ReflectHelpers.ObjectsClassComparator.INSTANCE); + pipelineRunnerRegistrars.addAll( + Lists.newArrayList(ServiceLoader.load(Registrar.class, + ReflectHelpers.findClassLoader()))); + + // Load all registered PTransform runner factories. + ImmutableMap.Builder builder = + ImmutableMap.builder(); + for (Registrar registrar : pipelineRunnerRegistrars) { + builder.putAll(registrar.getPTransformRunnerFactories()); + } + REGISTERED_RUNNER_FACTORIES = builder.build(); + } private final PipelineOptions options; private final Function fnApiRegistry; private final BeamFnDataClient beamFnDataClient; + private final Map urnToPTransformRunnerFactoryMap; + private final PTransformRunnerFactory defaultPTransformRunnerFactory; + public ProcessBundleHandler( PipelineOptions options, Function fnApiRegistry, BeamFnDataClient beamFnDataClient) { + this(options, fnApiRegistry, beamFnDataClient, REGISTERED_RUNNER_FACTORIES); + } + + @VisibleForTesting + ProcessBundleHandler( + PipelineOptions options, + Function fnApiRegistry, + BeamFnDataClient beamFnDataClient, + Map urnToPTransformRunnerFactoryMap) { this.options = options; this.fnApiRegistry = fnApiRegistry; this.beamFnDataClient = beamFnDataClient; + this.urnToPTransformRunnerFactoryMap = urnToPTransformRunnerFactoryMap; + this.defaultPTransformRunnerFactory = new PTransformRunnerFactory() { + @Override + public Object createRunnerForPTransform( + PipelineOptions pipelineOptions, + BeamFnDataClient beamFnDataClient, + String pTransformId, + RunnerApi.PTransform pTransform, + Supplier processBundleInstructionId, + Map pCollections, + Map coders, + Multimap>> pCollectionIdsToConsumers, + Consumer addStartFunction, + Consumer addFinishFunction) { + throw new IllegalStateException(String.format( + "No factory registered for %s, known factories %s", + pTransform.getSpec().getUrn(), + urnToPTransformRunnerFactoryMap.keySet())); + } + }; } private void createRunnerAndConsumersForPTransformRecursively( @@ -128,115 +157,19 @@ private void createRunnerAndConsumersForPTransformRecursively( } } - createRunnerForPTransform( - pTransformId, - pTransform, - processBundleInstructionId, - processBundleDescriptor.getPcollectionsMap(), - pCollectionIdsToConsumers, - addStartFunction, - addFinishFunction); - } - - protected void createRunnerForPTransform( - String pTransformId, - RunnerApi.PTransform pTransform, - Supplier processBundleInstructionId, - Map pCollections, - Multimap>> pCollectionIdsToConsumers, - Consumer addStartFunction, - Consumer addFinishFunction) throws IOException { - - - // For every output PCollection, create a map from output name to Consumer - ImmutableMap.Builder>>> - outputMapBuilder = ImmutableMap.builder(); - for (Map.Entry entry : pTransform.getOutputsMap().entrySet()) { - outputMapBuilder.put( - entry.getKey(), - pCollectionIdsToConsumers.get(entry.getValue())); - } - ImmutableMap>>> outputMap = - outputMapBuilder.build(); - - - // Based upon the function spec, populate the start/finish/consumer information. - RunnerApi.FunctionSpec functionSpec = pTransform.getSpec(); - ThrowingConsumer> consumer; - switch (functionSpec.getUrn()) { - default: - BeamFnApi.Target target; - RunnerApi.Coder coderSpec; - throw new IllegalArgumentException( - String.format("Unknown FunctionSpec %s", functionSpec)); - - case DATA_OUTPUT_URN: - target = BeamFnApi.Target.newBuilder() - .setPrimitiveTransformReference(pTransformId) - .setName(getOnlyElement(pTransform.getInputsMap().keySet())) - .build(); - coderSpec = (RunnerApi.Coder) fnApiRegistry.apply( - pCollections.get(getOnlyElement(pTransform.getInputsMap().values())).getCoderId()); - BeamFnDataWriteRunner remoteGrpcWriteRunner = - new BeamFnDataWriteRunner( - functionSpec, - processBundleInstructionId, - target, - coderSpec, - beamFnDataClient); - addStartFunction.accept(remoteGrpcWriteRunner::registerForOutput); - consumer = (ThrowingConsumer) - (ThrowingConsumer>) remoteGrpcWriteRunner::consume; - addFinishFunction.accept(remoteGrpcWriteRunner::close); - break; - - case DATA_INPUT_URN: - target = BeamFnApi.Target.newBuilder() - .setPrimitiveTransformReference(pTransformId) - .setName(getOnlyElement(pTransform.getOutputsMap().keySet())) - .build(); - coderSpec = (RunnerApi.Coder) fnApiRegistry.apply( - pCollections.get(getOnlyElement(pTransform.getOutputsMap().values())).getCoderId()); - BeamFnDataReadRunner remoteGrpcReadRunner = - new BeamFnDataReadRunner( - functionSpec, - processBundleInstructionId, - target, - coderSpec, - beamFnDataClient, - (Map) outputMap); - addStartFunction.accept(remoteGrpcReadRunner::registerInputLocation); - consumer = null; - addFinishFunction.accept(remoteGrpcReadRunner::blockTillReadFinishes); - break; - - case JAVA_DO_FN_URN: - DoFnRunner doFnRunner = createDoFnRunner(functionSpec, (Map) outputMap); - addStartFunction.accept(doFnRunner::startBundle); - consumer = (ThrowingConsumer) - (ThrowingConsumer>) doFnRunner::processElement; - addFinishFunction.accept(doFnRunner::finishBundle); - break; - - case JAVA_SOURCE_URN: - @SuppressWarnings({"unchecked", "rawtypes"}) - BoundedSourceRunner, Object> sourceRunner = - createBoundedSourceRunner(functionSpec, (Map) outputMap); - // TODO: Remove and replace with source being sent across gRPC port - addStartFunction.accept(sourceRunner::start); - consumer = (ThrowingConsumer) - (ThrowingConsumer>>) - sourceRunner::runReadLoop; - break; - } - - // If we created a consumer, add it to the map containing PCollection ids to consumers - if (consumer != null) { - for (String inputPCollectionId : - pTransform.getInputsMap().values()) { - pCollectionIdsToConsumers.put(inputPCollectionId, consumer); - } - } + urnToPTransformRunnerFactoryMap.getOrDefault( + pTransform.getSpec().getUrn(), defaultPTransformRunnerFactory) + .createRunnerForPTransform( + options, + beamFnDataClient, + pTransformId, + pTransform, + processBundleInstructionId, + processBundleDescriptor.getPcollectionsMap(), + processBundleDescriptor.getCodersyyyMap(), + pCollectionIdsToConsumers, + addStartFunction, + addFinishFunction); } public BeamFnApi.InstructionResponse.Builder processBundle(BeamFnApi.InstructionRequest request) @@ -299,88 +232,4 @@ public BeamFnApi.InstructionResponse.Builder processBundle(BeamFnApi.Instruction return response; } - - /** - * Converts a {@link org.apache.beam.fn.v1.BeamFnApi.FunctionSpec} into a {@link DoFnRunner}. - */ - private DoFnRunner createDoFnRunner( - RunnerApi.FunctionSpec functionSpec, - Map>>> outputMap) { - ByteString serializedFn; - try { - serializedFn = functionSpec.getParameter().unpack(BytesValue.class).getValue(); - } catch (InvalidProtocolBufferException e) { - throw new IllegalArgumentException( - String.format("Unable to unwrap DoFn %s", functionSpec), e); - } - DoFnInfo doFnInfo = - (DoFnInfo) - SerializableUtils.deserializeFromByteArray(serializedFn.toByteArray(), "DoFnInfo"); - - checkArgument( - Objects.equals( - new HashSet<>(Collections2.transform(outputMap.keySet(), Long::parseLong)), - doFnInfo.getOutputMap().keySet()), - "Unexpected mismatch between transform output map %s and DoFnInfo output map %s.", - outputMap.keySet(), - doFnInfo.getOutputMap()); - - ImmutableMultimap.Builder, - ThrowingConsumer>> tagToOutput = - ImmutableMultimap.builder(); - for (Map.Entry> entry : doFnInfo.getOutputMap().entrySet()) { - tagToOutput.putAll(entry.getValue(), outputMap.get(Long.toString(entry.getKey()))); - } - @SuppressWarnings({"unchecked", "rawtypes"}) - final Map, Collection>>> tagBasedOutputMap = - (Map) tagToOutput.build().asMap(); - - OutputManager outputManager = - new OutputManager() { - Map, Collection>>> tupleTagToOutput = - tagBasedOutputMap; - - @Override - public void output(TupleTag tag, WindowedValue output) { - try { - Collection>> consumers = - tupleTagToOutput.get(tag); - if (consumers == null) { - /* This is a normal case, e.g., if a DoFn has output but that output is not - * consumed. Drop the output. */ - return; - } - for (ThrowingConsumer> consumer : consumers) { - consumer.accept(output); - } - } catch (Throwable t) { - throw new RuntimeException(t); - } - } - }; - - @SuppressWarnings({"unchecked", "rawtypes", "deprecation"}) - DoFnRunner runner = - DoFnRunners.simpleRunner( - PipelineOptionsFactory.create(), /* TODO */ - (DoFn) doFnInfo.getDoFn(), - NullSideInputReader.empty(), /* TODO */ - outputManager, - (TupleTag) doFnInfo.getOutputMap().get(doFnInfo.getMainOutput()), - new ArrayList<>(doFnInfo.getOutputMap().values()), - new FakeStepContext(), - (WindowingStrategy) doFnInfo.getWindowingStrategy()); - return runner; - } - - private , OutputT> - BoundedSourceRunner createBoundedSourceRunner( - RunnerApi.FunctionSpec functionSpec, - Map>>> outputMap) { - - @SuppressWarnings({"rawtypes", "unchecked"}) - BoundedSourceRunner runner = - new BoundedSourceRunner(options, functionSpec, outputMap); - return runner; - } } diff --git a/sdks/java/harness/src/main/java/org/apache/beam/runners/core/BeamFnDataReadRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/runners/core/BeamFnDataReadRunner.java index f0fe2748d51e9..9339347d4f61b 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/runners/core/BeamFnDataReadRunner.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/runners/core/BeamFnDataReadRunner.java @@ -18,22 +18,28 @@ package org.apache.beam.runners.core; +import static com.google.common.collect.Iterables.getOnlyElement; + import com.fasterxml.jackson.databind.ObjectMapper; -import com.google.common.collect.FluentIterable; -import com.google.common.collect.ImmutableList; +import com.google.auto.service.AutoService; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Multimap; import com.google.protobuf.BytesValue; import java.io.IOException; import java.util.Collection; import java.util.Map; import java.util.concurrent.CompletableFuture; +import java.util.function.Consumer; import java.util.function.Supplier; import org.apache.beam.fn.harness.data.BeamFnDataClient; import org.apache.beam.fn.harness.fn.ThrowingConsumer; +import org.apache.beam.fn.harness.fn.ThrowingRunnable; import org.apache.beam.fn.v1.BeamFnApi; import org.apache.beam.runners.dataflow.util.CloudObject; import org.apache.beam.runners.dataflow.util.CloudObjects; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.common.runner.v1.RunnerApi; +import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.KV; import org.slf4j.Logger; @@ -48,9 +54,61 @@ * {@link #blockTillReadFinishes()} to finish. */ public class BeamFnDataReadRunner { - private static final Logger LOG = LoggerFactory.getLogger(BeamFnDataReadRunner.class); + private static final Logger LOG = LoggerFactory.getLogger(BeamFnDataReadRunner.class); private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + private static final String URN = "urn:org.apache.beam:source:runner:0.1"; + + /** A registrar which provides a factory to handle reading from the Fn Api Data Plane. */ + @AutoService(PTransformRunnerFactory.Registrar.class) + public static class Registrar implements + PTransformRunnerFactory.Registrar { + + @Override + public Map getPTransformRunnerFactories() { + return ImmutableMap.of(URN, new Factory()); + } + } + + /** A factory for {@link BeamFnDataReadRunner}s. */ + static class Factory + implements PTransformRunnerFactory> { + + @Override + public BeamFnDataReadRunner createRunnerForPTransform( + PipelineOptions pipelineOptions, + BeamFnDataClient beamFnDataClient, + String pTransformId, + RunnerApi.PTransform pTransform, + Supplier processBundleInstructionId, + Map pCollections, + Map coders, + Multimap>> pCollectionIdsToConsumers, + Consumer addStartFunction, + Consumer addFinishFunction) throws IOException { + + BeamFnApi.Target target = BeamFnApi.Target.newBuilder() + .setPrimitiveTransformReference(pTransformId) + .setName(getOnlyElement(pTransform.getOutputsMap().keySet())) + .build(); + RunnerApi.Coder coderSpec = coders.get(pCollections.get( + getOnlyElement(pTransform.getOutputsMap().values())).getCoderId()); + Collection>> consumers = + (Collection) pCollectionIdsToConsumers.get( + getOnlyElement(pTransform.getOutputsMap().values())); + + BeamFnDataReadRunner runner = new BeamFnDataReadRunner<>( + pTransform.getSpec(), + processBundleInstructionId, + target, + coderSpec, + beamFnDataClient, + consumers); + addStartFunction.accept(runner::registerInputLocation); + addFinishFunction.accept(runner::blockTillReadFinishes); + return runner; + } + } private final BeamFnApi.ApiServiceDescriptor apiServiceDescriptor; private final Collection>> consumers; @@ -61,20 +119,20 @@ public class BeamFnDataReadRunner { private CompletableFuture readFuture; - public BeamFnDataReadRunner( + BeamFnDataReadRunner( RunnerApi.FunctionSpec functionSpec, Supplier processBundleInstructionIdSupplier, BeamFnApi.Target inputTarget, RunnerApi.Coder coderSpec, BeamFnDataClient beamFnDataClientFactory, - Map>>> outputMap) + Collection>> consumers) throws IOException { this.apiServiceDescriptor = functionSpec.getParameter().unpack(BeamFnApi.RemoteGrpcPort.class) .getApiServiceDescriptor(); this.inputTarget = inputTarget; this.processBundleInstructionIdSupplier = processBundleInstructionIdSupplier; this.beamFnDataClientFactory = beamFnDataClientFactory; - this.consumers = ImmutableList.copyOf(FluentIterable.concat(outputMap.values())); + this.consumers = consumers; @SuppressWarnings("unchecked") Coder> coder = diff --git a/sdks/java/harness/src/main/java/org/apache/beam/runners/core/BeamFnDataWriteRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/runners/core/BeamFnDataWriteRunner.java index a48df1210a478..c2a996b234571 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/runners/core/BeamFnDataWriteRunner.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/runners/core/BeamFnDataWriteRunner.java @@ -18,30 +18,91 @@ package org.apache.beam.runners.core; +import static com.google.common.collect.Iterables.getOnlyElement; + import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.auto.service.AutoService; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Multimap; import com.google.protobuf.BytesValue; import java.io.IOException; import java.util.Map; +import java.util.function.Consumer; import java.util.function.Supplier; import org.apache.beam.fn.harness.data.BeamFnDataClient; import org.apache.beam.fn.harness.fn.CloseableThrowingConsumer; +import org.apache.beam.fn.harness.fn.ThrowingConsumer; +import org.apache.beam.fn.harness.fn.ThrowingRunnable; import org.apache.beam.fn.v1.BeamFnApi; import org.apache.beam.runners.dataflow.util.CloudObject; import org.apache.beam.runners.dataflow.util.CloudObjects; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.common.runner.v1.RunnerApi; +import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.KV; /** - * Registers as a consumer with the Beam Fn Data API. Propagates and elements consumed to - * the the registered consumer. + * Registers as a consumer with the Beam Fn Data Api. Consumes elements and encodes them for + * transmission. * *

    Can be re-used serially across {@link org.apache.beam.fn.v1.BeamFnApi.ProcessBundleRequest}s. * For each request, call {@link #registerForOutput()} to start and call {@link #close()} to finish. */ public class BeamFnDataWriteRunner { + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + private static final String URN = "urn:org.apache.beam:sink:runner:0.1"; + + /** A registrar which provides a factory to handle writing to the Fn Api Data Plane. */ + @AutoService(PTransformRunnerFactory.Registrar.class) + public static class Registrar implements + PTransformRunnerFactory.Registrar { + + @Override + public Map getPTransformRunnerFactories() { + return ImmutableMap.of(URN, new Factory()); + } + } + + /** A factory for {@link BeamFnDataWriteRunner}s. */ + static class Factory + implements PTransformRunnerFactory> { + + @Override + public BeamFnDataWriteRunner createRunnerForPTransform( + PipelineOptions pipelineOptions, + BeamFnDataClient beamFnDataClient, + String pTransformId, + RunnerApi.PTransform pTransform, + Supplier processBundleInstructionId, + Map pCollections, + Map coders, + Multimap>> pCollectionIdsToConsumers, + Consumer addStartFunction, + Consumer addFinishFunction) throws IOException { + BeamFnApi.Target target = BeamFnApi.Target.newBuilder() + .setPrimitiveTransformReference(pTransformId) + .setName(getOnlyElement(pTransform.getInputsMap().keySet())) + .build(); + RunnerApi.Coder coderSpec = coders.get( + pCollections.get(getOnlyElement(pTransform.getInputsMap().values())).getCoderId()); + BeamFnDataWriteRunner runner = + new BeamFnDataWriteRunner<>( + pTransform.getSpec(), + processBundleInstructionId, + target, + coderSpec, + beamFnDataClient); + addStartFunction.accept(runner::registerForOutput); + pCollectionIdsToConsumers.put( + getOnlyElement(pTransform.getInputsMap().values()), + (ThrowingConsumer) + (ThrowingConsumer>) runner::consume); + addFinishFunction.accept(runner::close); + return runner; + } + } private final BeamFnApi.ApiServiceDescriptor apiServiceDescriptor; private final BeamFnApi.Target outputTarget; @@ -51,7 +112,7 @@ public class BeamFnDataWriteRunner { private CloseableThrowingConsumer> consumer; - public BeamFnDataWriteRunner( + BeamFnDataWriteRunner( RunnerApi.FunctionSpec functionSpec, Supplier processBundleInstructionIdSupplier, BeamFnApi.Target outputTarget, diff --git a/sdks/java/harness/src/main/java/org/apache/beam/runners/core/BoundedSourceRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/runners/core/BoundedSourceRunner.java index 4d530b8f79ff6..3338c3a8918b5 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/runners/core/BoundedSourceRunner.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/runners/core/BoundedSourceRunner.java @@ -18,14 +18,20 @@ package org.apache.beam.runners.core; -import com.google.common.collect.FluentIterable; +import com.google.auto.service.AutoService; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Multimap; import com.google.protobuf.BytesValue; import com.google.protobuf.InvalidProtocolBufferException; import java.io.IOException; import java.util.Collection; import java.util.Map; +import java.util.function.Consumer; +import java.util.function.Supplier; +import org.apache.beam.fn.harness.data.BeamFnDataClient; import org.apache.beam.fn.harness.fn.ThrowingConsumer; +import org.apache.beam.fn.harness.fn.ThrowingRunnable; import org.apache.beam.sdk.common.runner.v1.RunnerApi; import org.apache.beam.sdk.io.BoundedSource; import org.apache.beam.sdk.io.Source.Reader; @@ -34,21 +40,77 @@ import org.apache.beam.sdk.util.WindowedValue; /** - * A runner which creates {@link Reader}s for each {@link BoundedSource} and executes - * the {@link Reader}s read loop. + * A runner which creates {@link Reader}s for each {@link BoundedSource} sent as an input and + * executes the {@link Reader}s read loop. */ public class BoundedSourceRunner, OutputT> { + + private static final String URN = "urn:org.apache.beam:source:java:0.1"; + + /** A registrar which provides a factory to handle Java {@link BoundedSource}s. */ + @AutoService(PTransformRunnerFactory.Registrar.class) + public static class Registrar implements + PTransformRunnerFactory.Registrar { + + @Override + public Map getPTransformRunnerFactories() { + return ImmutableMap.of(URN, new Factory()); + } + } + + /** A factory for {@link BoundedSourceRunner}. */ + static class Factory, OutputT> + implements PTransformRunnerFactory> { + @Override + public BoundedSourceRunner createRunnerForPTransform( + PipelineOptions pipelineOptions, + BeamFnDataClient beamFnDataClient, + String pTransformId, + RunnerApi.PTransform pTransform, + Supplier processBundleInstructionId, + Map pCollections, + Map coders, + Multimap>> pCollectionIdsToConsumers, + Consumer addStartFunction, + Consumer addFinishFunction) { + + ImmutableList.Builder>> consumers = ImmutableList.builder(); + for (String pCollectionId : pTransform.getOutputsMap().values()) { + consumers.addAll(pCollectionIdsToConsumers.get(pCollectionId)); + } + + @SuppressWarnings({"rawtypes", "unchecked"}) + BoundedSourceRunner runner = new BoundedSourceRunner( + pipelineOptions, + pTransform.getSpec(), + consumers.build()); + + // TODO: Remove and replace with source being sent across gRPC port + addStartFunction.accept(runner::start); + + ThrowingConsumer runReadLoop = + (ThrowingConsumer>) runner::runReadLoop; + for (String pCollectionId : pTransform.getInputsMap().values()) { + pCollectionIdsToConsumers.put( + pCollectionId, + runReadLoop); + } + + return runner; + } + } + private final PipelineOptions pipelineOptions; private final RunnerApi.FunctionSpec definition; private final Collection>> consumers; - public BoundedSourceRunner( + BoundedSourceRunner( PipelineOptions pipelineOptions, RunnerApi.FunctionSpec definition, - Map>>> outputMap) { + Collection>> consumers) { this.pipelineOptions = pipelineOptions; this.definition = definition; - this.consumers = ImmutableList.copyOf(FluentIterable.concat(outputMap.values())); + this.consumers = consumers; } /** diff --git a/sdks/java/harness/src/main/java/org/apache/beam/runners/core/DoFnRunnerFactory.java b/sdks/java/harness/src/main/java/org/apache/beam/runners/core/DoFnRunnerFactory.java new file mode 100644 index 0000000000000..3c0b6ebcb408e --- /dev/null +++ b/sdks/java/harness/src/main/java/org/apache/beam/runners/core/DoFnRunnerFactory.java @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.core; + +import static com.google.common.base.Preconditions.checkArgument; + +import com.google.auto.service.AutoService; +import com.google.common.collect.Collections2; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableMultimap; +import com.google.common.collect.Multimap; +import com.google.protobuf.ByteString; +import com.google.protobuf.BytesValue; +import com.google.protobuf.InvalidProtocolBufferException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashSet; +import java.util.Map; +import java.util.Objects; +import java.util.function.Consumer; +import java.util.function.Supplier; +import org.apache.beam.fn.harness.data.BeamFnDataClient; +import org.apache.beam.fn.harness.fake.FakeStepContext; +import org.apache.beam.fn.harness.fn.ThrowingConsumer; +import org.apache.beam.fn.harness.fn.ThrowingRunnable; +import org.apache.beam.runners.core.DoFnRunners.OutputManager; +import org.apache.beam.runners.dataflow.util.DoFnInfo; +import org.apache.beam.sdk.common.runner.v1.RunnerApi; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.util.SerializableUtils; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.WindowingStrategy; + +/** + * Classes associated with converting {@link RunnerApi.PTransform}s to {@link DoFnRunner}s. + * + *

    TODO: Move DoFnRunners into SDK harness and merge the methods below into it removing this + * class. + */ +public class DoFnRunnerFactory { + + private static final String URN = "urn:org.apache.beam:dofn:java:0.1"; + + /** A registrar which provides a factory to handle Java {@link DoFn}s. */ + @AutoService(PTransformRunnerFactory.Registrar.class) + public static class Registrar implements + PTransformRunnerFactory.Registrar { + + @Override + public Map getPTransformRunnerFactories() { + return ImmutableMap.of(URN, new Factory()); + } + } + + /** A factory for {@link DoFnRunner}s. */ + static class Factory + implements PTransformRunnerFactory> { + + @Override + public DoFnRunner createRunnerForPTransform( + PipelineOptions pipelineOptions, + BeamFnDataClient beamFnDataClient, + String pTransformId, + RunnerApi.PTransform pTransform, + Supplier processBundleInstructionId, + Map pCollections, + Map coders, + Multimap>> pCollectionIdsToConsumers, + Consumer addStartFunction, + Consumer addFinishFunction) { + + // For every output PCollection, create a map from output name to Consumer + ImmutableMap.Builder>>> + outputMapBuilder = ImmutableMap.builder(); + for (Map.Entry entry : pTransform.getOutputsMap().entrySet()) { + outputMapBuilder.put( + entry.getKey(), + pCollectionIdsToConsumers.get(entry.getValue())); + } + ImmutableMap>>> outputMap = + outputMapBuilder.build(); + + // Get the DoFnInfo from the serialized blob. + ByteString serializedFn; + try { + serializedFn = pTransform.getSpec().getParameter().unpack(BytesValue.class).getValue(); + } catch (InvalidProtocolBufferException e) { + throw new IllegalArgumentException( + String.format("Unable to unwrap DoFn %s", pTransform.getSpec()), e); + } + DoFnInfo doFnInfo = + (DoFnInfo) + SerializableUtils.deserializeFromByteArray(serializedFn.toByteArray(), "DoFnInfo"); + + // Verify that the DoFnInfo tag to output map matches the output map on the PTransform. + checkArgument( + Objects.equals( + new HashSet<>(Collections2.transform(outputMap.keySet(), Long::parseLong)), + doFnInfo.getOutputMap().keySet()), + "Unexpected mismatch between transform output map %s and DoFnInfo output map %s.", + outputMap.keySet(), + doFnInfo.getOutputMap()); + + ImmutableMultimap.Builder, + ThrowingConsumer>> tagToOutput = + ImmutableMultimap.builder(); + for (Map.Entry> entry : doFnInfo.getOutputMap().entrySet()) { + @SuppressWarnings({"unchecked", "rawtypes"}) + Collection>> consumers = + (Collection) outputMap.get(Long.toString(entry.getKey())); + tagToOutput.putAll(entry.getValue(), consumers); + } + + @SuppressWarnings({"unchecked", "rawtypes"}) + Map, Collection>>> tagBasedOutputMap = + (Map) tagToOutput.build().asMap(); + + OutputManager outputManager = + new OutputManager() { + Map, Collection>>> tupleTagToOutput = + tagBasedOutputMap; + + @Override + public void output(TupleTag tag, WindowedValue output) { + try { + Collection>> consumers = + tupleTagToOutput.get(tag); + if (consumers == null) { + /* This is a normal case, e.g., if a DoFn has output but that output is not + * consumed. Drop the output. */ + return; + } + for (ThrowingConsumer> consumer : consumers) { + consumer.accept(output); + } + } catch (Throwable t) { + throw new RuntimeException(t); + } + } + }; + + @SuppressWarnings({"unchecked", "rawtypes", "deprecation"}) + DoFnRunner runner = + DoFnRunners.simpleRunner( + pipelineOptions, + (DoFn) doFnInfo.getDoFn(), + NullSideInputReader.empty(), /* TODO */ + outputManager, + (TupleTag) doFnInfo.getOutputMap().get(doFnInfo.getMainOutput()), + new ArrayList<>(doFnInfo.getOutputMap().values()), + new FakeStepContext(), + (WindowingStrategy) doFnInfo.getWindowingStrategy()); + + // Register the appropriate handlers. + addStartFunction.accept(runner::startBundle); + for (String pcollectionId : pTransform.getInputsMap().values()) { + pCollectionIdsToConsumers.put( + pcollectionId, + (ThrowingConsumer) (ThrowingConsumer>) runner::processElement); + } + addFinishFunction.accept(runner::finishBundle); + return runner; + } + } +} diff --git a/sdks/java/harness/src/main/java/org/apache/beam/runners/core/PTransformRunnerFactory.java b/sdks/java/harness/src/main/java/org/apache/beam/runners/core/PTransformRunnerFactory.java new file mode 100644 index 0000000000000..b325db4545cf5 --- /dev/null +++ b/sdks/java/harness/src/main/java/org/apache/beam/runners/core/PTransformRunnerFactory.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.core; + +import com.google.common.collect.Multimap; +import java.io.IOException; +import java.util.Map; +import java.util.function.Consumer; +import java.util.function.Supplier; +import org.apache.beam.fn.harness.data.BeamFnDataClient; +import org.apache.beam.fn.harness.fn.ThrowingConsumer; +import org.apache.beam.fn.harness.fn.ThrowingRunnable; +import org.apache.beam.sdk.common.runner.v1.RunnerApi; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.util.WindowedValue; + +/** + * A factory able to instantiate an appropriate handler for a given PTransform. + */ +public interface PTransformRunnerFactory { + + /** + * Creates and returns a handler for a given PTransform. Note that the handler must support + * processing multiple bundles. The handler will be discarded if an error is thrown during + * element processing, or during execution of start/finish. + * + * @param pipelineOptions Pipeline options + * @param beamFnDataClient + * @param pTransformId The id of the PTransform. + * @param pTransform The PTransform definition. + * @param processBundleInstructionId A supplier containing the active process bundle instruction + * id. + * @param pCollections A mapping from PCollection id to PCollection definition. + * @param coders A mapping from coder id to coder definition. + * @param pCollectionIdsToConsumers A mapping from PCollection id to a collection of consumers. + * Note that if this handler is a consumer, it should register itself within this multimap under + * the appropriate PCollection ids. Also note that all output consumers needed by this PTransform + * (based on the values of the {@link RunnerApi.PTransform#getOutputsMap()} will have already + * registered within this multimap. + * @param addStartFunction A consumer to register a start bundle handler with. + * @param addFinishFunction A consumer to register a finish bundle handler with. + */ + T createRunnerForPTransform( + PipelineOptions pipelineOptions, + BeamFnDataClient beamFnDataClient, + String pTransformId, + RunnerApi.PTransform pTransform, + Supplier processBundleInstructionId, + Map pCollections, + Map coders, + Multimap>> pCollectionIdsToConsumers, + Consumer addStartFunction, + Consumer addFinishFunction) throws IOException; + + /** + * A registrar which can return a mapping from {@link RunnerApi.FunctionSpec#getUrn()} to + * a factory capable of instantiating an appropriate handler. + */ + interface Registrar { + /** + * Returns a mapping from {@link RunnerApi.FunctionSpec#getUrn()} to a factory capable of + * instantiating an appropriate handler. + */ + Map getPTransformRunnerFactories(); + } +} diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java index 562f91fdd210e..a616b2c34a813 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java @@ -18,62 +18,28 @@ package org.apache.beam.fn.harness.control; -import static org.apache.beam.sdk.util.WindowedValue.timestampedValueInGlobalWindow; -import static org.apache.beam.sdk.util.WindowedValue.valueInGlobalWindow; import static org.hamcrest.Matchers.contains; -import static org.hamcrest.Matchers.containsInAnyOrder; -import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; -import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertThat; -import static org.junit.Assert.assertTrue; -import static org.mockito.Matchers.any; -import static org.mockito.Matchers.eq; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyNoMoreInteractions; -import static org.mockito.Mockito.verifyZeroInteractions; -import static org.mockito.Mockito.when; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.google.common.base.Suppliers; -import com.google.common.collect.HashMultimap; -import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import com.google.common.collect.Iterables; import com.google.common.collect.Multimap; -import com.google.protobuf.Any; -import com.google.protobuf.ByteString; -import com.google.protobuf.BytesValue; import com.google.protobuf.Message; import java.io.IOException; import java.util.ArrayList; import java.util.List; import java.util.Map; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Consumer; import java.util.function.Supplier; import org.apache.beam.fn.harness.data.BeamFnDataClient; -import org.apache.beam.fn.harness.fn.CloseableThrowingConsumer; import org.apache.beam.fn.harness.fn.ThrowingConsumer; import org.apache.beam.fn.harness.fn.ThrowingRunnable; import org.apache.beam.fn.v1.BeamFnApi; -import org.apache.beam.runners.dataflow.util.CloudObjects; -import org.apache.beam.runners.dataflow.util.DoFnInfo; -import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.coders.StringUtf8Coder; -import org.apache.beam.sdk.coders.VarLongCoder; +import org.apache.beam.runners.core.PTransformRunnerFactory; import org.apache.beam.sdk.common.runner.v1.RunnerApi; -import org.apache.beam.sdk.io.CountingSource; +import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.PipelineOptionsFactory; -import org.apache.beam.sdk.transforms.DoFn; -import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.transforms.windowing.GlobalWindow; -import org.apache.beam.sdk.util.SerializableUtils; import org.apache.beam.sdk.util.WindowedValue; -import org.apache.beam.sdk.values.KV; -import org.apache.beam.sdk.values.TupleTag; -import org.apache.beam.sdk.values.WindowingStrategy; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -82,55 +48,14 @@ import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; import org.mockito.Captor; -import org.mockito.Matchers; import org.mockito.Mock; import org.mockito.MockitoAnnotations; /** Tests for {@link ProcessBundleHandler}. */ @RunWith(JUnit4.class) public class ProcessBundleHandlerTest { - private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); - - private static final Coder> STRING_CODER = - WindowedValue.getFullCoder(StringUtf8Coder.of(), GlobalWindow.Coder.INSTANCE); - private static final String LONG_CODER_SPEC_ID = "998L"; - private static final String STRING_CODER_SPEC_ID = "999L"; - private static final BeamFnApi.RemoteGrpcPort REMOTE_PORT = BeamFnApi.RemoteGrpcPort.newBuilder() - .setApiServiceDescriptor(BeamFnApi.ApiServiceDescriptor.newBuilder() - .setId("58L") - .setUrl("TestUrl")) - .build(); - private static final RunnerApi.Coder LONG_CODER_SPEC; - private static final RunnerApi.Coder STRING_CODER_SPEC; - static { - try { - STRING_CODER_SPEC = RunnerApi.Coder.newBuilder() - .setSpec(RunnerApi.SdkFunctionSpec.newBuilder() - .setSpec(RunnerApi.FunctionSpec.newBuilder() - .setParameter(Any.pack(BytesValue.newBuilder().setValue(ByteString.copyFrom( - OBJECT_MAPPER.writeValueAsBytes(CloudObjects.asCloudObject(STRING_CODER)))) - .build()))) - .build()) - .build(); - LONG_CODER_SPEC = RunnerApi.Coder.newBuilder() - .setSpec(RunnerApi.SdkFunctionSpec.newBuilder() - .setSpec(RunnerApi.FunctionSpec.newBuilder() - .setParameter(Any.pack(BytesValue.newBuilder().setValue(ByteString.copyFrom( - OBJECT_MAPPER.writeValueAsBytes( - CloudObjects.asCloudObject(WindowedValue.getFullCoder(VarLongCoder.of(), - GlobalWindow.Coder.INSTANCE))))) - .build()))) - .build()) - .build(); - } catch (IOException e) { - throw new ExceptionInInitializerError(e); - } - } - private static final String DATA_INPUT_URN = "urn:org.apache.beam:source:runner:0.1"; private static final String DATA_OUTPUT_URN = "urn:org.apache.beam:sink:runner:0.1"; - private static final String JAVA_DO_FN_URN = "urn:org.apache.beam:dofn:java:0.1"; - private static final String JAVA_SOURCE_URN = "urn:org.apache.beam:source:java:0.1"; @Rule public ExpectedException thrown = ExpectedException.none(); @@ -161,16 +86,16 @@ public void testOrderOfStartAndFinishCalls() throws Exception { List transformsProcessed = new ArrayList<>(); List orderOfOperations = new ArrayList<>(); - ProcessBundleHandler handler = new ProcessBundleHandler( - PipelineOptionsFactory.create(), - fnApiRegistry::get, - beamFnDataClient) { + PTransformRunnerFactory startFinishRecorder = new PTransformRunnerFactory() { @Override - protected void createRunnerForPTransform( + public Object createRunnerForPTransform( + PipelineOptions pipelineOptions, + BeamFnDataClient beamFnDataClient, String pTransformId, RunnerApi.PTransform pTransform, Supplier processBundleInstructionId, Map pCollections, + Map coders, Multimap>> pCollectionIdsToConsumers, Consumer addStartFunction, Consumer addFinishFunction) throws IOException { @@ -182,8 +107,18 @@ protected void createRunnerForPTransform( () -> orderOfOperations.add("Start" + pTransformId)); addFinishFunction.accept( () -> orderOfOperations.add("Finish" + pTransformId)); + return null; } }; + + ProcessBundleHandler handler = new ProcessBundleHandler( + PipelineOptionsFactory.create(), + fnApiRegistry::get, + beamFnDataClient, + ImmutableMap.of( + DATA_INPUT_URN, startFinishRecorder, + DATA_OUTPUT_URN, startFinishRecorder)); + handler.processBundle(BeamFnApi.InstructionRequest.newBuilder() .setInstructionId("999L") .setProcessBundle( @@ -211,21 +146,25 @@ public void testCreatingPTransformExceptionsArePropagated() throws Exception { ProcessBundleHandler handler = new ProcessBundleHandler( PipelineOptionsFactory.create(), fnApiRegistry::get, - beamFnDataClient) { - @Override - protected void createRunnerForPTransform( - String pTransformId, - RunnerApi.PTransform pTransform, - Supplier processBundleInstructionId, - Map pCollections, - Multimap>> pCollectionIdsToConsumers, - Consumer addStartFunction, - Consumer addFinishFunction) throws IOException { - thrown.expect(IllegalStateException.class); - thrown.expectMessage("TestException"); - throw new IllegalStateException("TestException"); - } - }; + beamFnDataClient, + ImmutableMap.of(DATA_INPUT_URN, new PTransformRunnerFactory() { + @Override + public Object createRunnerForPTransform( + PipelineOptions pipelineOptions, + BeamFnDataClient beamFnDataClient, + String pTransformId, + RunnerApi.PTransform pTransform, + Supplier processBundleInstructionId, + Map pCollections, + Map coders, + Multimap>> pCollectionIdsToConsumers, + Consumer addStartFunction, + Consumer addFinishFunction) throws IOException { + thrown.expect(IllegalStateException.class); + thrown.expectMessage("TestException"); + throw new IllegalStateException("TestException"); + } + })); handler.processBundle( BeamFnApi.InstructionRequest.newBuilder().setProcessBundle( BeamFnApi.ProcessBundleRequest.newBuilder().setProcessBundleDescriptorReference("1L")) @@ -245,25 +184,26 @@ public void testPTransformStartExceptionsArePropagated() throws Exception { ProcessBundleHandler handler = new ProcessBundleHandler( PipelineOptionsFactory.create(), fnApiRegistry::get, - beamFnDataClient) { - @Override - protected void createRunnerForPTransform( - String pTransformId, - RunnerApi.PTransform pTransform, - Supplier processBundleInstructionId, - Map pCollections, - Multimap>> pCollectionIdsToConsumers, - Consumer addStartFunction, - Consumer addFinishFunction) throws IOException { - thrown.expect(IllegalStateException.class); - thrown.expectMessage("TestException"); - addStartFunction.accept(this::throwException); - } - - private void throwException() { - throw new IllegalStateException("TestException"); - } - }; + beamFnDataClient, + ImmutableMap.of(DATA_INPUT_URN, new PTransformRunnerFactory() { + @Override + public Object createRunnerForPTransform( + PipelineOptions pipelineOptions, + BeamFnDataClient beamFnDataClient, + String pTransformId, + RunnerApi.PTransform pTransform, + Supplier processBundleInstructionId, + Map pCollections, + Map coders, + Multimap>> pCollectionIdsToConsumers, + Consumer addStartFunction, + Consumer addFinishFunction) throws IOException { + thrown.expect(IllegalStateException.class); + thrown.expectMessage("TestException"); + addStartFunction.accept(ProcessBundleHandlerTest::throwException); + return null; + } + })); handler.processBundle( BeamFnApi.InstructionRequest.newBuilder().setProcessBundle( BeamFnApi.ProcessBundleRequest.newBuilder().setProcessBundleDescriptorReference("1L")) @@ -283,338 +223,33 @@ public void testPTransformFinishExceptionsArePropagated() throws Exception { ProcessBundleHandler handler = new ProcessBundleHandler( PipelineOptionsFactory.create(), fnApiRegistry::get, - beamFnDataClient) { - @Override - protected void createRunnerForPTransform( - String pTransformId, - RunnerApi.PTransform pTransform, - Supplier processBundleInstructionId, - Map pCollections, - Multimap>> pCollectionIdsToConsumers, - Consumer addStartFunction, - Consumer addFinishFunction) throws IOException { - thrown.expect(IllegalStateException.class); - thrown.expectMessage("TestException"); - addFinishFunction.accept(this::throwException); - } - - private void throwException() { - throw new IllegalStateException("TestException"); - } - }; + beamFnDataClient, + ImmutableMap.of(DATA_INPUT_URN, new PTransformRunnerFactory() { + @Override + public Object createRunnerForPTransform( + PipelineOptions pipelineOptions, + BeamFnDataClient beamFnDataClient, + String pTransformId, + RunnerApi.PTransform pTransform, + Supplier processBundleInstructionId, + Map pCollections, + Map coders, + Multimap>> pCollectionIdsToConsumers, + Consumer addStartFunction, + Consumer addFinishFunction) throws IOException { + thrown.expect(IllegalStateException.class); + thrown.expectMessage("TestException"); + addFinishFunction.accept(ProcessBundleHandlerTest::throwException); + return null; + } + })); handler.processBundle( BeamFnApi.InstructionRequest.newBuilder().setProcessBundle( BeamFnApi.ProcessBundleRequest.newBuilder().setProcessBundleDescriptorReference("1L")) .build()); } - private static class TestDoFn extends DoFn { - private static final TupleTag mainOutput = new TupleTag<>("mainOutput"); - private static final TupleTag additionalOutput = new TupleTag<>("output"); - - private BoundedWindow window; - - @ProcessElement - public void processElement(ProcessContext context, BoundedWindow window) { - context.output("MainOutput" + context.element()); - context.output(additionalOutput, "AdditionalOutput" + context.element()); - this.window = window; - } - - @FinishBundle - public void finishBundle(FinishBundleContext context) { - if (window != null) { - context.output("FinishBundle", window.maxTimestamp(), window); - window = null; - } - } - } - - /** - * Create a DoFn that has 3 inputs (inputATarget1, inputATarget2, inputBTarget) and 2 outputs - * (mainOutput, output). Validate that inputs are fed to the {@link DoFn} and that outputs - * are directed to the correct consumers. - */ - @Test - public void testCreatingAndProcessingDoFn() throws Exception { - Map fnApiRegistry = ImmutableMap.of(STRING_CODER_SPEC_ID, STRING_CODER_SPEC); - String pTransformId = "100L"; - String mainOutputId = "101"; - String additionalOutputId = "102"; - - DoFnInfo doFnInfo = DoFnInfo.forFn( - new TestDoFn(), - WindowingStrategy.globalDefault(), - ImmutableList.of(), - StringUtf8Coder.of(), - Long.parseLong(mainOutputId), - ImmutableMap.of( - Long.parseLong(mainOutputId), TestDoFn.mainOutput, - Long.parseLong(additionalOutputId), TestDoFn.additionalOutput)); - RunnerApi.FunctionSpec functionSpec = RunnerApi.FunctionSpec.newBuilder() - .setUrn(JAVA_DO_FN_URN) - .setParameter(Any.pack(BytesValue.newBuilder() - .setValue(ByteString.copyFrom(SerializableUtils.serializeToByteArray(doFnInfo))) - .build())) - .build(); - RunnerApi.PTransform pTransform = RunnerApi.PTransform.newBuilder() - .setSpec(functionSpec) - .putInputs("inputA", "inputATarget") - .putInputs("inputB", "inputBTarget") - .putOutputs(mainOutputId, "mainOutputTarget") - .putOutputs(additionalOutputId, "additionalOutputTarget") - .build(); - - List> mainOutputValues = new ArrayList<>(); - List> additionalOutputValues = new ArrayList<>(); - Multimap>> consumers = HashMultimap.create(); - consumers.put("mainOutputTarget", - (ThrowingConsumer) (ThrowingConsumer>) mainOutputValues::add); - consumers.put("additionalOutputTarget", - (ThrowingConsumer) (ThrowingConsumer>) additionalOutputValues::add); - List startFunctions = new ArrayList<>(); - List finishFunctions = new ArrayList<>(); - - ProcessBundleHandler handler = new ProcessBundleHandler( - PipelineOptionsFactory.create(), - fnApiRegistry::get, - beamFnDataClient); - handler.createRunnerForPTransform( - pTransformId, - pTransform, - Suppliers.ofInstance("57L")::get, - ImmutableMap.of(), - consumers, - startFunctions::add, - finishFunctions::add); - - Iterables.getOnlyElement(startFunctions).run(); - mainOutputValues.clear(); - - assertThat(consumers.keySet(), containsInAnyOrder( - "inputATarget", "inputBTarget", "mainOutputTarget", "additionalOutputTarget")); - - Iterables.getOnlyElement(consumers.get("inputATarget")).accept(valueInGlobalWindow("A1")); - Iterables.getOnlyElement(consumers.get("inputATarget")).accept(valueInGlobalWindow("A2")); - Iterables.getOnlyElement(consumers.get("inputATarget")).accept(valueInGlobalWindow("B")); - assertThat(mainOutputValues, contains( - valueInGlobalWindow("MainOutputA1"), - valueInGlobalWindow("MainOutputA2"), - valueInGlobalWindow("MainOutputB"))); - assertThat(additionalOutputValues, contains( - valueInGlobalWindow("AdditionalOutputA1"), - valueInGlobalWindow("AdditionalOutputA2"), - valueInGlobalWindow("AdditionalOutputB"))); - mainOutputValues.clear(); - additionalOutputValues.clear(); - - Iterables.getOnlyElement(finishFunctions).run(); - assertThat( - mainOutputValues, - contains( - timestampedValueInGlobalWindow("FinishBundle", GlobalWindow.INSTANCE.maxTimestamp()))); - mainOutputValues.clear(); - } - - @Test - public void testCreatingAndProcessingSource() throws Exception { - Map fnApiRegistry = ImmutableMap.of(LONG_CODER_SPEC_ID, LONG_CODER_SPEC); - List> outputValues = new ArrayList<>(); - - Multimap>> consumers = HashMultimap.create(); - consumers.put("outputPC", - (ThrowingConsumer) (ThrowingConsumer>) outputValues::add); - List startFunctions = new ArrayList<>(); - List finishFunctions = new ArrayList<>(); - - RunnerApi.FunctionSpec functionSpec = RunnerApi.FunctionSpec.newBuilder() - .setUrn(JAVA_SOURCE_URN) - .setParameter(Any.pack(BytesValue.newBuilder() - .setValue(ByteString.copyFrom( - SerializableUtils.serializeToByteArray(CountingSource.upTo(3)))) - .build())) - .build(); - - RunnerApi.PTransform pTransform = RunnerApi.PTransform.newBuilder() - .setSpec(functionSpec) - .putInputs("input", "inputPC") - .putOutputs("output", "outputPC") - .build(); - - ProcessBundleHandler handler = new ProcessBundleHandler( - PipelineOptionsFactory.create(), - fnApiRegistry::get, - beamFnDataClient); - - handler.createRunnerForPTransform( - "pTransformId", - pTransform, - Suppliers.ofInstance("57L")::get, - ImmutableMap.of(), - consumers, - startFunctions::add, - finishFunctions::add); - - // This is testing a deprecated way of running sources and should be removed - // once all source definitions are instead propagated along the input edge. - Iterables.getOnlyElement(startFunctions).run(); - assertThat(outputValues, contains( - valueInGlobalWindow(0L), - valueInGlobalWindow(1L), - valueInGlobalWindow(2L))); - outputValues.clear(); - - // Check that when passing a source along as an input, the source is processed. - assertThat(consumers.keySet(), containsInAnyOrder("inputPC", "outputPC")); - Iterables.getOnlyElement(consumers.get("inputPC")).accept( - valueInGlobalWindow(CountingSource.upTo(2))); - assertThat(outputValues, contains( - valueInGlobalWindow(0L), - valueInGlobalWindow(1L))); - - assertThat(finishFunctions, empty()); - } - - @Test - public void testCreatingAndProcessingBeamFnDataReadRunner() throws Exception { - Map fnApiRegistry = ImmutableMap.of(STRING_CODER_SPEC_ID, STRING_CODER_SPEC); - String bundleId = "57"; - String outputId = "101"; - - List> outputValues = new ArrayList<>(); - - Multimap>> consumers = HashMultimap.create(); - consumers.put("outputPC", - (ThrowingConsumer) (ThrowingConsumer>) outputValues::add); - List startFunctions = new ArrayList<>(); - List finishFunctions = new ArrayList<>(); - - RunnerApi.FunctionSpec functionSpec = RunnerApi.FunctionSpec.newBuilder() - .setUrn(DATA_INPUT_URN) - .setParameter(Any.pack(REMOTE_PORT)) - .build(); - - RunnerApi.PTransform pTransform = RunnerApi.PTransform.newBuilder() - .setSpec(functionSpec) - .putOutputs(outputId, "outputPC") - .build(); - - ProcessBundleHandler handler = new ProcessBundleHandler( - PipelineOptionsFactory.create(), - fnApiRegistry::get, - beamFnDataClient); - - handler.createRunnerForPTransform( - "pTransformId", - pTransform, - Suppliers.ofInstance(bundleId)::get, - ImmutableMap.of("outputPC", - RunnerApi.PCollection.newBuilder().setCoderId(STRING_CODER_SPEC_ID).build()), - consumers, - startFunctions::add, - finishFunctions::add); - - verifyZeroInteractions(beamFnDataClient); - - CompletableFuture completionFuture = new CompletableFuture<>(); - when(beamFnDataClient.forInboundConsumer(any(), any(), any(), any())) - .thenReturn(completionFuture); - Iterables.getOnlyElement(startFunctions).run(); - verify(beamFnDataClient).forInboundConsumer( - eq(REMOTE_PORT.getApiServiceDescriptor()), - eq(KV.of(bundleId, BeamFnApi.Target.newBuilder() - .setPrimitiveTransformReference("pTransformId") - .setName(outputId) - .build())), - eq(STRING_CODER), - consumerCaptor.capture()); - - consumerCaptor.getValue().accept(valueInGlobalWindow("TestValue")); - assertThat(outputValues, contains(valueInGlobalWindow("TestValue"))); - outputValues.clear(); - - assertThat(consumers.keySet(), containsInAnyOrder("outputPC")); - - completionFuture.complete(null); - Iterables.getOnlyElement(finishFunctions).run(); - - verifyNoMoreInteractions(beamFnDataClient); - } - - @Test - public void testCreatingAndProcessingBeamFnDataWriteRunner() throws Exception { - Map fnApiRegistry = ImmutableMap.of(STRING_CODER_SPEC_ID, STRING_CODER_SPEC); - String bundleId = "57L"; - String inputId = "100L"; - - Multimap>> consumers = HashMultimap.create(); - List startFunctions = new ArrayList<>(); - List finishFunctions = new ArrayList<>(); - - RunnerApi.FunctionSpec functionSpec = RunnerApi.FunctionSpec.newBuilder() - .setUrn(DATA_OUTPUT_URN) - .setParameter(Any.pack(REMOTE_PORT)) - .build(); - - RunnerApi.PTransform pTransform = RunnerApi.PTransform.newBuilder() - .setSpec(functionSpec) - .putInputs(inputId, "inputPC") - .build(); - - ProcessBundleHandler handler = new ProcessBundleHandler( - PipelineOptionsFactory.create(), - fnApiRegistry::get, - beamFnDataClient); - - handler.createRunnerForPTransform( - "ptransformId", - pTransform, - Suppliers.ofInstance(bundleId)::get, - ImmutableMap.of("inputPC", - RunnerApi.PCollection.newBuilder().setCoderId(STRING_CODER_SPEC_ID).build()), - consumers, - startFunctions::add, - finishFunctions::add); - - verifyZeroInteractions(beamFnDataClient); - - List> outputValues = new ArrayList<>(); - AtomicBoolean wasCloseCalled = new AtomicBoolean(); - CloseableThrowingConsumer> outputConsumer = - new CloseableThrowingConsumer>(){ - @Override - public void close() throws Exception { - wasCloseCalled.set(true); - } - - @Override - public void accept(WindowedValue t) throws Exception { - outputValues.add(t); - } - }; - - when(beamFnDataClient.forOutboundConsumer( - any(), - any(), - Matchers.>>any())).thenReturn(outputConsumer); - Iterables.getOnlyElement(startFunctions).run(); - verify(beamFnDataClient).forOutboundConsumer( - eq(REMOTE_PORT.getApiServiceDescriptor()), - eq(KV.of(bundleId, BeamFnApi.Target.newBuilder() - .setPrimitiveTransformReference("ptransformId") - .setName(inputId) - .build())), - eq(STRING_CODER)); - - assertThat(consumers.keySet(), containsInAnyOrder("inputPC")); - Iterables.getOnlyElement(consumers.get("inputPC")).accept(valueInGlobalWindow("TestValue")); - assertThat(outputValues, contains(valueInGlobalWindow("TestValue"))); - outputValues.clear(); - - assertFalse(wasCloseCalled.get()); - Iterables.getOnlyElement(finishFunctions).run(); - assertTrue(wasCloseCalled.get()); - - verifyNoMoreInteractions(beamFnDataClient); + private static void throwException() { + throw new IllegalStateException("TestException"); } } diff --git a/sdks/java/harness/src/test/java/org/apache/beam/runners/core/BeamFnDataReadRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/runners/core/BeamFnDataReadRunnerTest.java index 7e8ab1a2216d0..d6a476ef0803d 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/runners/core/BeamFnDataReadRunnerTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/runners/core/BeamFnDataReadRunnerTest.java @@ -20,41 +20,51 @@ import static org.apache.beam.sdk.util.WindowedValue.valueInGlobalWindow; import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.containsInAnyOrder; import static org.junit.Assert.assertThat; +import static org.junit.Assert.fail; import static org.mockito.Matchers.any; import static org.mockito.Matchers.eq; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.verifyZeroInteractions; import static org.mockito.Mockito.when; import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.base.Suppliers; +import com.google.common.collect.HashMultimap; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Iterables; +import com.google.common.collect.Multimap; import com.google.common.util.concurrent.Uninterruptibles; import com.google.protobuf.Any; import com.google.protobuf.ByteString; import com.google.protobuf.BytesValue; import java.io.IOException; import java.util.ArrayList; -import java.util.Collection; import java.util.List; -import java.util.Map; +import java.util.ServiceLoader; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import org.apache.beam.fn.harness.data.BeamFnDataClient; import org.apache.beam.fn.harness.fn.ThrowingConsumer; +import org.apache.beam.fn.harness.fn.ThrowingRunnable; import org.apache.beam.fn.harness.test.TestExecutors; import org.apache.beam.fn.harness.test.TestExecutors.TestExecutorService; import org.apache.beam.fn.v1.BeamFnApi; +import org.apache.beam.runners.core.PTransformRunnerFactory.Registrar; import org.apache.beam.runners.dataflow.util.CloudObjects; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.common.runner.v1.RunnerApi; +import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.transforms.windowing.GlobalWindow; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.KV; +import org.hamcrest.collection.IsMapContaining; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -68,15 +78,18 @@ /** Tests for {@link BeamFnDataReadRunner}. */ @RunWith(JUnit4.class) public class BeamFnDataReadRunnerTest { - private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); private static final BeamFnApi.RemoteGrpcPort PORT_SPEC = BeamFnApi.RemoteGrpcPort.newBuilder() .setApiServiceDescriptor(BeamFnApi.ApiServiceDescriptor.getDefaultInstance()).build(); private static final RunnerApi.FunctionSpec FUNCTION_SPEC = RunnerApi.FunctionSpec.newBuilder() .setParameter(Any.pack(PORT_SPEC)).build(); private static final Coder> CODER = WindowedValue.getFullCoder(StringUtf8Coder.of(), GlobalWindow.Coder.INSTANCE); + private static final String CODER_SPEC_ID = "string-coder-id"; private static final RunnerApi.Coder CODER_SPEC; + private static final String URN = "urn:org.apache.beam:source:runner:0.1"; + static { try { CODER_SPEC = RunnerApi.Coder.newBuilder().setSpec( @@ -98,7 +111,7 @@ public class BeamFnDataReadRunnerTest { .build(); @Rule public TestExecutorService executor = TestExecutors.from(Executors::newCachedThreadPool); - @Mock private BeamFnDataClient mockBeamFnDataClientFactory; + @Mock private BeamFnDataClient mockBeamFnDataClient; @Captor private ArgumentCaptor>> consumerCaptor; @Before @@ -106,33 +119,94 @@ public void setUp() { MockitoAnnotations.initMocks(this); } + @Test + public void testCreatingAndProcessingBeamFnDataReadRunner() throws Exception { + String bundleId = "57"; + String outputId = "101"; + + List> outputValues = new ArrayList<>(); + + Multimap>> consumers = HashMultimap.create(); + consumers.put("outputPC", + (ThrowingConsumer) (ThrowingConsumer>) outputValues::add); + List startFunctions = new ArrayList<>(); + List finishFunctions = new ArrayList<>(); + + RunnerApi.FunctionSpec functionSpec = RunnerApi.FunctionSpec.newBuilder() + .setUrn("urn:org.apache.beam:source:runner:0.1") + .setParameter(Any.pack(PORT_SPEC)) + .build(); + + RunnerApi.PTransform pTransform = RunnerApi.PTransform.newBuilder() + .setSpec(functionSpec) + .putOutputs(outputId, "outputPC") + .build(); + + new BeamFnDataReadRunner.Factory().createRunnerForPTransform( + PipelineOptionsFactory.create(), + mockBeamFnDataClient, + "pTransformId", + pTransform, + Suppliers.ofInstance(bundleId)::get, + ImmutableMap.of("outputPC", + RunnerApi.PCollection.newBuilder().setCoderId(CODER_SPEC_ID).build()), + ImmutableMap.of(CODER_SPEC_ID, CODER_SPEC), + consumers, + startFunctions::add, + finishFunctions::add); + + verifyZeroInteractions(mockBeamFnDataClient); + + CompletableFuture completionFuture = new CompletableFuture<>(); + when(mockBeamFnDataClient.forInboundConsumer(any(), any(), any(), any())) + .thenReturn(completionFuture); + Iterables.getOnlyElement(startFunctions).run(); + verify(mockBeamFnDataClient).forInboundConsumer( + eq(PORT_SPEC.getApiServiceDescriptor()), + eq(KV.of(bundleId, BeamFnApi.Target.newBuilder() + .setPrimitiveTransformReference("pTransformId") + .setName(outputId) + .build())), + eq(CODER), + consumerCaptor.capture()); + + consumerCaptor.getValue().accept(valueInGlobalWindow("TestValue")); + assertThat(outputValues, contains(valueInGlobalWindow("TestValue"))); + outputValues.clear(); + + assertThat(consumers.keySet(), containsInAnyOrder("outputPC")); + + completionFuture.complete(null); + Iterables.getOnlyElement(finishFunctions).run(); + + verifyNoMoreInteractions(mockBeamFnDataClient); + } + @Test public void testReuseForMultipleBundles() throws Exception { CompletableFuture bundle1Future = new CompletableFuture<>(); CompletableFuture bundle2Future = new CompletableFuture<>(); - when(mockBeamFnDataClientFactory.forInboundConsumer( + when(mockBeamFnDataClient.forInboundConsumer( any(), any(), any(), any())).thenReturn(bundle1Future).thenReturn(bundle2Future); List> valuesA = new ArrayList<>(); List> valuesB = new ArrayList<>(); - Map>>> outputMap = ImmutableMap.of( - "outA", ImmutableList.of(valuesA::add), - "outB", ImmutableList.of(valuesB::add)); + AtomicReference bundleId = new AtomicReference<>("0"); BeamFnDataReadRunner readRunner = new BeamFnDataReadRunner<>( FUNCTION_SPEC, bundleId::get, INPUT_TARGET, CODER_SPEC, - mockBeamFnDataClientFactory, - outputMap); + mockBeamFnDataClient, + ImmutableList.of(valuesA::add, valuesB::add)); // Process for bundle id 0 readRunner.registerInputLocation(); - verify(mockBeamFnDataClientFactory).forInboundConsumer( + verify(mockBeamFnDataClient).forInboundConsumer( eq(PORT_SPEC.getApiServiceDescriptor()), eq(KV.of(bundleId.get(), INPUT_TARGET)), eq(CODER), @@ -164,7 +238,7 @@ public void run() { valuesB.clear(); readRunner.registerInputLocation(); - verify(mockBeamFnDataClientFactory).forInboundConsumer( + verify(mockBeamFnDataClient).forInboundConsumer( eq(PORT_SPEC.getApiServiceDescriptor()), eq(KV.of(bundleId.get(), INPUT_TARGET)), eq(CODER), @@ -190,6 +264,18 @@ public void run() { assertThat(valuesA, contains(valueInGlobalWindow("GHI"), valueInGlobalWindow("JKL"))); assertThat(valuesB, contains(valueInGlobalWindow("GHI"), valueInGlobalWindow("JKL"))); - verifyNoMoreInteractions(mockBeamFnDataClientFactory); + verifyNoMoreInteractions(mockBeamFnDataClient); + } + + @Test + public void testRegistration() { + for (Registrar registrar : + ServiceLoader.load(Registrar.class)) { + if (registrar instanceof BeamFnDataReadRunner.Registrar) { + assertThat(registrar.getPTransformRunnerFactories(), IsMapContaining.hasKey(URN)); + return; + } + } + fail("Expected registrar not found."); } } diff --git a/sdks/java/harness/src/test/java/org/apache/beam/runners/core/BeamFnDataWriteRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/runners/core/BeamFnDataWriteRunnerTest.java index a3c874e545882..64d9ea6764758 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/runners/core/BeamFnDataWriteRunnerTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/runners/core/BeamFnDataWriteRunnerTest.java @@ -20,31 +20,48 @@ import static org.apache.beam.sdk.util.WindowedValue.valueInGlobalWindow; import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; import static org.mockito.Matchers.any; import static org.mockito.Matchers.eq; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.verifyZeroInteractions; import static org.mockito.Mockito.when; import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.base.Suppliers; +import com.google.common.collect.HashMultimap; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Iterables; +import com.google.common.collect.Multimap; import com.google.protobuf.Any; import com.google.protobuf.ByteString; import com.google.protobuf.BytesValue; import java.io.IOException; import java.util.ArrayList; +import java.util.List; +import java.util.ServiceLoader; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; import org.apache.beam.fn.harness.data.BeamFnDataClient; import org.apache.beam.fn.harness.fn.CloseableThrowingConsumer; +import org.apache.beam.fn.harness.fn.ThrowingConsumer; +import org.apache.beam.fn.harness.fn.ThrowingRunnable; import org.apache.beam.fn.v1.BeamFnApi; +import org.apache.beam.runners.core.PTransformRunnerFactory.Registrar; import org.apache.beam.runners.dataflow.util.CloudObjects; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.common.runner.v1.RunnerApi; +import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.transforms.windowing.GlobalWindow; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.KV; +import org.hamcrest.collection.IsMapContaining; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -56,15 +73,18 @@ /** Tests for {@link BeamFnDataWriteRunner}. */ @RunWith(JUnit4.class) public class BeamFnDataWriteRunnerTest { - private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); private static final BeamFnApi.RemoteGrpcPort PORT_SPEC = BeamFnApi.RemoteGrpcPort.newBuilder() .setApiServiceDescriptor(BeamFnApi.ApiServiceDescriptor.getDefaultInstance()).build(); private static final RunnerApi.FunctionSpec FUNCTION_SPEC = RunnerApi.FunctionSpec.newBuilder() .setParameter(Any.pack(PORT_SPEC)).build(); + private static final String CODER_ID = "string-coder-id"; private static final Coder> CODER = WindowedValue.getFullCoder(StringUtf8Coder.of(), GlobalWindow.Coder.INSTANCE); private static final RunnerApi.Coder CODER_SPEC; + private static final String URN = "urn:org.apache.beam:sink:runner:0.1"; + static { try { CODER_SPEC = RunnerApi.Coder.newBuilder().setSpec( @@ -85,18 +105,93 @@ public class BeamFnDataWriteRunnerTest { .setName("out") .build(); - @Mock private BeamFnDataClient mockBeamFnDataClientFactory; + @Mock private BeamFnDataClient mockBeamFnDataClient; @Before public void setUp() { MockitoAnnotations.initMocks(this); } + + @Test + public void testCreatingAndProcessingBeamFnDataWriteRunner() throws Exception { + String bundleId = "57L"; + String inputId = "100L"; + + Multimap>> consumers = HashMultimap.create(); + List startFunctions = new ArrayList<>(); + List finishFunctions = new ArrayList<>(); + + RunnerApi.FunctionSpec functionSpec = RunnerApi.FunctionSpec.newBuilder() + .setUrn("urn:org.apache.beam:sink:runner:0.1") + .setParameter(Any.pack(PORT_SPEC)) + .build(); + + RunnerApi.PTransform pTransform = RunnerApi.PTransform.newBuilder() + .setSpec(functionSpec) + .putInputs(inputId, "inputPC") + .build(); + + new BeamFnDataWriteRunner.Factory().createRunnerForPTransform( + PipelineOptionsFactory.create(), + mockBeamFnDataClient, + "ptransformId", + pTransform, + Suppliers.ofInstance(bundleId)::get, + ImmutableMap.of("inputPC", + RunnerApi.PCollection.newBuilder().setCoderId(CODER_ID).build()), + ImmutableMap.of(CODER_ID, CODER_SPEC), + consumers, + startFunctions::add, + finishFunctions::add); + + verifyZeroInteractions(mockBeamFnDataClient); + + List> outputValues = new ArrayList<>(); + AtomicBoolean wasCloseCalled = new AtomicBoolean(); + CloseableThrowingConsumer> outputConsumer = + new CloseableThrowingConsumer>(){ + @Override + public void close() throws Exception { + wasCloseCalled.set(true); + } + + @Override + public void accept(WindowedValue t) throws Exception { + outputValues.add(t); + } + }; + + when(mockBeamFnDataClient.forOutboundConsumer( + any(), + any(), + Matchers.>>any())).thenReturn(outputConsumer); + Iterables.getOnlyElement(startFunctions).run(); + verify(mockBeamFnDataClient).forOutboundConsumer( + eq(PORT_SPEC.getApiServiceDescriptor()), + eq(KV.of(bundleId, BeamFnApi.Target.newBuilder() + .setPrimitiveTransformReference("ptransformId") + .setName(inputId) + .build())), + eq(CODER)); + + assertThat(consumers.keySet(), containsInAnyOrder("inputPC")); + Iterables.getOnlyElement(consumers.get("inputPC")).accept(valueInGlobalWindow("TestValue")); + assertThat(outputValues, contains(valueInGlobalWindow("TestValue"))); + outputValues.clear(); + + assertFalse(wasCloseCalled.get()); + Iterables.getOnlyElement(finishFunctions).run(); + assertTrue(wasCloseCalled.get()); + + verifyNoMoreInteractions(mockBeamFnDataClient); + } + @Test public void testReuseForMultipleBundles() throws Exception { RecordingConsumer> valuesA = new RecordingConsumer<>(); RecordingConsumer> valuesB = new RecordingConsumer<>(); - when(mockBeamFnDataClientFactory.forOutboundConsumer( + when(mockBeamFnDataClient.forOutboundConsumer( any(), any(), Matchers.>>any())).thenReturn(valuesA).thenReturn(valuesB); @@ -106,12 +201,12 @@ public void testReuseForMultipleBundles() throws Exception { bundleId::get, OUTPUT_TARGET, CODER_SPEC, - mockBeamFnDataClientFactory); + mockBeamFnDataClient); // Process for bundle id 0 writeRunner.registerForOutput(); - verify(mockBeamFnDataClientFactory).forOutboundConsumer( + verify(mockBeamFnDataClient).forOutboundConsumer( eq(PORT_SPEC.getApiServiceDescriptor()), eq(KV.of(bundleId.get(), OUTPUT_TARGET)), eq(CODER)); @@ -129,7 +224,7 @@ public void testReuseForMultipleBundles() throws Exception { valuesB.clear(); writeRunner.registerForOutput(); - verify(mockBeamFnDataClientFactory).forOutboundConsumer( + verify(mockBeamFnDataClient).forOutboundConsumer( eq(PORT_SPEC.getApiServiceDescriptor()), eq(KV.of(bundleId.get(), OUTPUT_TARGET)), eq(CODER)); @@ -140,7 +235,7 @@ public void testReuseForMultipleBundles() throws Exception { assertTrue(valuesB.closed); assertThat(valuesB, contains(valueInGlobalWindow("GHI"), valueInGlobalWindow("JKL"))); - verifyNoMoreInteractions(mockBeamFnDataClientFactory); + verifyNoMoreInteractions(mockBeamFnDataClient); } private static class RecordingConsumer extends ArrayList @@ -158,6 +253,17 @@ public void accept(T t) throws Exception { } add(t); } + } + @Test + public void testRegistration() { + for (Registrar registrar : + ServiceLoader.load(Registrar.class)) { + if (registrar instanceof BeamFnDataWriteRunner.Registrar) { + assertThat(registrar.getPTransformRunnerFactories(), IsMapContaining.hasKey(URN)); + return; + } + } + fail("Expected registrar not found."); } } diff --git a/sdks/java/harness/src/test/java/org/apache/beam/runners/core/BoundedSourceRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/runners/core/BoundedSourceRunnerTest.java index d8ed121a70413..6c9a4cb7ff464 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/runners/core/BoundedSourceRunnerTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/runners/core/BoundedSourceRunnerTest.java @@ -20,25 +20,35 @@ import static org.apache.beam.sdk.util.WindowedValue.valueInGlobalWindow; import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.collection.IsEmptyCollection.empty; import static org.junit.Assert.assertThat; +import static org.junit.Assert.fail; +import com.google.common.base.Suppliers; +import com.google.common.collect.HashMultimap; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Iterables; +import com.google.common.collect.Multimap; import com.google.protobuf.Any; import com.google.protobuf.ByteString; import com.google.protobuf.BytesValue; import java.util.ArrayList; import java.util.Collection; import java.util.List; -import java.util.Map; +import java.util.ServiceLoader; import org.apache.beam.fn.harness.fn.ThrowingConsumer; +import org.apache.beam.fn.harness.fn.ThrowingRunnable; +import org.apache.beam.runners.core.PTransformRunnerFactory.Registrar; import org.apache.beam.sdk.common.runner.v1.RunnerApi; import org.apache.beam.sdk.io.BoundedSource; import org.apache.beam.sdk.io.CountingSource; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.util.SerializableUtils; import org.apache.beam.sdk.util.WindowedValue; +import org.hamcrest.Matchers; +import org.hamcrest.collection.IsMapContaining; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -46,27 +56,25 @@ /** Tests for {@link BoundedSourceRunner}. */ @RunWith(JUnit4.class) public class BoundedSourceRunnerTest { + + public static final String URN = "urn:org.apache.beam:source:java:0.1"; + @Test public void testRunReadLoopWithMultipleSources() throws Exception { - List> out1ValuesA = new ArrayList<>(); - List> out1ValuesB = new ArrayList<>(); + List> out1Values = new ArrayList<>(); List> out2Values = new ArrayList<>(); - Map>>> outputMap = ImmutableMap.of( - "out1", ImmutableList.of(out1ValuesA::add, out1ValuesB::add), - "out2", ImmutableList.of(out2Values::add)); + Collection>> consumers = + ImmutableList.of(out1Values::add, out2Values::add); - BoundedSourceRunner, Long> runner = - new BoundedSourceRunner<>( + BoundedSourceRunner, Long> runner = new BoundedSourceRunner<>( PipelineOptionsFactory.create(), RunnerApi.FunctionSpec.getDefaultInstance(), - outputMap); + consumers); runner.runReadLoop(valueInGlobalWindow(CountingSource.upTo(2))); runner.runReadLoop(valueInGlobalWindow(CountingSource.upTo(1))); - assertThat(out1ValuesA, - contains(valueInGlobalWindow(0L), valueInGlobalWindow(1L), valueInGlobalWindow(0L))); - assertThat(out1ValuesB, + assertThat(out1Values, contains(valueInGlobalWindow(0L), valueInGlobalWindow(1L), valueInGlobalWindow(0L))); assertThat(out2Values, contains(valueInGlobalWindow(0L), valueInGlobalWindow(1L), valueInGlobalWindow(0L))); @@ -74,40 +82,106 @@ public void testRunReadLoopWithMultipleSources() throws Exception { @Test public void testRunReadLoopWithEmptySource() throws Exception { - List> out1Values = new ArrayList<>(); - Map>>> outputMap = ImmutableMap.of( - "out1", ImmutableList.of(out1Values::add)); + List> outValues = new ArrayList<>(); + Collection>> consumers = + ImmutableList.of(outValues::add); - BoundedSourceRunner, Long> runner = - new BoundedSourceRunner<>( + BoundedSourceRunner, Long> runner = new BoundedSourceRunner<>( PipelineOptionsFactory.create(), RunnerApi.FunctionSpec.getDefaultInstance(), - outputMap); + consumers); runner.runReadLoop(valueInGlobalWindow(CountingSource.upTo(0))); - assertThat(out1Values, empty()); + assertThat(outValues, empty()); } @Test public void testStart() throws Exception { List> outValues = new ArrayList<>(); - Map>>> outputMap = ImmutableMap.of( - "out", ImmutableList.of(outValues::add)); + Collection>> consumers = + ImmutableList.of(outValues::add); ByteString encodedSource = ByteString.copyFrom(SerializableUtils.serializeToByteArray(CountingSource.upTo(3))); - BoundedSourceRunner, Long> runner = - new BoundedSourceRunner<>( + BoundedSourceRunner, Long> runner = new BoundedSourceRunner<>( PipelineOptionsFactory.create(), - RunnerApi.FunctionSpec.newBuilder().setParameter( + RunnerApi.FunctionSpec.newBuilder().setParameter( Any.pack(BytesValue.newBuilder().setValue(encodedSource).build())).build(), - outputMap); + consumers); runner.start(); assertThat(outValues, contains(valueInGlobalWindow(0L), valueInGlobalWindow(1L), valueInGlobalWindow(2L))); } + + @Test + public void testCreatingAndProcessingSourceFromFactory() throws Exception { + List> outputValues = new ArrayList<>(); + + Multimap>> consumers = HashMultimap.create(); + consumers.put("outputPC", + (ThrowingConsumer) (ThrowingConsumer>) outputValues::add); + List startFunctions = new ArrayList<>(); + List finishFunctions = new ArrayList<>(); + + RunnerApi.FunctionSpec functionSpec = RunnerApi.FunctionSpec.newBuilder() + .setUrn("urn:org.apache.beam:source:java:0.1") + .setParameter(Any.pack(BytesValue.newBuilder() + .setValue(ByteString.copyFrom( + SerializableUtils.serializeToByteArray(CountingSource.upTo(3)))) + .build())) + .build(); + + RunnerApi.PTransform pTransform = RunnerApi.PTransform.newBuilder() + .setSpec(functionSpec) + .putInputs("input", "inputPC") + .putOutputs("output", "outputPC") + .build(); + + new BoundedSourceRunner.Factory<>().createRunnerForPTransform( + PipelineOptionsFactory.create(), + null /* beamFnDataClient */, + "pTransformId", + pTransform, + Suppliers.ofInstance("57L")::get, + ImmutableMap.of(), + ImmutableMap.of(), + consumers, + startFunctions::add, + finishFunctions::add); + + // This is testing a deprecated way of running sources and should be removed + // once all source definitions are instead propagated along the input edge. + Iterables.getOnlyElement(startFunctions).run(); + assertThat(outputValues, contains( + valueInGlobalWindow(0L), + valueInGlobalWindow(1L), + valueInGlobalWindow(2L))); + outputValues.clear(); + + // Check that when passing a source along as an input, the source is processed. + assertThat(consumers.keySet(), containsInAnyOrder("inputPC", "outputPC")); + Iterables.getOnlyElement(consumers.get("inputPC")).accept( + valueInGlobalWindow(CountingSource.upTo(2))); + assertThat(outputValues, contains( + valueInGlobalWindow(0L), + valueInGlobalWindow(1L))); + + assertThat(finishFunctions, Matchers.empty()); + } + + @Test + public void testRegistration() { + for (Registrar registrar : + ServiceLoader.load(Registrar.class)) { + if (registrar instanceof BoundedSourceRunner.Registrar) { + assertThat(registrar.getPTransformRunnerFactories(), IsMapContaining.hasKey(URN)); + return; + } + } + fail("Expected registrar not found."); + } } diff --git a/sdks/java/harness/src/test/java/org/apache/beam/runners/core/DoFnRunnerFactoryTest.java b/sdks/java/harness/src/test/java/org/apache/beam/runners/core/DoFnRunnerFactoryTest.java new file mode 100644 index 0000000000000..62646ffa9710f --- /dev/null +++ b/sdks/java/harness/src/test/java/org/apache/beam/runners/core/DoFnRunnerFactoryTest.java @@ -0,0 +1,209 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.core; + +import static org.apache.beam.sdk.util.WindowedValue.timestampedValueInGlobalWindow; +import static org.apache.beam.sdk.util.WindowedValue.valueInGlobalWindow; +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.fail; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.base.Suppliers; +import com.google.common.collect.HashMultimap; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Iterables; +import com.google.common.collect.Multimap; +import com.google.protobuf.Any; +import com.google.protobuf.ByteString; +import com.google.protobuf.BytesValue; +import com.google.protobuf.Message; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.ServiceLoader; +import org.apache.beam.fn.harness.fn.ThrowingConsumer; +import org.apache.beam.fn.harness.fn.ThrowingRunnable; +import org.apache.beam.runners.core.PTransformRunnerFactory.Registrar; +import org.apache.beam.runners.dataflow.util.CloudObjects; +import org.apache.beam.runners.dataflow.util.DoFnInfo; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.common.runner.v1.RunnerApi; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; +import org.apache.beam.sdk.util.SerializableUtils; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.WindowingStrategy; +import org.hamcrest.collection.IsMapContaining; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link DoFnRunnerFactory}. */ +@RunWith(JUnit4.class) +public class DoFnRunnerFactoryTest { + + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + private static final Coder> STRING_CODER = + WindowedValue.getFullCoder(StringUtf8Coder.of(), GlobalWindow.Coder.INSTANCE); + private static final String STRING_CODER_SPEC_ID = "999L"; + private static final RunnerApi.Coder STRING_CODER_SPEC; + private static final String URN = "urn:org.apache.beam:dofn:java:0.1"; + + static { + try { + STRING_CODER_SPEC = RunnerApi.Coder.newBuilder() + .setSpec(RunnerApi.SdkFunctionSpec.newBuilder() + .setSpec(RunnerApi.FunctionSpec.newBuilder() + .setParameter(Any.pack(BytesValue.newBuilder().setValue(ByteString.copyFrom( + OBJECT_MAPPER.writeValueAsBytes(CloudObjects.asCloudObject(STRING_CODER)))) + .build()))) + .build()) + .build(); + } catch (IOException e) { + throw new ExceptionInInitializerError(e); + } + } + + private static class TestDoFn extends DoFn { + private static final TupleTag mainOutput = new TupleTag<>("mainOutput"); + private static final TupleTag additionalOutput = new TupleTag<>("output"); + + private BoundedWindow window; + + @ProcessElement + public void processElement(ProcessContext context, BoundedWindow window) { + context.output("MainOutput" + context.element()); + context.output(additionalOutput, "AdditionalOutput" + context.element()); + this.window = window; + } + + @FinishBundle + public void finishBundle(FinishBundleContext context) { + if (window != null) { + context.output("FinishBundle", window.maxTimestamp(), window); + window = null; + } + } + } + + /** + * Create a DoFn that has 3 inputs (inputATarget1, inputATarget2, inputBTarget) and 2 outputs + * (mainOutput, output). Validate that inputs are fed to the {@link DoFn} and that outputs + * are directed to the correct consumers. + */ + @Test + public void testCreatingAndProcessingDoFn() throws Exception { + Map fnApiRegistry = ImmutableMap.of(STRING_CODER_SPEC_ID, STRING_CODER_SPEC); + String pTransformId = "pTransformId"; + String mainOutputId = "101"; + String additionalOutputId = "102"; + + DoFnInfo doFnInfo = DoFnInfo.forFn( + new TestDoFn(), + WindowingStrategy.globalDefault(), + ImmutableList.of(), + StringUtf8Coder.of(), + Long.parseLong(mainOutputId), + ImmutableMap.of( + Long.parseLong(mainOutputId), TestDoFn.mainOutput, + Long.parseLong(additionalOutputId), TestDoFn.additionalOutput)); + RunnerApi.FunctionSpec functionSpec = RunnerApi.FunctionSpec.newBuilder() + .setUrn("urn:org.apache.beam:dofn:java:0.1") + .setParameter(Any.pack(BytesValue.newBuilder() + .setValue(ByteString.copyFrom(SerializableUtils.serializeToByteArray(doFnInfo))) + .build())) + .build(); + RunnerApi.PTransform pTransform = RunnerApi.PTransform.newBuilder() + .setSpec(functionSpec) + .putInputs("inputA", "inputATarget") + .putInputs("inputB", "inputBTarget") + .putOutputs(mainOutputId, "mainOutputTarget") + .putOutputs(additionalOutputId, "additionalOutputTarget") + .build(); + + List> mainOutputValues = new ArrayList<>(); + List> additionalOutputValues = new ArrayList<>(); + Multimap>> consumers = HashMultimap.create(); + consumers.put("mainOutputTarget", + (ThrowingConsumer) (ThrowingConsumer>) mainOutputValues::add); + consumers.put("additionalOutputTarget", + (ThrowingConsumer) (ThrowingConsumer>) additionalOutputValues::add); + List startFunctions = new ArrayList<>(); + List finishFunctions = new ArrayList<>(); + + new DoFnRunnerFactory.Factory<>().createRunnerForPTransform( + PipelineOptionsFactory.create(), + null /* beamFnDataClient */, + pTransformId, + pTransform, + Suppliers.ofInstance("57L")::get, + ImmutableMap.of(), + ImmutableMap.of(), + consumers, + startFunctions::add, + finishFunctions::add); + + Iterables.getOnlyElement(startFunctions).run(); + mainOutputValues.clear(); + + assertThat(consumers.keySet(), containsInAnyOrder( + "inputATarget", "inputBTarget", "mainOutputTarget", "additionalOutputTarget")); + + Iterables.getOnlyElement(consumers.get("inputATarget")).accept(valueInGlobalWindow("A1")); + Iterables.getOnlyElement(consumers.get("inputATarget")).accept(valueInGlobalWindow("A2")); + Iterables.getOnlyElement(consumers.get("inputATarget")).accept(valueInGlobalWindow("B")); + assertThat(mainOutputValues, contains( + valueInGlobalWindow("MainOutputA1"), + valueInGlobalWindow("MainOutputA2"), + valueInGlobalWindow("MainOutputB"))); + assertThat(additionalOutputValues, contains( + valueInGlobalWindow("AdditionalOutputA1"), + valueInGlobalWindow("AdditionalOutputA2"), + valueInGlobalWindow("AdditionalOutputB"))); + mainOutputValues.clear(); + additionalOutputValues.clear(); + + Iterables.getOnlyElement(finishFunctions).run(); + assertThat( + mainOutputValues, + contains( + timestampedValueInGlobalWindow("FinishBundle", GlobalWindow.INSTANCE.maxTimestamp()))); + mainOutputValues.clear(); + } + + @Test + public void testRegistration() { + for (Registrar registrar : + ServiceLoader.load(Registrar.class)) { + if (registrar instanceof DoFnRunnerFactory.Registrar) { + assertThat(registrar.getPTransformRunnerFactories(), IsMapContaining.hasKey(URN)); + return; + } + } + fail("Expected registrar not found."); + } +} From 9a6a277cea4582f0a64eac97730cb85af5ba352b Mon Sep 17 00:00:00 2001 From: Eugene Kirpichov Date: Thu, 8 Jun 2017 16:54:12 -0700 Subject: [PATCH 049/200] Tests for reading windowed side input from resumed SDF call --- .../sdk/transforms/SplittableDoFnTest.java | 145 +++++++++++++++++- 1 file changed, 140 insertions(+), 5 deletions(-) diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/SplittableDoFnTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/SplittableDoFnTest.java index 02a44d2b907e1..646d8d310bf8a 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/SplittableDoFnTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/SplittableDoFnTest.java @@ -18,10 +18,13 @@ package org.apache.beam.sdk.transforms; import static com.google.common.base.Preconditions.checkState; +import static org.hamcrest.Matchers.greaterThan; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; +import com.google.common.collect.Ordering; import java.io.Serializable; import java.util.ArrayList; import java.util.Arrays; @@ -29,6 +32,7 @@ import org.apache.beam.sdk.coders.BigEndianIntegerCoder; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.coders.VarIntCoder; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.testing.TestStream; @@ -182,6 +186,12 @@ public void testPairWithIndexWindowedTimestamped() { private static class SDFWithMultipleOutputsPerBlock extends DoFn { private static final int MAX_INDEX = 98765; + private final TupleTag numProcessCalls; + + private SDFWithMultipleOutputsPerBlock(TupleTag numProcessCalls) { + this.numProcessCalls = numProcessCalls; + } + private static int snapToNextBlock(int index, int[] blockStarts) { for (int i = 1; i < blockStarts.length; ++i) { if (index > blockStarts[i - 1] && index <= blockStarts[i]) { @@ -195,6 +205,7 @@ private static int snapToNextBlock(int index, int[] blockStarts) { public void processElement(ProcessContext c, OffsetRangeTracker tracker) { int[] blockStarts = {-1, 0, 12, 123, 1234, 12345, 34567, MAX_INDEX}; int trueStart = snapToNextBlock((int) tracker.currentRestriction().getFrom(), blockStarts); + c.output(numProcessCalls, 1); for (int i = trueStart; tracker.tryClaim(blockStarts[i]); ++i) { for (int index = blockStarts[i]; index < blockStarts[i + 1]; ++index) { c.output(index); @@ -211,10 +222,26 @@ public OffsetRange getInitialRange(String element) { @Test @Category({ValidatesRunner.class, UsesSplittableParDo.class}) public void testOutputAfterCheckpoint() throws Exception { - PCollection outputs = p.apply(Create.of("foo")) - .apply(ParDo.of(new SDFWithMultipleOutputsPerBlock())); - PAssert.thatSingleton(outputs.apply(Count.globally())) + TupleTag main = new TupleTag<>(); + TupleTag numProcessCalls = new TupleTag<>(); + PCollectionTuple outputs = + p.apply(Create.of("foo")) + .apply( + ParDo.of(new SDFWithMultipleOutputsPerBlock(numProcessCalls)) + .withOutputTags(main, TupleTagList.of(numProcessCalls))); + PAssert.thatSingleton(outputs.get(main).apply(Count.globally())) .isEqualTo((long) SDFWithMultipleOutputsPerBlock.MAX_INDEX); + // Verify that more than 1 process() call was involved, i.e. that there was checkpointing. + PAssert.thatSingleton( + outputs.get(numProcessCalls).setCoder(VarIntCoder.of()).apply(Sum.integersGlobally())) + .satisfies( + new SerializableFunction() { + @Override + public Void apply(Integer input) { + assertThat(input, greaterThan(1)); + return null; + } + }); p.run(); } @@ -287,9 +314,117 @@ public void testWindowedSideInput() throws Exception { PAssert.that(res).containsInAnyOrder("a:0", "a:1", "a:2", "a:3", "b:4", "b:5", "b:6", "b:7"); p.run(); + } + + @BoundedPerElement + private static class SDFWithMultipleOutputsPerBlockAndSideInput + extends DoFn> { + private static final int MAX_INDEX = 98765; + private final PCollectionView sideInput; + private final TupleTag numProcessCalls; + + public SDFWithMultipleOutputsPerBlockAndSideInput( + PCollectionView sideInput, TupleTag numProcessCalls) { + this.sideInput = sideInput; + this.numProcessCalls = numProcessCalls; + } + + private static int snapToNextBlock(int index, int[] blockStarts) { + for (int i = 1; i < blockStarts.length; ++i) { + if (index > blockStarts[i - 1] && index <= blockStarts[i]) { + return i; + } + } + throw new IllegalStateException("Shouldn't get here"); + } + + @ProcessElement + public void processElement(ProcessContext c, OffsetRangeTracker tracker) { + int[] blockStarts = {-1, 0, 12, 123, 1234, 12345, 34567, MAX_INDEX}; + int trueStart = snapToNextBlock((int) tracker.currentRestriction().getFrom(), blockStarts); + c.output(numProcessCalls, 1); + for (int i = trueStart; tracker.tryClaim(blockStarts[i]); ++i) { + for (int index = blockStarts[i]; index < blockStarts[i + 1]; ++index) { + c.output(KV.of(c.sideInput(sideInput) + ":" + c.element(), index)); + } + } + } + + @GetInitialRestriction + public OffsetRange getInitialRange(Integer element) { + return new OffsetRange(0, MAX_INDEX); + } + } + + @Test + @Category({ + ValidatesRunner.class, + UsesSplittableParDo.class, + UsesSplittableParDoWithWindowedSideInputs.class + }) + public void testWindowedSideInputWithCheckpoints() throws Exception { + PCollection mainInput = + p.apply("main", + Create.timestamped( + TimestampedValue.of(0, new Instant(0)), + TimestampedValue.of(1, new Instant(1)), + TimestampedValue.of(2, new Instant(2)), + TimestampedValue.of(3, new Instant(3)))) + .apply("window 1", Window.into(FixedWindows.of(Duration.millis(1)))); + + PCollectionView sideInput = + p.apply("side", + Create.timestamped( + TimestampedValue.of("a", new Instant(0)), + TimestampedValue.of("b", new Instant(2)))) + .apply("window 2", Window.into(FixedWindows.of(Duration.millis(2)))) + .apply("singleton", View.asSingleton()); + + TupleTag> main = new TupleTag<>(); + TupleTag numProcessCalls = new TupleTag<>(); + PCollectionTuple res = + mainInput.apply( + ParDo.of(new SDFWithMultipleOutputsPerBlockAndSideInput(sideInput, numProcessCalls)) + .withSideInputs(sideInput) + .withOutputTags(main, TupleTagList.of(numProcessCalls))); + PCollection>> grouped = + res.get(main).apply(GroupByKey.create()); + + PAssert.that(grouped.apply(Keys.create())) + .containsInAnyOrder("a:0", "a:1", "b:2", "b:3"); + PAssert.that(grouped) + .satisfies( + new SerializableFunction>>, Void>() { + @Override + public Void apply(Iterable>> input) { + List expected = new ArrayList<>(); + for (int i = 0; i < SDFWithMultipleOutputsPerBlockAndSideInput.MAX_INDEX; ++i) { + expected.add(i); + } + for (KV> kv : input) { + assertEquals(expected, Ordering.natural().sortedCopy(kv.getValue())); + } + return null; + } + }); + + // Verify that more than 1 process() call was involved, i.e. that there was checkpointing. + PAssert.thatSingleton( + res.get(numProcessCalls) + .setCoder(VarIntCoder.of()) + .apply(Sum.integersGlobally().withoutDefaults())) + // This should hold in all windows, but verifying a particular window is sufficient. + .inOnlyPane(new IntervalWindow(new Instant(0), new Instant(1))) + .satisfies( + new SerializableFunction() { + @Override + public Void apply(Integer input) { + assertThat(input, greaterThan(1)); + return null; + } + }); + p.run(); - // TODO: also add test coverage when the SDF checkpoints - the resumed call should also - // properly access side inputs. // TODO: also test coverage when some of the windows of the side input are not ready. } From 4519681ec3d2fb723a514128d7c9c531c8de9dbf Mon Sep 17 00:00:00 2001 From: Charles Chen Date: Thu, 15 Jun 2017 15:27:18 -0700 Subject: [PATCH 050/200] Populate PBegin input when decoding from Runner API --- sdks/python/apache_beam/pipeline.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/pipeline.py b/sdks/python/apache_beam/pipeline.py index ab77956a0c1a2..d84a2b7b59cce 100644 --- a/sdks/python/apache_beam/pipeline.py +++ b/sdks/python/apache_beam/pipeline.py @@ -515,7 +515,18 @@ def from_runner_api(proto, runner, options): p.applied_labels = set([ t.unique_name for t in proto.components.transforms.values()]) for id in proto.components.pcollections: - context.pcollections.get_by_id(id).pipeline = p + pcollection = context.pcollections.get_by_id(id) + pcollection.pipeline = p + + # Inject PBegin input where necessary. + from apache_beam.io.iobase import Read + from apache_beam.transforms.core import Create + has_pbegin = [Read, Create] + for id in proto.components.transforms: + transform = context.transforms.get_by_id(id) + if not transform.inputs and transform.transform.__class__ in has_pbegin: + transform.inputs = (pvalue.PBegin(p),) + return p From d7715b78a1f57e6088aabdbc6979cb5809269a97 Mon Sep 17 00:00:00 2001 From: Luke Cwik Date: Wed, 31 May 2017 08:54:16 -0700 Subject: [PATCH 051/200] [BEAM-1348] Model the Fn State Api as per https://s.apache.org/beam-fn-state-api-and-bundle-processing --- .../fn-api/src/main/proto/beam_fn_api.proto | 97 ++++++++++--------- 1 file changed, 52 insertions(+), 45 deletions(-) diff --git a/sdks/common/fn-api/src/main/proto/beam_fn_api.proto b/sdks/common/fn-api/src/main/proto/beam_fn_api.proto index 95fe0424f3ae1..8162bc50598b6 100644 --- a/sdks/common/fn-api/src/main/proto/beam_fn_api.proto +++ b/sdks/common/fn-api/src/main/proto/beam_fn_api.proto @@ -314,9 +314,9 @@ message ProcessBundleRequest { // instantiated and executed by the SDK harness. string process_bundle_descriptor_reference = 1; - // (Optional) A list of cache tokens that can be used by an SDK to cache - // data looked up using the State API across multiple bundles. - repeated CacheToken cache_tokens = 2; + // (Optional) A list of cache tokens that can be used by an SDK to reuse + // cached data returned by the State API across multiple bundles. + repeated bytes cache_tokens = 2; } // Stable @@ -539,6 +539,10 @@ message StateResponse { // failed. string error = 2; + // (Optional) If this is specified, then the result of this state request + // can be cached using the supplied token. + bytes cache_token = 3; + // A corresponding response matching the request will be populated. oneof response { // A response to getting state. @@ -564,49 +568,44 @@ service BeamFnState { ) {} } -message CacheToken { - // (Required) Represents the function spec and tag associated with this state - // key. - // - // By combining the function_spec_reference with the tag representing: - // * the input, we refer to the iterable portion of a large GBK - // * the side input, we refer to the side input - // * the user state, we refer to user state - Target target = 1; - - // (Required) An opaque identifier. - bytes token = 2; -} - message StateKey { - // (Required) Represents the function spec and tag associated with this state - // key. - // - // By combining the function_spec_reference with the tag representing: - // * the input, we refer to fetching the iterable portion of a large GBK - // * the side input, we refer to fetching the side input - // * the user state, we refer to fetching user state - Target target = 1; - - // (Required) The bytes of the window which this state request is for encoded - // in the nested context. - bytes window = 2; + message Runner { + // (Required) Opaque information supplied by the runner. Used to support + // remote references. + bytes key = 1; + } - // (Required) The user key encoded in the nested context. - bytes key = 3; -} + message MultimapSideInput { + // (Required) The id of the PTransform containing a side input. + string ptransform_id = 1; + // (Required) The id of the side input. + string side_input_id = 2; + // (Required) The window (after mapping the currently executing elements + // window into the side input windows domain) encoded in a nested context. + bytes window = 3; + // (Required) The key encoded in a nested context. + bytes key = 4; + } -// A logical byte stream which can be continued using the state API. -message ContinuableStream { - // (Optional) If specified, represents a token which can be used with the - // state API to get the next chunk of this logical byte stream. The end of - // the logical byte stream is signalled by this field being unset. - bytes continuation_token = 1; + message BagUserState { + // (Required) The id of the PTransform containing user state. + string ptransform_id = 1; + // (Required) The id of the user state. + string user_state_id = 2; + // (Required) The window encoded in a nested context. + bytes window = 3; + // (Required) The key of the currently executing element encoded in a + // nested context. + bytes key = 4; + } - // Represents a part of a logical byte stream. Elements within - // the logical byte stream are encoded in the nested context and - // concatenated together. - bytes data = 2; + // (Required) One of the following state keys must be set. + oneof type { + Runner runner = 1; + MultimapSideInput multimap_side_input = 2; + BagUserState bag_user_state = 3; + // TODO: represent a state key for user map state + } } // A request to get state. @@ -619,10 +618,18 @@ message StateGetRequest { bytes continuation_token = 1; } -// A response to get state. +// A response to get state representing a logical byte stream which can be +// continued using the state API. message StateGetResponse { - // (Required) The response containing a continuable logical byte stream. - ContinuableStream stream = 1; + // (Optional) If specified, represents a token which can be used with the + // state API to get the next chunk of this logical byte stream. The end of + // the logical byte stream is signalled by this field being unset. + bytes continuation_token = 1; + + // Represents a part of a logical byte stream. Elements within + // the logical byte stream are encoded in the nested context and + // concatenated together. + bytes data = 2; } // A request to append state. From bcd439640a635e635cc8686a4e4fcaa94800cb37 Mon Sep 17 00:00:00 2001 From: Vikas Kedigehalli Date: Mon, 12 Jun 2017 22:32:39 -0700 Subject: [PATCH 052/200] Python streaming Create override as a composite of Impulse and a DoFn --- .../runners/dataflow/dataflow_runner.py | 34 +++++++++ .../runners/dataflow/dataflow_runner_test.py | 18 +++++ .../dataflow/native_io/streaming_create.py | 72 +++++++++++++++++++ .../runners/dataflow/ptransform_overrides.py | 52 ++++++++++++++ sdks/python/apache_beam/transforms/core.py | 11 +-- 5 files changed, 183 insertions(+), 4 deletions(-) create mode 100644 sdks/python/apache_beam/runners/dataflow/native_io/streaming_create.py create mode 100644 sdks/python/apache_beam/runners/dataflow/ptransform_overrides.py diff --git a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py b/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py index cc9274ec40c78..ce46ea9a23f7a 100644 --- a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py +++ b/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py @@ -39,6 +39,7 @@ from apache_beam.runners.dataflow.internal.clients import dataflow as dataflow_api from apache_beam.runners.dataflow.internal.names import PropertyNames from apache_beam.runners.dataflow.internal.names import TransformNames +from apache_beam.runners.dataflow.ptransform_overrides import CreatePTransformOverride from apache_beam.runners.runner import PValueCache from apache_beam.runners.runner import PipelineResult from apache_beam.runners.runner import PipelineRunner @@ -69,6 +70,16 @@ class DataflowRunner(PipelineRunner): BATCH_ENVIRONMENT_MAJOR_VERSION = '6' STREAMING_ENVIRONMENT_MAJOR_VERSION = '1' + # A list of PTransformOverride objects to be applied before running a pipeline + # using DataflowRunner. + # Currently this only works for overrides where the input and output types do + # not change. + # For internal SDK use only. This should not be updated by Beam pipeline + # authors. + _PTRANSFORM_OVERRIDES = [ + CreatePTransformOverride(), + ] + def __init__(self, cache=None): # Cache of CloudWorkflowStep protos generated while the runner # "executes" a pipeline. @@ -229,6 +240,9 @@ def run(self, pipeline): 'Google Cloud Dataflow runner not available, ' 'please install apache_beam[gcp]') + # Performing configured PTransform overrides. + pipeline.replace_all(DataflowRunner._PTRANSFORM_OVERRIDES) + # Add setup_options for all the BeamPlugin imports setup_options = pipeline._options.view_as(SetupOptions) plugins = BeamPlugin.get_all_plugin_paths() @@ -370,6 +384,26 @@ def _add_singleton_step(self, label, full_label, tag, input_step): PropertyNames.OUTPUT_NAME: PropertyNames.OUT}]) return step + def run_Impulse(self, transform_node): + standard_options = ( + transform_node.outputs[None].pipeline._options.view_as(StandardOptions)) + if standard_options.streaming: + step = self._add_step( + TransformNames.READ, transform_node.full_label, transform_node) + step.add_property(PropertyNames.FORMAT, 'pubsub') + step.add_property(PropertyNames.PUBSUB_SUBSCRIPTION, '_starting_signal/') + + step.encoding = self._get_encoded_output_coder(transform_node) + step.add_property( + PropertyNames.OUTPUT_INFO, + [{PropertyNames.USER_NAME: ( + '%s.%s' % ( + transform_node.full_label, PropertyNames.OUT)), + PropertyNames.ENCODING: step.encoding, + PropertyNames.OUTPUT_NAME: PropertyNames.OUT}]) + else: + ValueError('Impulse source for batch pipelines has not been defined.') + def run_Flatten(self, transform_node): step = self._add_step(TransformNames.FLATTEN, transform_node.full_label, transform_node) diff --git a/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py b/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py index 74fd01df7bc23..819d4713c11cd 100644 --- a/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py +++ b/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py @@ -111,6 +111,24 @@ def test_remote_runner_translation(self): remote_runner.job = apiclient.Job(p._options) super(DataflowRunner, remote_runner).run(p) + def test_streaming_create_translation(self): + remote_runner = DataflowRunner() + self.default_properties.append("--streaming") + p = Pipeline(remote_runner, PipelineOptions(self.default_properties)) + p | ptransform.Create([1]) # pylint: disable=expression-not-assigned + remote_runner.job = apiclient.Job(p._options) + # Performing configured PTransform overrides here. + p.replace_all(DataflowRunner._PTRANSFORM_OVERRIDES) + super(DataflowRunner, remote_runner).run(p) + job_dict = json.loads(str(remote_runner.job)) + self.assertEqual(len(job_dict[u'steps']), 2) + + self.assertEqual(job_dict[u'steps'][0][u'kind'], u'ParallelRead') + self.assertEqual( + job_dict[u'steps'][0][u'properties'][u'pubsub_subscription'], + '_starting_signal/') + self.assertEqual(job_dict[u'steps'][1][u'kind'], u'ParallelDo') + def test_remote_runner_display_data(self): remote_runner = DataflowRunner() p = Pipeline(remote_runner, diff --git a/sdks/python/apache_beam/runners/dataflow/native_io/streaming_create.py b/sdks/python/apache_beam/runners/dataflow/native_io/streaming_create.py new file mode 100644 index 0000000000000..8c6c8d6d52998 --- /dev/null +++ b/sdks/python/apache_beam/runners/dataflow/native_io/streaming_create.py @@ -0,0 +1,72 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Create transform for streaming.""" + +from apache_beam import pvalue +from apache_beam import DoFn +from apache_beam import ParDo +from apache_beam import PTransform +from apache_beam import Windowing +from apache_beam.transforms.window import GlobalWindows + + +class StreamingCreate(PTransform): + """A specialized implementation for ``Create`` transform in streaming mode. + + Note: There is no unbounded source API in python to wrap the Create source, + so we map this to composite of Impulse primitive and an SDF. + """ + + def __init__(self, values, coder): + self.coder = coder + self.encoded_values = map(coder.encode, values) + + class DecodeAndEmitDoFn(DoFn): + """A DoFn which stores encoded versions of elements. + + It also stores a Coder to decode and emit those elements. + TODO: BEAM-2422 - Make this a SplittableDoFn. + """ + + def __init__(self, encoded_values, coder): + self.encoded_values = encoded_values + self.coder = coder + + def process(self, unused_element): + for encoded_value in self.encoded_values: + yield self.coder.decode(encoded_value) + + class Impulse(PTransform): + """The Dataflow specific override for the impulse primitive.""" + + def expand(self, pbegin): + assert isinstance(pbegin, pvalue.PBegin), ( + 'Input to Impulse transform must be a PBegin but found %s' % pbegin) + return pvalue.PCollection(pbegin.pipeline) + + def get_windowing(self, inputs): + return Windowing(GlobalWindows()) + + def infer_output_type(self, unused_input_type): + return bytes + + def expand(self, pbegin): + return (pbegin + | 'Impulse' >> self.Impulse() + | 'Decode Values' >> ParDo( + self.DecodeAndEmitDoFn(self.encoded_values, self.coder))) diff --git a/sdks/python/apache_beam/runners/dataflow/ptransform_overrides.py b/sdks/python/apache_beam/runners/dataflow/ptransform_overrides.py new file mode 100644 index 0000000000000..680a4b7de5c20 --- /dev/null +++ b/sdks/python/apache_beam/runners/dataflow/ptransform_overrides.py @@ -0,0 +1,52 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Ptransform overrides for DataflowRunner.""" + +from apache_beam.coders import typecoders +from apache_beam.pipeline import PTransformOverride + + +class CreatePTransformOverride(PTransformOverride): + """A ``PTransformOverride`` for ``Create`` in streaming mode.""" + + def get_matcher(self): + return self.is_streaming_create + + @staticmethod + def is_streaming_create(applied_ptransform): + # Imported here to avoid circular dependencies. + # pylint: disable=wrong-import-order, wrong-import-position + from apache_beam import Create + from apache_beam.options.pipeline_options import StandardOptions + + if isinstance(applied_ptransform.transform, Create): + standard_options = (applied_ptransform + .outputs[None] + .pipeline._options + .view_as(StandardOptions)) + return standard_options.streaming + else: + return False + + def get_replacement_transform(self, ptransform): + # Imported here to avoid circular dependencies. + # pylint: disable=wrong-import-order, wrong-import-position + from apache_beam.runners.dataflow.native_io.streaming_create import \ + StreamingCreate + coder = typecoders.registry.get_coder(ptransform.get_output_type()) + return StreamingCreate(ptransform.value, coder) diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py index c30136de2a439..801821909cf9a 100644 --- a/sdks/python/apache_beam/transforms/core.py +++ b/sdks/python/apache_beam/transforms/core.py @@ -1444,15 +1444,18 @@ def infer_output_type(self, unused_input_type): return Any return Union[[trivial_inference.instance_to_type(v) for v in self.value]] + def get_output_type(self): + return (self.get_type_hints().simple_output_type(self.label) or + self.infer_output_type(None)) + def expand(self, pbegin): from apache_beam.io import iobase assert isinstance(pbegin, pvalue.PBegin) self.pipeline = pbegin.pipeline - ouput_type = (self.get_type_hints().simple_output_type(self.label) or - self.infer_output_type(None)) - coder = typecoders.registry.get_coder(ouput_type) + coder = typecoders.registry.get_coder(self.get_output_type()) source = self._create_source_from_iterable(self.value, coder) - return pbegin.pipeline | iobase.Read(source).with_output_types(ouput_type) + return (pbegin.pipeline + | iobase.Read(source).with_output_types(self.get_output_type())) def get_windowing(self, unused_inputs): return Windowing(GlobalWindows()) From cf654a0bcd876310311f48deb64cd49d7df2893c Mon Sep 17 00:00:00 2001 From: Eugene Kirpichov Date: Fri, 16 Jun 2017 13:07:48 -0700 Subject: [PATCH 053/200] A few cleanups in CombineTest Better error messages and IntelliJ warning cleanups. --- .../beam/sdk/transforms/CombineTest.java | 125 ++++++++---------- 1 file changed, 53 insertions(+), 72 deletions(-) diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/CombineTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/CombineTest.java index c4ba62d148fae..e2469ab3ba6f1 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/CombineTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/CombineTest.java @@ -17,7 +17,7 @@ */ package org.apache.beam.sdk.transforms; -import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; import static org.apache.beam.sdk.TestUtils.checkCombineFn; import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem; @@ -45,7 +45,6 @@ import org.apache.beam.sdk.coders.BigEndianIntegerCoder; import org.apache.beam.sdk.coders.BigEndianLongCoder; import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.coders.CoderException; import org.apache.beam.sdk.coders.CoderRegistry; import org.apache.beam.sdk.coders.DoubleCoder; import org.apache.beam.sdk.coders.KvCoder; @@ -85,7 +84,6 @@ import org.junit.experimental.categories.Category; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; -import org.mockito.Mock; /** * Tests for Combine transforms. @@ -97,8 +95,6 @@ public class CombineTest implements Serializable { static final List> EMPTY_TABLE = Collections.emptyList(); - @Mock private DoFn.ProcessContext processContext; - @Rule public final transient TestPipeline pipeline = TestPipeline.create(); @@ -142,12 +138,12 @@ private void runTestSimpleCombineWithContext(List> table, PCollection> combinePerKey = perKeyInput.apply( Combine.perKey(new TestCombineFnWithContext(globallySumView)) - .withSideInputs(Arrays.asList(globallySumView))); + .withSideInputs(globallySumView)); PCollection combineGlobally = globallyInput .apply(Combine.globally(new TestCombineFnWithContext(globallySumView)) .withoutDefaults() - .withSideInputs(Arrays.asList(globallySumView))); + .withSideInputs(globallySumView)); PAssert.that(sum).containsInAnyOrder(globalSum); PAssert.that(combinePerKey).containsInAnyOrder(perKeyCombines); @@ -280,11 +276,9 @@ public void testFixedWindowsCombine() { .apply(Combine.perKey(new TestCombineFn())); PAssert.that(sum).containsInAnyOrder(2, 5, 13); - PAssert.that(sumPerKey).containsInAnyOrder( - KV.of("a", "11"), - KV.of("a", "4"), - KV.of("b", "1"), - KV.of("b", "13")); + PAssert.that(sumPerKey) + .containsInAnyOrder( + Arrays.asList(KV.of("a", "11"), KV.of("a", "4"), KV.of("b", "1"), KV.of("b", "13"))); pipeline.run(); } @@ -313,19 +307,18 @@ public void testFixedWindowsCombineWithContext() { PCollection> combinePerKeyWithContext = perKeyInput.apply( Combine.perKey(new TestCombineFnWithContext(globallySumView)) - .withSideInputs(Arrays.asList(globallySumView))); + .withSideInputs(globallySumView)); PCollection combineGloballyWithContext = globallyInput .apply(Combine.globally(new TestCombineFnWithContext(globallySumView)) .withoutDefaults() - .withSideInputs(Arrays.asList(globallySumView))); + .withSideInputs(globallySumView)); PAssert.that(sum).containsInAnyOrder(2, 5, 13); - PAssert.that(combinePerKeyWithContext).containsInAnyOrder( - KV.of("a", "2:11"), - KV.of("a", "5:4"), - KV.of("b", "5:1"), - KV.of("b", "13:13")); + PAssert.that(combinePerKeyWithContext) + .containsInAnyOrder( + Arrays.asList( + KV.of("a", "2:11"), KV.of("a", "5:4"), KV.of("b", "5:1"), KV.of("b", "13:13"))); PAssert.that(combineGloballyWithContext).containsInAnyOrder("2:11", "5:14", "13:13"); pipeline.run(); } @@ -355,23 +348,25 @@ public void testSlidingWindowsCombineWithContext() { PCollection> combinePerKeyWithContext = perKeyInput.apply( Combine.perKey(new TestCombineFnWithContext(globallySumView)) - .withSideInputs(Arrays.asList(globallySumView))); + .withSideInputs(globallySumView)); PCollection combineGloballyWithContext = globallyInput .apply(Combine.globally(new TestCombineFnWithContext(globallySumView)) .withoutDefaults() - .withSideInputs(Arrays.asList(globallySumView))); + .withSideInputs(globallySumView)); PAssert.that(sum).containsInAnyOrder(1, 2, 1, 4, 5, 14, 13); - PAssert.that(combinePerKeyWithContext).containsInAnyOrder( - KV.of("a", "1:1"), - KV.of("a", "2:11"), - KV.of("a", "1:1"), - KV.of("a", "4:4"), - KV.of("a", "5:4"), - KV.of("b", "5:1"), - KV.of("b", "14:113"), - KV.of("b", "13:13")); + PAssert.that(combinePerKeyWithContext) + .containsInAnyOrder( + Arrays.asList( + KV.of("a", "1:1"), + KV.of("a", "2:11"), + KV.of("a", "1:1"), + KV.of("a", "4:4"), + KV.of("a", "5:4"), + KV.of("b", "5:1"), + KV.of("b", "14:113"), + KV.of("b", "13:13"))); PAssert.that(combineGloballyWithContext).containsInAnyOrder( "1:1", "2:11", "1:1", "4:4", "5:14", "14:113", "13:13"); pipeline.run(); @@ -433,10 +428,8 @@ public void testSessionsCombine() { .apply(Combine.perKey(new TestCombineFn())); PAssert.that(sum).containsInAnyOrder(7, 13); - PAssert.that(sumPerKey).containsInAnyOrder( - KV.of("a", "114"), - KV.of("b", "1"), - KV.of("b", "13")); + PAssert.that(sumPerKey) + .containsInAnyOrder(Arrays.asList(KV.of("a", "114"), KV.of("b", "1"), KV.of("b", "13"))); pipeline.run(); } @@ -471,7 +464,7 @@ public void testSessionsCombineWithContext() { .apply( Combine.perKey( new TestCombineFnWithContext(globallyFixedWindowsView)) - .withSideInputs(Arrays.asList(globallyFixedWindowsView))); + .withSideInputs(globallyFixedWindowsView)); PCollection sessionsCombineGlobally = globallyInput @@ -481,13 +474,12 @@ public void testSessionsCombineWithContext() { .apply( Combine.globally(new TestCombineFnWithContext(globallyFixedWindowsView)) .withoutDefaults() - .withSideInputs(Arrays.asList(globallyFixedWindowsView))); + .withSideInputs(globallyFixedWindowsView)); PAssert.that(fixedWindowsSum).containsInAnyOrder(2, 4, 1, 13); - PAssert.that(sessionsCombinePerKey).containsInAnyOrder( - KV.of("a", "1:114"), - KV.of("b", "1:1"), - KV.of("b", "0:13")); + PAssert.that(sessionsCombinePerKey) + .containsInAnyOrder( + Arrays.asList(KV.of("a", "1:114"), KV.of("b", "1:1"), KV.of("b", "0:13"))); PAssert.that(sessionsCombineGlobally).containsInAnyOrder("1:1114", "0:13"); pipeline.run(); } @@ -716,7 +708,7 @@ public void testWindowedCombineGloballyAsSingletonView() { pipeline .apply( "CreateMainInput", - Create.timestamped(nonEmptyElement, emptyElement).withCoder(VoidCoder.of())) + Create.timestamped(nonEmptyElement, emptyElement).withCoder(VoidCoder.of())) .apply("WindowMainInput", Window.into(windowFn)) .apply( "OutputSideInput", @@ -941,15 +933,13 @@ public Coder getAccumulatorCoder( */ private class CountSumCoder extends AtomicCoder { @Override - public void encode(CountSum value, OutputStream outStream) - throws CoderException, IOException { + public void encode(CountSum value, OutputStream outStream) throws IOException { LONG_CODER.encode(value.count, outStream); DOUBLE_CODER.encode(value.sum, outStream); } @Override - public CountSum decode(InputStream inStream) - throws CoderException, IOException { + public CountSum decode(InputStream inStream) throws IOException { long count = LONG_CODER.decode(inStream); double sum = DOUBLE_CODER.decode(inStream); return new CountSum(count, sum); @@ -992,28 +982,15 @@ public Accumulator(String seed, String value) { public static Coder getCoder() { return new AtomicCoder() { @Override - public void encode(Accumulator accumulator, OutputStream outStream) - throws CoderException, IOException { - encode(accumulator, outStream, Coder.Context.NESTED); + public void encode(Accumulator accumulator, OutputStream outStream) throws IOException { + StringUtf8Coder.of().encode(accumulator.seed, outStream); + StringUtf8Coder.of().encode(accumulator.value, outStream); } @Override - public void encode(Accumulator accumulator, OutputStream outStream, Coder.Context context) - throws CoderException, IOException { - StringUtf8Coder.of().encode(accumulator.seed, outStream, context.nested()); - StringUtf8Coder.of().encode(accumulator.value, outStream, context); - } - - @Override - public Accumulator decode(InputStream inStream) throws CoderException, IOException { - return decode(inStream, Coder.Context.NESTED); - } - - @Override - public Accumulator decode(InputStream inStream, Coder.Context context) - throws CoderException, IOException { - String seed = StringUtf8Coder.of().decode(inStream, context.nested()); - String value = StringUtf8Coder.of().decode(inStream, context); + public Accumulator decode(InputStream inStream) throws IOException { + String seed = StringUtf8Coder.of().decode(inStream); + String value = StringUtf8Coder.of().decode(inStream); return new Accumulator(seed, value); } }; @@ -1042,18 +1019,22 @@ public Accumulator addInput(Accumulator accumulator, Integer value) { @Override public Accumulator mergeAccumulators(Iterable accumulators) { - String seed = null; - String all = ""; + Accumulator seedAccumulator = null; + StringBuilder all = new StringBuilder(); for (Accumulator accumulator : accumulators) { - if (seed == null) { - seed = accumulator.seed; + if (seedAccumulator == null) { + seedAccumulator = accumulator; } else { - checkArgument(seed.equals(accumulator.seed), "Different seed values in accumulator"); + assertEquals( + String.format( + "Different seed values in accumulator: %s vs. %s", seedAccumulator, accumulator), + seedAccumulator.seed, + accumulator.seed); } - all += accumulator.value; + all.append(accumulator.value); accumulator.value = "cleared in mergeAccumulators"; } - return new Accumulator(seed, all); + return new Accumulator(checkNotNull(seedAccumulator).seed, all.toString()); } @Override @@ -1161,7 +1142,7 @@ public void addInput(Integer element) { @Override public void mergeAccumulator(Counter accumulator) { checkState(outputs == 0); - checkArgument(accumulator.outputs == 0); + assertEquals(0, accumulator.outputs); merges += accumulator.merges + 1; inputs += accumulator.inputs; From c2683e876bea541684977ded9e179dd0e1a8ccdf Mon Sep 17 00:00:00 2001 From: Vikas Kedigehalli Date: Fri, 16 Jun 2017 17:18:57 -0700 Subject: [PATCH 054/200] Add dry run option to DataflowRunner --- sdks/python/apache_beam/options/pipeline_options.py | 5 +++++ .../apache_beam/runners/dataflow/dataflow_runner.py | 7 ++++++- .../apache_beam/runners/dataflow/dataflow_runner_test.py | 9 ++++----- 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/sdks/python/apache_beam/options/pipeline_options.py b/sdks/python/apache_beam/options/pipeline_options.py index 8644e51b2dbe0..dab8ff204d3a4 100644 --- a/sdks/python/apache_beam/options/pipeline_options.py +++ b/sdks/python/apache_beam/options/pipeline_options.py @@ -605,6 +605,11 @@ def _add_argparse_args(cls, parser): help=('Verify state/output of e2e test pipeline. This is pickled ' 'version of the matcher which should extends ' 'hamcrest.core.base_matcher.BaseMatcher.')) + parser.add_argument( + '--dry_run', + default=False, + help=('Used in unit testing runners without submitting the ' + 'actual job.')) def validate(self, validator): errors = [] diff --git a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py b/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py index ce46ea9a23f7a..9395f1688056e 100644 --- a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py +++ b/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py @@ -48,6 +48,7 @@ from apache_beam.typehints import typehints from apache_beam.options.pipeline_options import StandardOptions from apache_beam.options.pipeline_options import SetupOptions +from apache_beam.options.pipeline_options import TestOptions from apache_beam.utils.plugin import BeamPlugin @@ -228,7 +229,6 @@ def visit_transform(self, transform_node): return FlattenInputVisitor() - # TODO(mariagh): Make this method take pipepline_options def run(self, pipeline): """Remotely executes entire pipeline or parts reachable from node.""" # Import here to avoid adding the dependency for local running scenarios. @@ -263,6 +263,11 @@ def run(self, pipeline): # The superclass's run will trigger a traversal of all reachable nodes. super(DataflowRunner, self).run(pipeline) + test_options = pipeline._options.view_as(TestOptions) + # If it is a dry run, return without submitting the job. + if test_options.dry_run: + return None + standard_options = pipeline._options.view_as(StandardOptions) if standard_options.streaming: job_version = DataflowRunner.STREAMING_ENVIRONMENT_MAJOR_VERSION diff --git a/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py b/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py index 819d4713c11cd..6cc5814a5fbcf 100644 --- a/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py +++ b/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py @@ -59,7 +59,8 @@ class DataflowRunnerTest(unittest.TestCase): '--project=test-project', '--staging_location=ignored', '--temp_location=/dev/null', - '--no_auth=True'] + '--no_auth=True', + '--dry_run=True'] @mock.patch('time.sleep', return_value=None) def test_wait_until_finish(self, patched_time_sleep): @@ -108,8 +109,7 @@ def test_remote_runner_translation(self): (p | ptransform.Create([1, 2, 3]) # pylint: disable=expression-not-assigned | 'Do' >> ptransform.FlatMap(lambda x: [(x, x)]) | ptransform.GroupByKey()) - remote_runner.job = apiclient.Job(p._options) - super(DataflowRunner, remote_runner).run(p) + p.run() def test_streaming_create_translation(self): remote_runner = DataflowRunner() @@ -160,8 +160,7 @@ def process(self): (p | ptransform.Create([1, 2, 3, 4, 5]) | 'Do' >> SpecialParDo(SpecialDoFn(), now)) - remote_runner.job = apiclient.Job(p._options) - super(DataflowRunner, remote_runner).run(p) + p.run() job_dict = json.loads(str(remote_runner.job)) steps = [step for step in job_dict['steps'] From 87be64e9817da5e5c86a243471021268d6281b33 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jean-Baptiste=20Onofr=C3=A9?= Date: Fri, 12 May 2017 15:21:49 +0200 Subject: [PATCH 055/200] [BEAM-975] Improve default connection options, javadoc and style in MongoDbIO --- .../apache/beam/sdk/io/mongodb/MongoDbIO.java | 315 ++++++++++++++---- .../beam/sdk/io/mongodb/MongoDbIOTest.java | 37 ++ 2 files changed, 283 insertions(+), 69 deletions(-) diff --git a/sdks/java/io/mongodb/src/main/java/org/apache/beam/sdk/io/mongodb/MongoDbIO.java b/sdks/java/io/mongodb/src/main/java/org/apache/beam/sdk/io/mongodb/MongoDbIO.java index 620df74f24646..04d9975a6760a 100644 --- a/sdks/java/io/mongodb/src/main/java/org/apache/beam/sdk/io/mongodb/MongoDbIO.java +++ b/sdks/java/io/mongodb/src/main/java/org/apache/beam/sdk/io/mongodb/MongoDbIO.java @@ -18,12 +18,13 @@ package org.apache.beam.sdk.io.mongodb; import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkState; import com.google.auto.value.AutoValue; import com.google.common.annotations.VisibleForTesting; import com.mongodb.BasicDBObject; import com.mongodb.MongoClient; +import com.mongodb.MongoClientOptions; import com.mongodb.MongoClientURI; import com.mongodb.client.MongoCollection; import com.mongodb.client.MongoCursor; @@ -100,12 +101,20 @@ public class MongoDbIO { /** Read data from MongoDB. */ public static Read read() { - return new AutoValue_MongoDbIO_Read.Builder().setNumSplits(0).build(); + return new AutoValue_MongoDbIO_Read.Builder() + .setKeepAlive(true) + .setMaxConnectionIdleTime(60000) + .setNumSplits(0) + .build(); } /** Write data to MongoDB. */ public static Write write() { - return new AutoValue_MongoDbIO_Write.Builder().setBatchSize(1024L).build(); + return new AutoValue_MongoDbIO_Write.Builder() + .setKeepAlive(true) + .setMaxConnectionIdleTime(60000) + .setBatchSize(1024L) + .build(); } private MongoDbIO() { @@ -117,16 +126,20 @@ private MongoDbIO() { @AutoValue public abstract static class Read extends PTransform> { @Nullable abstract String uri(); + abstract boolean keepAlive(); + abstract int maxConnectionIdleTime(); @Nullable abstract String database(); @Nullable abstract String collection(); @Nullable abstract String filter(); abstract int numSplits(); - abstract Builder toBuilder(); + abstract Builder builder(); @AutoValue.Builder abstract static class Builder { abstract Builder setUri(String uri); + abstract Builder setKeepAlive(boolean keepAlive); + abstract Builder setMaxConnectionIdleTime(int maxConnectionIdleTime); abstract Builder setDatabase(String database); abstract Builder setCollection(String collection); abstract Builder setFilter(String filter); @@ -135,31 +148,94 @@ abstract static class Builder { } /** - * Example documentation for withUri. + * Define the location of the MongoDB instances using an URI. The URI describes the hosts to + * be used and some options. + * + *

    The format of the URI is: + * + *

    {@code
    +     * mongodb://[username:password@]host1[:port1]...[,hostN[:portN]]][/[database][?options]]
    +     * }
    + * + *

    Where: + *

      + *
    • {@code mongodb://} is a required prefix to identify that this is a string in the + * standard connection format.
    • + *
    • {@code username:password@} are optional. If given, the driver will attempt to + * login to a database after connecting to a database server. For some authentication + * mechanisms, only the username is specified and the password is not, in which case + * the ":" after the username is left off as well.
    • + *
    • {@code host1} is the only required part of the URI. It identifies a server + * address to connect to.
    • + *
    • {@code :portX} is optional and defaults to {@code :27017} if not provided.
    • + *
    • {@code /database} is the name of the database to login to and thus is only + * relevant if the {@code username:password@} syntax is used. If not specified, the + * "admin" database will be used by default. It has to be equivalent with the database + * you specific with {@link Read#withDatabase(String)}.
    • + *
    • {@code ?options} are connection options. Note that if {@code database} is absent + * there is still a {@code /} required between the last {@code host} and the {@code ?} + * introducing the options. Options are name=value pairs and the pairs are separated by + * "{@code &}". The {@code KeepAlive} connection option can't be passed via the URI, + * instead you have to use {@link Read#withKeepAlive(boolean)}. Same for the + * {@code MaxConnectionIdleTime} connection option via + * {@link Read#withMaxConnectionIdleTime(int)}. + *
    • + *
    */ public Read withUri(String uri) { - checkNotNull(uri); - return toBuilder().setUri(uri).build(); + checkArgument(uri != null, "MongoDbIO.read().withUri(uri) called with null uri"); + return builder().setUri(uri).build(); + } + + /** + * Sets whether socket keep alive is enabled. + */ + public Read withKeepAlive(boolean keepAlive) { + return builder().setKeepAlive(keepAlive).build(); + } + + /** + * Sets the maximum idle time for a pooled connection. + */ + public Read withMaxConnectionIdleTime(int maxConnectionIdleTime) { + return builder().setMaxConnectionIdleTime(maxConnectionIdleTime).build(); } + /** + * Sets the database to use. + */ public Read withDatabase(String database) { - checkNotNull(database); - return toBuilder().setDatabase(database).build(); + checkArgument(database != null, "MongoDbIO.read().withDatabase(database) called with null" + + " database"); + return builder().setDatabase(database).build(); } + /** + * Sets the collection to consider in the database. + */ public Read withCollection(String collection) { - checkNotNull(collection); - return toBuilder().setCollection(collection).build(); + checkArgument(collection != null, "MongoDbIO.read().withCollection(collection) called " + + "with null collection"); + return builder().setCollection(collection).build(); } + /** + * Sets a filter on the documents in a collection. + */ public Read withFilter(String filter) { - checkNotNull(filter); - return toBuilder().setFilter(filter).build(); + checkArgument(filter != null, "MongoDbIO.read().withFilter(filter) called with null " + + "filter"); + return builder().setFilter(filter).build(); } + /** + * Sets the user defined number of splits. + */ public Read withNumSplits(int numSplits) { - checkArgument(numSplits >= 0); - return toBuilder().setNumSplits(numSplits).build(); + checkArgument(numSplits >= 0, "MongoDbIO.read().withNumSplits(numSplits) called with " + + "invalid number. The number of splits has to be a positive value (currently %d)", + numSplits); + return builder().setNumSplits(numSplits).build(); } @Override @@ -169,15 +245,19 @@ public PCollection expand(PBegin input) { @Override public void validate(PipelineOptions options) { - checkNotNull(uri(), "uri"); - checkNotNull(database(), "database"); - checkNotNull(collection(), "collection"); + checkState(uri() != null, "MongoDbIO.read() requires an URI to be set via withUri(uri)"); + checkState(database() != null, "MongoDbIO.read() requires a database to be set via " + + "withDatabase(database)"); + checkState(collection() != null, "MongoDbIO.read() requires a collection to be set via " + + "withCollection(collection)"); } @Override public void populateDisplayData(DisplayData.Builder builder) { super.populateDisplayData(builder); builder.add(DisplayData.item("uri", uri())); + builder.add(DisplayData.item("keepAlive", keepAlive())); + builder.add(DisplayData.item("maxConnectionIdleTime", maxConnectionIdleTime())); builder.add(DisplayData.item("database", database())); builder.add(DisplayData.item("collection", collection())); builder.addIfNotNull(DisplayData.item("filter", filter())); @@ -218,61 +298,71 @@ public BoundedReader createReader(PipelineOptions options) { @Override public long getEstimatedSizeBytes(PipelineOptions pipelineOptions) { - MongoClient mongoClient = new MongoClient(new MongoClientURI(spec.uri())); - MongoDatabase mongoDatabase = mongoClient.getDatabase(spec.database()); + try (MongoClient mongoClient = new MongoClient(new MongoClientURI(spec.uri()))) { + return getEstimatedSizeBytes(mongoClient, spec.database(), spec.collection()); + } + } + + private long getEstimatedSizeBytes(MongoClient mongoClient, + String database, + String collection) { + MongoDatabase mongoDatabase = mongoClient.getDatabase(database); // get the Mongo collStats object // it gives the size for the entire collection BasicDBObject stat = new BasicDBObject(); - stat.append("collStats", spec.collection()); + stat.append("collStats", collection); Document stats = mongoDatabase.runCommand(stat); + return stats.get("size", Number.class).longValue(); } @Override public List> split(long desiredBundleSizeBytes, PipelineOptions options) { - MongoClient mongoClient = new MongoClient(new MongoClientURI(spec.uri())); - MongoDatabase mongoDatabase = mongoClient.getDatabase(spec.database()); - - List splitKeys; - if (spec.numSplits() > 0) { - // the user defines his desired number of splits - // calculate the batch size - long estimatedSizeBytes = getEstimatedSizeBytes(options); - desiredBundleSizeBytes = estimatedSizeBytes / spec.numSplits(); - } + try (MongoClient mongoClient = new MongoClient(new MongoClientURI(spec.uri()))) { + MongoDatabase mongoDatabase = mongoClient.getDatabase(spec.database()); + + List splitKeys; + if (spec.numSplits() > 0) { + // the user defines his desired number of splits + // calculate the batch size + long estimatedSizeBytes = getEstimatedSizeBytes(mongoClient, + spec.database(), spec.collection()); + desiredBundleSizeBytes = estimatedSizeBytes / spec.numSplits(); + } - // the desired batch size is small, using default chunk size of 1MB - if (desiredBundleSizeBytes < 1024 * 1024) { - desiredBundleSizeBytes = 1 * 1024 * 1024; - } + // the desired batch size is small, using default chunk size of 1MB + if (desiredBundleSizeBytes < 1024 * 1024) { + desiredBundleSizeBytes = 1 * 1024 * 1024; + } - // now we have the batch size (provided by user or provided by the runner) - // we use Mongo splitVector command to get the split keys - BasicDBObject splitVectorCommand = new BasicDBObject(); - splitVectorCommand.append("splitVector", spec.database() + "." + spec.collection()); - splitVectorCommand.append("keyPattern", new BasicDBObject().append("_id", 1)); - splitVectorCommand.append("force", false); - // maxChunkSize is the Mongo partition size in MB - LOG.debug("Splitting in chunk of {} MB", desiredBundleSizeBytes / 1024 / 1024); - splitVectorCommand.append("maxChunkSize", desiredBundleSizeBytes / 1024 / 1024); - Document splitVectorCommandResult = mongoDatabase.runCommand(splitVectorCommand); - splitKeys = (List) splitVectorCommandResult.get("splitKeys"); - - List> sources = new ArrayList<>(); - if (splitKeys.size() < 1) { - LOG.debug("Split keys is low, using an unique source"); - sources.add(this); - return sources; - } + // now we have the batch size (provided by user or provided by the runner) + // we use Mongo splitVector command to get the split keys + BasicDBObject splitVectorCommand = new BasicDBObject(); + splitVectorCommand.append("splitVector", spec.database() + "." + spec.collection()); + splitVectorCommand.append("keyPattern", new BasicDBObject().append("_id", 1)); + splitVectorCommand.append("force", false); + // maxChunkSize is the Mongo partition size in MB + LOG.debug("Splitting in chunk of {} MB", desiredBundleSizeBytes / 1024 / 1024); + splitVectorCommand.append("maxChunkSize", desiredBundleSizeBytes / 1024 / 1024); + Document splitVectorCommandResult = mongoDatabase.runCommand(splitVectorCommand); + splitKeys = (List) splitVectorCommandResult.get("splitKeys"); + + List> sources = new ArrayList<>(); + if (splitKeys.size() < 1) { + LOG.debug("Split keys is low, using an unique source"); + sources.add(this); + return sources; + } - LOG.debug("Number of splits is {}", splitKeys.size()); - for (String shardFilter : splitKeysToFilters(splitKeys, spec.filter())) { - sources.add(new BoundedMongoDbSource(spec.withFilter(shardFilter))); - } + LOG.debug("Number of splits is {}", splitKeys.size()); + for (String shardFilter : splitKeysToFilters(splitKeys, spec.filter())) { + sources.add(new BoundedMongoDbSource(spec.withFilter(shardFilter))); + } - return sources; + return sources; + } } /** @@ -367,7 +457,10 @@ public BoundedMongoDbReader(BoundedMongoDbSource source) { @Override public boolean start() { Read spec = source.spec; - client = new MongoClient(new MongoClientURI(spec.uri())); + MongoClientOptions.Builder optionsBuilder = new MongoClientOptions.Builder(); + optionsBuilder.maxConnectionIdleTime(spec.maxConnectionIdleTime()); + optionsBuilder.socketKeepAlive(spec.keepAlive()); + client = new MongoClient(new MongoClientURI(spec.uri(), optionsBuilder)); MongoDatabase mongoDatabase = client.getDatabase(spec.database()); @@ -426,36 +519,106 @@ public void close() { */ @AutoValue public abstract static class Write extends PTransform, PDone> { + @Nullable abstract String uri(); + abstract boolean keepAlive(); + abstract int maxConnectionIdleTime(); @Nullable abstract String database(); @Nullable abstract String collection(); abstract long batchSize(); - abstract Builder toBuilder(); + abstract Builder builder(); @AutoValue.Builder abstract static class Builder { abstract Builder setUri(String uri); + abstract Builder setKeepAlive(boolean keepAlive); + abstract Builder setMaxConnectionIdleTime(int maxConnectionIdleTime); abstract Builder setDatabase(String database); abstract Builder setCollection(String collection); abstract Builder setBatchSize(long batchSize); abstract Write build(); } + /** + * Define the location of the MongoDB instances using an URI. The URI describes the hosts to + * be used and some options. + * + *

    The format of the URI is: + * + *

    {@code
    +     * mongodb://[username:password@]host1[:port1],...[,hostN[:portN]]][/[database][?options]]
    +     * }
    + * + *

    Where: + *

      + *
    • {@code mongodb://} is a required prefix to identify that this is a string in the + * standard connection format.
    • + *
    • {@code username:password@} are optional. If given, the driver will attempt to + * login to a database after connecting to a database server. For some authentication + * mechanisms, only the username is specified and the password is not, in which case + * the ":" after the username is left off as well.
    • + *
    • {@code host1} is the only required part of the URI. It identifies a server + * address to connect to.
    • + *
    • {@code :portX} is optional and defaults to {@code :27017} if not provided.
    • + *
    • {@code /database} is the name of the database to login to and thus is only + * relevant if the {@code username:password@} syntax is used. If not specified, the + * "admin" database will be used by default. It has to be equivalent with the database + * you specific with {@link Write#withDatabase(String)}.
    • + *
    • {@code ?options} are connection options. Note that if {@code database} is absent + * there is still a {@code /} required between the last {@code host} and the {@code ?} + * introducing the options. Options are name=value pairs and the pairs are separated by + * "{@code &}". The {@code KeepAlive} connection option can't be passed via the URI, instead + * you have to use {@link Write#withKeepAlive(boolean)}. Same for the + * {@code MaxConnectionIdleTime} connection option via + * {@link Write#withMaxConnectionIdleTime(int)}. + *
    • + *
    + */ public Write withUri(String uri) { - return toBuilder().setUri(uri).build(); + checkArgument(uri != null, "MongoDbIO.write().withUri(uri) called with null uri"); + return builder().setUri(uri).build(); + } + + /** + * Sets whether socket keep alive is enabled. + */ + public Write withKeepAlive(boolean keepAlive) { + return builder().setKeepAlive(keepAlive).build(); + } + + /** + * Sets the maximum idle time for a pooled connection. + */ + public Write withMaxConnectionIdleTime(int maxConnectionIdleTime) { + return builder().setMaxConnectionIdleTime(maxConnectionIdleTime).build(); } + /** + * Sets the database to use. + */ public Write withDatabase(String database) { - return toBuilder().setDatabase(database).build(); + checkArgument(database != null, "MongoDbIO.write().withDatabase(database) called with " + + "null database"); + return builder().setDatabase(database).build(); } + /** + * Sets the collection where to write data in the database. + */ public Write withCollection(String collection) { - return toBuilder().setCollection(collection).build(); + checkArgument(collection != null, "MongoDbIO.write().withCollection(collection) called " + + "with null collection"); + return builder().setCollection(collection).build(); } + /** + * Define the size of the batch to group write operations. + */ public Write withBatchSize(long batchSize) { - return toBuilder().setBatchSize(batchSize).build(); + checkArgument(batchSize >= 0, "MongoDbIO.write().withBatchSize(batchSize) called with " + + "invalid batch size. Batch size has to be >= 0 (currently %d)", batchSize); + return builder().setBatchSize(batchSize).build(); } @Override @@ -466,10 +629,21 @@ public PDone expand(PCollection input) { @Override public void validate(PipelineOptions options) { - checkNotNull(uri(), "uri"); - checkNotNull(database(), "database"); - checkNotNull(collection(), "collection"); - checkNotNull(batchSize(), "batchSize"); + checkState(uri() != null, "MongoDbIO.write() requires an URI to be set via withUri(uri)"); + checkState(database() != null, "MongoDbIO.write() requires a database to be set via " + + "withDatabase(database)"); + checkState(collection() != null, "MongoDbIO.write() requires a collection to be set via " + + "withCollection(collection)"); + } + + @Override + public void populateDisplayData(DisplayData.Builder builder) { + builder.add(DisplayData.item("uri", uri())); + builder.add(DisplayData.item("keepAlive", keepAlive())); + builder.add(DisplayData.item("maxConnectionIdleTime", maxConnectionIdleTime())); + builder.add(DisplayData.item("database", database())); + builder.add(DisplayData.item("collection", collection())); + builder.add(DisplayData.item("batchSize", batchSize())); } private static class WriteFn extends DoFn { @@ -483,7 +657,10 @@ public WriteFn(Write spec) { @Setup public void createMongoClient() throws Exception { - client = new MongoClient(new MongoClientURI(spec.uri())); + MongoClientOptions.Builder builder = new MongoClientOptions.Builder(); + builder.socketKeepAlive(spec.keepAlive()); + builder.maxConnectionIdleTime(spec.maxConnectionIdleTime()); + client = new MongoClient(new MongoClientURI(spec.uri(), builder)); } @StartBundle diff --git a/sdks/java/io/mongodb/src/test/java/org/apache/beam/sdk/io/mongodb/MongoDbIOTest.java b/sdks/java/io/mongodb/src/test/java/org/apache/beam/sdk/io/mongodb/MongoDbIOTest.java index cd26b483cda93..67dbca4af196b 100644 --- a/sdks/java/io/mongodb/src/test/java/org/apache/beam/sdk/io/mongodb/MongoDbIOTest.java +++ b/sdks/java/io/mongodb/src/test/java/org/apache/beam/sdk/io/mongodb/MongoDbIOTest.java @@ -18,6 +18,7 @@ package org.apache.beam.sdk.io.mongodb; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; import com.mongodb.MongoClient; import com.mongodb.client.MongoCollection; @@ -188,6 +189,42 @@ public Void apply(Iterable> input) { pipeline.run(); } + @Test + public void testReadWithCustomConnectionOptions() throws Exception { + MongoDbIO.Read read = MongoDbIO.read() + .withUri("mongodb://localhost:" + port) + .withKeepAlive(false) + .withMaxConnectionIdleTime(10) + .withDatabase(DATABASE) + .withCollection(COLLECTION); + assertFalse(read.keepAlive()); + assertEquals(10, read.maxConnectionIdleTime()); + + PCollection documents = pipeline.apply(read); + + PAssert.thatSingleton(documents.apply("Count All", Count.globally())) + .isEqualTo(1000L); + + PAssert.that(documents + .apply("Map Scientist", MapElements.via(new SimpleFunction>() { + public KV apply(Document input) { + return KV.of(input.getString("scientist"), null); + } + })) + .apply("Count Scientist", Count.perKey()) + ).satisfies(new SerializableFunction>, Void>() { + @Override + public Void apply(Iterable> input) { + for (KV element : input) { + assertEquals(100L, element.getValue().longValue()); + } + return null; + } + }); + + pipeline.run(); + } + @Test public void testReadWithFilter() throws Exception { From 6d27282562911179ea3ff19fd7eae54e8b89425d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Isma=C3=ABl=20Mej=C3=ADa?= Date: Mon, 19 Jun 2017 16:43:28 +0200 Subject: [PATCH 056/200] Make HBaseIO tests faster by only using the core daemons needed by HBase --- .../org/apache/beam/sdk/io/hbase/HBaseIOTest.java | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/sdks/java/io/hbase/src/test/java/org/apache/beam/sdk/io/hbase/HBaseIOTest.java b/sdks/java/io/hbase/src/test/java/org/apache/beam/sdk/io/hbase/HBaseIOTest.java index 4a067895b22c5..005770da86569 100644 --- a/sdks/java/io/hbase/src/test/java/org/apache/beam/sdk/io/hbase/HBaseIOTest.java +++ b/sdks/java/io/hbase/src/test/java/org/apache/beam/sdk/io/hbase/HBaseIOTest.java @@ -47,6 +47,7 @@ import org.apache.hadoop.hbase.HColumnDescriptor; import org.apache.hadoop.hbase.HConstants; import org.apache.hadoop.hbase.HTableDescriptor; +import org.apache.hadoop.hbase.MiniHBaseCluster; import org.apache.hadoop.hbase.TableName; import org.apache.hadoop.hbase.client.BufferedMutator; import org.apache.hadoop.hbase.client.Connection; @@ -96,7 +97,12 @@ public static void beforeClass() throws Exception { conf.setStrings("hbase.master.hostname", "localhost"); conf.setStrings("hbase.regionserver.hostname", "localhost"); htu = new HBaseTestingUtility(conf); - htu.startMiniCluster(1, 4); + + // We don't use the full htu.startMiniCluster() to avoid starting unneeded HDFS/MR daemons + htu.startMiniZKCluster(); + MiniHBaseCluster hbm = htu.startMiniHBaseCluster(1, 4); + hbm.waitForActiveAndReadyMaster(); + admin = htu.getHBaseAdmin(); } @@ -107,7 +113,8 @@ public static void afterClass() throws Exception { admin = null; } if (htu != null) { - htu.shutdownMiniCluster(); + htu.shutdownMiniHBaseCluster(); + htu.shutdownMiniZKCluster(); htu = null; } } From 595ca1ec84134328be5be3c8ae21a5a43a5a7166 Mon Sep 17 00:00:00 2001 From: Luke Cwik Date: Mon, 19 Jun 2017 15:48:06 -0700 Subject: [PATCH 057/200] [BEAM-1348] Fix type error introduced into Python SDK because of PR/3268 --- sdks/python/apache_beam/runners/portability/fn_api_runner.py | 2 +- sdks/python/apache_beam/runners/worker/sdk_worker.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner.py b/sdks/python/apache_beam/runners/portability/fn_api_runner.py index 90764f4dfac01..d792131a9d2eb 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner.py @@ -264,7 +264,7 @@ def outputs(op): element_coder.get_impl().encode_to_stream( element, output_stream, True) elements_data = output_stream.get() - state_key = beam_fn_api_pb2.StateKey(key=view_id) + state_key = beam_fn_api_pb2.StateKey.MultimapSideInput(key=view_id) state_handler.Clear(state_key) state_handler.Append(state_key, elements_data) diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker.py b/sdks/python/apache_beam/runners/worker/sdk_worker.py index f662538e981dc..d08b1798a94b4 100644 --- a/sdks/python/apache_beam/runners/worker/sdk_worker.py +++ b/sdks/python/apache_beam/runners/worker/sdk_worker.py @@ -358,7 +358,7 @@ def create_side_input(tag, si): tag=tag, source=SideInputSource( self.state_handler, - beam_fn_api_pb2.StateKey( + beam_fn_api_pb2.StateKey.MultimapSideInput( key=si.view_fn.id.encode('utf-8')), coder=unpack_and_deserialize_py_fn(si.view_fn))) output_tags = list(transform.outputs.keys()) From 1ec59a08a3fab5ac0918d7f1a33b82427957b630 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Isma=C3=ABl=20Mej=C3=ADa?= Date: Mon, 5 Jun 2017 23:48:38 +0200 Subject: [PATCH 058/200] [BEAM-2411] Make the write transform of HBaseIO simpler --- .../org/apache/beam/sdk/io/hbase/HBaseIO.java | 45 +++++++------------ .../apache/beam/sdk/io/hbase/HBaseIOTest.java | 37 +++++++-------- 2 files changed, 32 insertions(+), 50 deletions(-) diff --git a/sdks/java/io/hbase/src/main/java/org/apache/beam/sdk/io/hbase/HBaseIO.java b/sdks/java/io/hbase/src/main/java/org/apache/beam/sdk/io/hbase/HBaseIO.java index 849873c059706..626fad90b4cb4 100644 --- a/sdks/java/io/hbase/src/main/java/org/apache/beam/sdk/io/hbase/HBaseIO.java +++ b/sdks/java/io/hbase/src/main/java/org/apache/beam/sdk/io/hbase/HBaseIO.java @@ -31,10 +31,7 @@ import java.util.TreeSet; import javax.annotation.Nullable; import org.apache.beam.sdk.annotations.Experimental; -import org.apache.beam.sdk.coders.ByteArrayCoder; import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.coders.IterableCoder; -import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.io.BoundedSource; import org.apache.beam.sdk.io.hadoop.SerializableConfiguration; import org.apache.beam.sdk.io.range.ByteKey; @@ -44,7 +41,6 @@ import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.display.DisplayData; -import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PDone; @@ -122,16 +118,15 @@ *

    Writing to HBase

    * *

    The HBase sink executes a set of row mutations on a single table. It takes as input a - * {@link PCollection PCollection<KV<byte[], Iterable<Mutation>>>}, where the - * {@code byte[]} is the key of the row being mutated, and each {@link Mutation} represents an - * idempotent transformation to that row. + * {@link PCollection PCollection<Mutation>}, where each {@link Mutation} represents an + * idempotent transformation on a row. * *

    To configure a HBase sink, you must supply a table id and a {@link Configuration} * to identify the HBase instance, for example: * *

    {@code
      * Configuration configuration = ...;
    - * PCollection>> data = ...;
    + * PCollection data = ...;
      * data.setCoder(HBaseIO.WRITE_CODER);
      *
      * data.apply("write",
    @@ -545,9 +540,7 @@ public static Write write() {
          *
          * @see HBaseIO
          */
    -    public static class Write
    -            extends PTransform>>, PDone> {
    -
    +    public static class Write extends PTransform, PDone> {
             /**
              * Returns a new {@link HBaseIO.Write} that will write to the HBase instance
              * indicated by the given Configuration, and using any other specified customizations.
    @@ -575,7 +568,7 @@ private Write(SerializableConfiguration serializableConfiguration, String tableI
             }
     
             @Override
    -        public PDone expand(PCollection>> input) {
    +        public PDone expand(PCollection input) {
                 input.apply(ParDo.of(new HBaseWriterFn(tableId, serializableConfiguration)));
                 return PDone.in(input.getPipeline());
             }
    @@ -613,7 +606,7 @@ public Configuration getConfiguration() {
             private final String tableId;
             private final SerializableConfiguration serializableConfiguration;
     
    -        private class HBaseWriterFn extends DoFn>, Void> {
    +        private class HBaseWriterFn extends DoFn {
     
                 public HBaseWriterFn(String tableId,
                                      SerializableConfiguration serializableConfiguration) {
    @@ -624,31 +617,27 @@ public HBaseWriterFn(String tableId,
     
                 @Setup
                 public void setup() throws Exception {
    -                Configuration configuration = this.serializableConfiguration.get();
    -                connection = ConnectionFactory.createConnection(configuration);
    +                connection = ConnectionFactory.createConnection(serializableConfiguration.get());
    +            }
     
    -                TableName tableName = TableName.valueOf(tableId);
    +            @StartBundle
    +            public void startBundle(StartBundleContext c) throws IOException {
                     BufferedMutatorParams params =
    -                    new BufferedMutatorParams(tableName);
    +                    new BufferedMutatorParams(TableName.valueOf(tableId));
                     mutator = connection.getBufferedMutator(params);
    -
                     recordsWritten = 0;
                 }
     
                 @ProcessElement
    -            public void processElement(ProcessContext ctx) throws Exception {
    -                KV> record = ctx.element();
    -                List mutations = new ArrayList<>();
    -                for (Mutation mutation : record.getValue()) {
    -                    mutations.add(mutation);
    -                    ++recordsWritten;
    -                }
    -                mutator.mutate(mutations);
    +            public void processElement(ProcessContext c) throws Exception {
    +                mutator.mutate(c.element());
    +                ++recordsWritten;
                 }
     
                 @FinishBundle
                 public void finishBundle() throws Exception {
                     mutator.flush();
    +                LOG.debug("Wrote {} records", recordsWritten);
                 }
     
                 @Teardown
    @@ -661,7 +650,6 @@ public void tearDown() throws Exception {
                         connection.close();
                         connection = null;
                     }
    -                LOG.debug("Wrote {} records", recordsWritten);
                 }
     
                 @Override
    @@ -679,6 +667,5 @@ public void populateDisplayData(DisplayData.Builder builder) {
             }
         }
     
    -    public static final Coder>> WRITE_CODER =
    -            KvCoder.of(ByteArrayCoder.of(), IterableCoder.of(HBaseMutationCoder.of()));
    +    public static final Coder WRITE_CODER = HBaseMutationCoder.of();
     }
    diff --git a/sdks/java/io/hbase/src/test/java/org/apache/beam/sdk/io/hbase/HBaseIOTest.java b/sdks/java/io/hbase/src/test/java/org/apache/beam/sdk/io/hbase/HBaseIOTest.java
    index 005770da86569..d081139b40dd8 100644
    --- a/sdks/java/io/hbase/src/test/java/org/apache/beam/sdk/io/hbase/HBaseIOTest.java
    +++ b/sdks/java/io/hbase/src/test/java/org/apache/beam/sdk/io/hbase/HBaseIOTest.java
    @@ -38,7 +38,6 @@
     import org.apache.beam.sdk.transforms.Count;
     import org.apache.beam.sdk.transforms.Create;
     import org.apache.beam.sdk.transforms.display.DisplayData;
    -import org.apache.beam.sdk.values.KV;
     import org.apache.beam.sdk.values.PCollection;
     import org.apache.commons.lang3.StringUtils;
     import org.apache.hadoop.conf.Configuration;
    @@ -292,15 +291,17 @@ public void testWriting() throws Exception {
             final String table = "table";
             final String key = "key";
             final String value = "value";
    +        final int numMutations = 100;
     
             createTable(table);
     
    -        p.apply("single row", Create.of(makeWrite(key, value)).withCoder(HBaseIO.WRITE_CODER))
    -                .apply("write", HBaseIO.write().withConfiguration(conf).withTableId(table));
    +        p.apply("multiple rows", Create.of(makeMutations(key, value, numMutations))
    +            .withCoder(HBaseIO.WRITE_CODER))
    +         .apply("write", HBaseIO.write().withConfiguration(conf).withTableId(table));
             p.run().waitUntilFinish();
     
             List results = readTable(table, new Scan());
    -        assertEquals(1, results.size());
    +        assertEquals(numMutations, results.size());
         }
     
         /** Tests that when writing to a non-existent table, the write fails. */
    @@ -308,10 +309,8 @@ public void testWriting() throws Exception {
         public void testWritingFailsTableDoesNotExist() throws Exception {
             final String table = "TEST-TABLE-DOES-NOT-EXIST";
     
    -        PCollection>> emptyInput =
    -                p.apply(Create.empty(HBaseIO.WRITE_CODER));
    -
    -        emptyInput.apply("write", HBaseIO.write().withConfiguration(conf).withTableId(table));
    +        p.apply(Create.empty(HBaseIO.WRITE_CODER))
    +         .apply("write", HBaseIO.write().withConfiguration(conf).withTableId(table));
     
             // Exception will be thrown by write.validate() when write is applied.
             thrown.expect(IllegalArgumentException.class);
    @@ -326,7 +325,7 @@ public void testWritingFailsBadElement() throws Exception {
             final String key = "KEY";
             createTable(table);
     
    -        p.apply(Create.of(makeBadWrite(key)).withCoder(HBaseIO.WRITE_CODER))
    +        p.apply(Create.of(makeBadMutation(key)).withCoder(HBaseIO.WRITE_CODER))
                     .apply(HBaseIO.write().withConfiguration(conf).withTableId(table));
     
             thrown.expect(Pipeline.PipelineExecutionException.class);
    @@ -405,26 +404,22 @@ private static List readTable(String tableId, Scan scan) throws Exceptio
     
         // Beam helper methods
         /** Helper function to make a single row mutation to be written. */
    -    private static KV> makeWrite(String key, String value) {
    -        byte[] rowKey = key.getBytes(StandardCharsets.UTF_8);
    +    private static Iterable makeMutations(String key, String value, int numMutations) {
             List mutations = new ArrayList<>();
    -        mutations.add(makeMutation(key, value));
    -        return KV.of(rowKey, (Iterable) mutations);
    +        for (int i = 0; i < numMutations; i++) {
    +            mutations.add(makeMutation(key + i, value));
    +        }
    +        return mutations;
         }
     
    -
         private static Mutation makeMutation(String key, String value) {
    -        byte[] rowKey = key.getBytes(StandardCharsets.UTF_8);
    -        return new Put(rowKey)
    +        return new Put(key.getBytes(StandardCharsets.UTF_8))
                         .addColumn(COLUMN_FAMILY, COLUMN_NAME, Bytes.toBytes(value))
                         .addColumn(COLUMN_FAMILY, COLUMN_EMAIL, Bytes.toBytes(value + "@email.com"));
         }
     
    -    private static KV> makeBadWrite(String key) {
    -        Put put = new Put(key.getBytes());
    -        List mutations = new ArrayList<>();
    -        mutations.add(put);
    -        return KV.of(key.getBytes(StandardCharsets.UTF_8), (Iterable) mutations);
    +    private static Mutation makeBadMutation(String key) {
    +        return new Put(key.getBytes());
         }
     
         private void runReadTest(HBaseIO.Read read, List expected) {
    
    From d42f6333141e85964d009110d8bea85ad4763632 Mon Sep 17 00:00:00 2001
    From: =?UTF-8?q?Isma=C3=ABl=20Mej=C3=ADa?= 
    Date: Tue, 20 Jun 2017 10:00:15 +0200
    Subject: [PATCH 059/200] [BEAM-2411] Add HBaseCoderProviderRegistrar for
     better coder inference
    
    ---
     sdks/java/io/hbase/pom.xml                    |  6 +++
     .../io/hbase/HBaseCoderProviderRegistrar.java | 49 +++++++++++++++++++
     .../org/apache/beam/sdk/io/hbase/HBaseIO.java |  3 --
     .../HBaseCoderProviderRegistrarTest.java      | 41 ++++++++++++++++
     .../apache/beam/sdk/io/hbase/HBaseIOTest.java |  9 ++--
     5 files changed, 100 insertions(+), 8 deletions(-)
     create mode 100644 sdks/java/io/hbase/src/main/java/org/apache/beam/sdk/io/hbase/HBaseCoderProviderRegistrar.java
     create mode 100644 sdks/java/io/hbase/src/test/java/org/apache/beam/sdk/io/hbase/HBaseCoderProviderRegistrarTest.java
    
    diff --git a/sdks/java/io/hbase/pom.xml b/sdks/java/io/hbase/pom.xml
    index f81cd2461dcef..4d9d600f246de 100644
    --- a/sdks/java/io/hbase/pom.xml
    +++ b/sdks/java/io/hbase/pom.xml
    @@ -63,6 +63,12 @@
           beam-sdks-java-io-hadoop-common
         
     
    +    
    +      com.google.auto.service
    +      auto-service
    +      true
    +    
    +
         
           org.apache.hbase
           hbase-shaded-client
    diff --git a/sdks/java/io/hbase/src/main/java/org/apache/beam/sdk/io/hbase/HBaseCoderProviderRegistrar.java b/sdks/java/io/hbase/src/main/java/org/apache/beam/sdk/io/hbase/HBaseCoderProviderRegistrar.java
    new file mode 100644
    index 0000000000000..dee3c703addef
    --- /dev/null
    +++ b/sdks/java/io/hbase/src/main/java/org/apache/beam/sdk/io/hbase/HBaseCoderProviderRegistrar.java
    @@ -0,0 +1,49 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one
    + * or more contributor license agreements.  See the NOTICE file
    + * distributed with this work for additional information
    + * regarding copyright ownership.  The ASF licenses this file
    + * to you under the Apache License, Version 2.0 (the
    + * "License"); you may not use this file except in compliance
    + * with the License.  You may obtain a copy of the License at
    + *
    + *     http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +package org.apache.beam.sdk.io.hbase;
    +
    +import com.google.auto.service.AutoService;
    +import com.google.common.collect.ImmutableList;
    +import java.util.List;
    +import org.apache.beam.sdk.coders.CoderProvider;
    +import org.apache.beam.sdk.coders.CoderProviderRegistrar;
    +import org.apache.beam.sdk.coders.CoderProviders;
    +import org.apache.beam.sdk.values.TypeDescriptor;
    +import org.apache.hadoop.hbase.client.Append;
    +import org.apache.hadoop.hbase.client.Delete;
    +import org.apache.hadoop.hbase.client.Increment;
    +import org.apache.hadoop.hbase.client.Mutation;
    +import org.apache.hadoop.hbase.client.Put;
    +import org.apache.hadoop.hbase.client.Result;
    +
    +/**
    + * A {@link CoderProviderRegistrar} for standard types used with {@link HBaseIO}.
    + */
    +@AutoService(CoderProviderRegistrar.class)
    +public class HBaseCoderProviderRegistrar implements CoderProviderRegistrar {
    +  @Override
    +  public List getCoderProviders() {
    +    return ImmutableList.of(
    +      CoderProviders.forCoder(TypeDescriptor.of(Append.class), HBaseMutationCoder.of()),
    +      CoderProviders.forCoder(TypeDescriptor.of(Delete.class), HBaseMutationCoder.of()),
    +      CoderProviders.forCoder(TypeDescriptor.of(Increment.class), HBaseMutationCoder.of()),
    +      CoderProviders.forCoder(TypeDescriptor.of(Mutation.class), HBaseMutationCoder.of()),
    +      CoderProviders.forCoder(TypeDescriptor.of(Put.class), HBaseMutationCoder.of()),
    +      CoderProviders.forCoder(TypeDescriptor.of(Result.class), HBaseResultCoder.of()));
    +  }
    +}
    diff --git a/sdks/java/io/hbase/src/main/java/org/apache/beam/sdk/io/hbase/HBaseIO.java b/sdks/java/io/hbase/src/main/java/org/apache/beam/sdk/io/hbase/HBaseIO.java
    index 626fad90b4cb4..c9afe8908a5c0 100644
    --- a/sdks/java/io/hbase/src/main/java/org/apache/beam/sdk/io/hbase/HBaseIO.java
    +++ b/sdks/java/io/hbase/src/main/java/org/apache/beam/sdk/io/hbase/HBaseIO.java
    @@ -127,7 +127,6 @@
      * 
    {@code
      * Configuration configuration = ...;
      * PCollection data = ...;
    - * data.setCoder(HBaseIO.WRITE_CODER);
      *
      * data.apply("write",
      *     HBaseIO.write()
    @@ -666,6 +665,4 @@ public void populateDisplayData(DisplayData.Builder builder) {
                 private long recordsWritten;
             }
         }
    -
    -    public static final Coder WRITE_CODER = HBaseMutationCoder.of();
     }
    diff --git a/sdks/java/io/hbase/src/test/java/org/apache/beam/sdk/io/hbase/HBaseCoderProviderRegistrarTest.java b/sdks/java/io/hbase/src/test/java/org/apache/beam/sdk/io/hbase/HBaseCoderProviderRegistrarTest.java
    new file mode 100644
    index 0000000000000..ac81e8a7ad5fa
    --- /dev/null
    +++ b/sdks/java/io/hbase/src/test/java/org/apache/beam/sdk/io/hbase/HBaseCoderProviderRegistrarTest.java
    @@ -0,0 +1,41 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one
    + * or more contributor license agreements.  See the NOTICE file
    + * distributed with this work for additional information
    + * regarding copyright ownership.  The ASF licenses this file
    + * to you under the Apache License, Version 2.0 (the
    + * "License"); you may not use this file except in compliance
    + * with the License.  You may obtain a copy of the License at
    + *
    + *     http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +package org.apache.beam.sdk.io.hbase;
    +
    +import org.apache.beam.sdk.coders.CoderRegistry;
    +import org.apache.hadoop.hbase.client.Mutation;
    +import org.apache.hadoop.hbase.client.Result;
    +import org.junit.Test;
    +import org.junit.runner.RunWith;
    +import org.junit.runners.JUnit4;
    +
    +/**
    + * Tests for {@link HBaseCoderProviderRegistrar}.
    + */
    +@RunWith(JUnit4.class)
    +public class HBaseCoderProviderRegistrarTest {
    +  @Test
    +  public void testResultCoderIsRegistered() throws Exception {
    +    CoderRegistry.createDefault().getCoder(Result.class);
    +  }
    +
    +  @Test
    +  public void testMutationCoderIsRegistered() throws Exception {
    +    CoderRegistry.createDefault().getCoder(Mutation.class);
    +  }
    +}
    diff --git a/sdks/java/io/hbase/src/test/java/org/apache/beam/sdk/io/hbase/HBaseIOTest.java b/sdks/java/io/hbase/src/test/java/org/apache/beam/sdk/io/hbase/HBaseIOTest.java
    index d081139b40dd8..806a27f722841 100644
    --- a/sdks/java/io/hbase/src/test/java/org/apache/beam/sdk/io/hbase/HBaseIOTest.java
    +++ b/sdks/java/io/hbase/src/test/java/org/apache/beam/sdk/io/hbase/HBaseIOTest.java
    @@ -295,8 +295,7 @@ public void testWriting() throws Exception {
     
             createTable(table);
     
    -        p.apply("multiple rows", Create.of(makeMutations(key, value, numMutations))
    -            .withCoder(HBaseIO.WRITE_CODER))
    +        p.apply("multiple rows", Create.of(makeMutations(key, value, numMutations)))
              .apply("write", HBaseIO.write().withConfiguration(conf).withTableId(table));
             p.run().waitUntilFinish();
     
    @@ -309,7 +308,7 @@ public void testWriting() throws Exception {
         public void testWritingFailsTableDoesNotExist() throws Exception {
             final String table = "TEST-TABLE-DOES-NOT-EXIST";
     
    -        p.apply(Create.empty(HBaseIO.WRITE_CODER))
    +        p.apply(Create.empty(HBaseMutationCoder.of()))
              .apply("write", HBaseIO.write().withConfiguration(conf).withTableId(table));
     
             // Exception will be thrown by write.validate() when write is applied.
    @@ -325,8 +324,8 @@ public void testWritingFailsBadElement() throws Exception {
             final String key = "KEY";
             createTable(table);
     
    -        p.apply(Create.of(makeBadMutation(key)).withCoder(HBaseIO.WRITE_CODER))
    -                .apply(HBaseIO.write().withConfiguration(conf).withTableId(table));
    +        p.apply(Create.of(makeBadMutation(key)))
    +         .apply(HBaseIO.write().withConfiguration(conf).withTableId(table));
     
             thrown.expect(Pipeline.PipelineExecutionException.class);
             thrown.expectCause(Matchers.instanceOf(IllegalArgumentException.class));
    
    From 6e4357225477d6beb4cb9735255d1759f4fab168 Mon Sep 17 00:00:00 2001
    From: Eugene Kirpichov 
    Date: Mon, 19 Jun 2017 11:56:29 -0700
    Subject: [PATCH 060/200] Retries http code 0 (usually network error)
    
    ---
     .../apache/beam/sdk/util/RetryHttpRequestInitializer.java    | 5 +++--
     1 file changed, 3 insertions(+), 2 deletions(-)
    
    diff --git a/sdks/java/extensions/google-cloud-platform-core/src/main/java/org/apache/beam/sdk/util/RetryHttpRequestInitializer.java b/sdks/java/extensions/google-cloud-platform-core/src/main/java/org/apache/beam/sdk/util/RetryHttpRequestInitializer.java
    index e5b48d39664f0..a23bee387e2bd 100644
    --- a/sdks/java/extensions/google-cloud-platform-core/src/main/java/org/apache/beam/sdk/util/RetryHttpRequestInitializer.java
    +++ b/sdks/java/extensions/google-cloud-platform-core/src/main/java/org/apache/beam/sdk/util/RetryHttpRequestInitializer.java
    @@ -96,8 +96,9 @@ public LoggingHttpBackoffUnsuccessfulResponseHandler(BackOff backoff,
                 @Override
                 public boolean isRequired(HttpResponse response) {
                   int statusCode = response.getStatusCode();
    -              return (statusCode / 100 == 5) ||  // 5xx: server error
    -                  statusCode == 429;             // 429: Too many requests
    +              return (statusCode == 0) // Code 0 usually means no response / network error
    +                  || (statusCode / 100 == 5) // 5xx: server error
    +                  || statusCode == 429; // 429: Too many requests
                 }
               });
         }
    
    From 5e12e9d75ab78f210b3b024a77c52aaec033218c Mon Sep 17 00:00:00 2001
    From: jasonkuster 
    Date: Tue, 20 Jun 2017 12:05:22 -0700
    Subject: [PATCH 061/200] Remove notifications from JDK versions test.
    
    ---
     .../jenkins/job_beam_PostCommit_Java_JDKVersionsTest.groovy     | 2 ++
     1 file changed, 2 insertions(+)
    
    diff --git a/.test-infra/jenkins/job_beam_PostCommit_Java_JDKVersionsTest.groovy b/.test-infra/jenkins/job_beam_PostCommit_Java_JDKVersionsTest.groovy
    index f23e741a00ef6..df0a2c7a6d3af 100644
    --- a/.test-infra/jenkins/job_beam_PostCommit_Java_JDKVersionsTest.groovy
    +++ b/.test-infra/jenkins/job_beam_PostCommit_Java_JDKVersionsTest.groovy
    @@ -37,6 +37,8 @@ matrixJob('beam_PostCommit_Java_JDK_Versions_Test') {
       common_job_properties.setPostCommit(
           delegate,
           '0 */6 * * *',
    +      false,
    +      '',  // TODO: Remove last two args once test is stable again.
           false)
     
       // Allows triggering this build against pull requests.
    
    From b7ff103f6ee10b07c50ddbd5a49a6a8ce6686087 Mon Sep 17 00:00:00 2001
    From: Eugene Kirpichov 
    Date: Fri, 16 Jun 2017 14:27:51 -0700
    Subject: [PATCH 062/200] Increases backoff in GcsUtil
    
    ---
     .../src/main/java/org/apache/beam/sdk/util/GcsUtil.java         | 2 +-
     1 file changed, 1 insertion(+), 1 deletion(-)
    
    diff --git a/sdks/java/extensions/google-cloud-platform-core/src/main/java/org/apache/beam/sdk/util/GcsUtil.java b/sdks/java/extensions/google-cloud-platform-core/src/main/java/org/apache/beam/sdk/util/GcsUtil.java
    index 8d1fe74ad270d..d7205bf756318 100644
    --- a/sdks/java/extensions/google-cloud-platform-core/src/main/java/org/apache/beam/sdk/util/GcsUtil.java
    +++ b/sdks/java/extensions/google-cloud-platform-core/src/main/java/org/apache/beam/sdk/util/GcsUtil.java
    @@ -135,7 +135,7 @@ public static GcsUtil create(
       private static final int MAX_CONCURRENT_BATCHES = 256;
     
       private static final FluentBackoff BACKOFF_FACTORY =
    -      FluentBackoff.DEFAULT.withMaxRetries(3).withInitialBackoff(Duration.millis(200));
    +      FluentBackoff.DEFAULT.withMaxRetries(10).withInitialBackoff(Duration.standardSeconds(1));
     
       /////////////////////////////////////////////////////////////////////////////
     
    
    From a0523b2dab617d6aee59708a8d8959f42049fce9 Mon Sep 17 00:00:00 2001
    From: Vikas Kedigehalli 
    Date: Mon, 19 Jun 2017 11:24:14 -0700
    Subject: [PATCH 063/200] Fix dataflow runner test to call pipeline.run instead
     of runner.run
    
    ---
     .../apache_beam/runners/dataflow/dataflow_runner_test.py     | 5 +----
     1 file changed, 1 insertion(+), 4 deletions(-)
    
    diff --git a/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py b/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py
    index 6cc5814a5fbcf..a9b8fdb2a24c3 100644
    --- a/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py
    +++ b/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py
    @@ -116,10 +116,7 @@ def test_streaming_create_translation(self):
         self.default_properties.append("--streaming")
         p = Pipeline(remote_runner, PipelineOptions(self.default_properties))
         p | ptransform.Create([1])  # pylint: disable=expression-not-assigned
    -    remote_runner.job = apiclient.Job(p._options)
    -    # Performing configured PTransform overrides here.
    -    p.replace_all(DataflowRunner._PTRANSFORM_OVERRIDES)
    -    super(DataflowRunner, remote_runner).run(p)
    +    p.run()
         job_dict = json.loads(str(remote_runner.job))
         self.assertEqual(len(job_dict[u'steps']), 2)
     
    
    From 08ec0d4dbff330ecd48c806cd764ab5a96835bd9 Mon Sep 17 00:00:00 2001
    From: Robert Bradshaw 
    Date: Tue, 20 Jun 2017 11:01:03 -0700
    Subject: [PATCH 064/200] Port fn_api_runner to be able to use runner protos.
    
    ---
     .../apache_beam/runners/pipeline_context.py   |  17 +-
     .../runners/portability/fn_api_runner.py      | 166 +++++++++++-
     .../runners/portability/fn_api_runner_test.py |  20 +-
     .../apache_beam/runners/worker/sdk_worker.py  | 243 +++++++++++++++++-
     4 files changed, 420 insertions(+), 26 deletions(-)
    
    diff --git a/sdks/python/apache_beam/runners/pipeline_context.py b/sdks/python/apache_beam/runners/pipeline_context.py
    index e212abf8d9fc2..c2ae3f33650dc 100644
    --- a/sdks/python/apache_beam/runners/pipeline_context.py
    +++ b/sdks/python/apache_beam/runners/pipeline_context.py
    @@ -24,6 +24,7 @@
     from apache_beam import pipeline
     from apache_beam import pvalue
     from apache_beam import coders
    +from apache_beam.portability.api import beam_fn_api_pb2
     from apache_beam.portability.api import beam_runner_api_pb2
     from apache_beam.transforms import core
     
    @@ -42,9 +43,10 @@ def __init__(self, context, obj_type, proto_map=None):
         self._id_to_proto = proto_map if proto_map else {}
         self._counter = 0
     
    -  def _unique_ref(self):
    +  def _unique_ref(self, obj=None):
         self._counter += 1
    -    return "ref_%s_%s" % (self._obj_type.__name__, self._counter)
    +    return "ref_%s_%s_%s" % (
    +        self._obj_type.__name__, type(obj).__name__, self._counter)
     
       def populate_map(self, proto_map):
         for id, proto in self._id_to_proto.items():
    @@ -52,7 +54,7 @@ def populate_map(self, proto_map):
     
       def get_id(self, obj):
         if obj not in self._obj_to_id:
    -      id = self._unique_ref()
    +      id = self._unique_ref(obj)
           self._id_to_obj[id] = obj
           self._obj_to_id[obj] = id
           self._id_to_proto[id] = obj.to_runner_api(self._pipeline_context)
    @@ -79,11 +81,16 @@ class PipelineContext(object):
           # TODO: environment
       }
     
    -  def __init__(self, context_proto=None):
    +  def __init__(self, proto=None):
    +    if isinstance(proto, beam_fn_api_pb2.ProcessBundleDescriptor):
    +      proto = beam_runner_api_pb2.Components(
    +          coders=dict(proto.codersyyy.items()),
    +          windowing_strategies=dict(proto.windowing_strategies.items()),
    +          environments=dict(proto.environments.items()))
         for name, cls in self._COMPONENT_TYPES.items():
           setattr(
               self, name, _PipelineContextMap(
    -              self, cls, getattr(context_proto, name, None)))
    +              self, cls, getattr(proto, name, None)))
     
       @staticmethod
       def from_runner_api(proto):
    diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner.py b/sdks/python/apache_beam/runners/portability/fn_api_runner.py
    index d792131a9d2eb..dabb7d687db51 100644
    --- a/sdks/python/apache_beam/runners/portability/fn_api_runner.py
    +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner.py
    @@ -24,9 +24,10 @@
     import threading
     
     from concurrent import futures
    +from google.protobuf import wrappers_pb2
     import grpc
     
    -import apache_beam as beam
    +import apache_beam as beam  # pylint: disable=ungrouped-imports
     from apache_beam.coders import WindowedValueCoder
     from apache_beam.coders.coder_impl import create_InputStream
     from apache_beam.coders.coder_impl import create_OutputStream
    @@ -34,10 +35,13 @@
     from apache_beam.io import iobase
     from apache_beam.transforms.window import GlobalWindows
     from apache_beam.portability.api import beam_fn_api_pb2
    +from apache_beam.portability.api import beam_runner_api_pb2
    +from apache_beam.runners import pipeline_context
     from apache_beam.runners.portability import maptask_executor_runner
     from apache_beam.runners.worker import data_plane
     from apache_beam.runners.worker import operation_specs
     from apache_beam.runners.worker import sdk_worker
    +from apache_beam.utils import proto_utils
     
     # This module is experimental. No backwards-compatibility guarantees.
     
    @@ -110,9 +114,13 @@ def process(self, source):
     
     class FnApiRunner(maptask_executor_runner.MapTaskExecutorRunner):
     
    -  def __init__(self):
    +  def __init__(self, use_runner_protos=False):
         super(FnApiRunner, self).__init__()
         self._last_uid = -1
    +    if use_runner_protos:
    +      self._map_task_to_protos = self._map_task_to_runner_protos
    +    else:
    +      self._map_task_to_protos = self._map_task_to_fn_protos
     
       def has_metrics_support(self):
         return False
    @@ -123,7 +131,140 @@ def _next_uid(self):
     
       def _map_task_registration(self, map_task, state_handler,
                                  data_operation_spec):
    +    input_data, side_input_data, runner_sinks, process_bundle_descriptor = (
    +        self._map_task_to_protos(map_task, data_operation_spec))
    +    # Side inputs will be accessed over the state API.
    +    for key, elements_data in side_input_data.items():
    +      state_key = beam_fn_api_pb2.StateKey.MultimapSideInput(key=key)
    +      state_handler.Clear(state_key)
    +      state_handler.Append(state_key, [elements_data])
    +    return beam_fn_api_pb2.InstructionRequest(
    +        instruction_id=self._next_uid(),
    +        register=beam_fn_api_pb2.RegisterRequest(
    +            process_bundle_descriptor=[process_bundle_descriptor])
    +        ), runner_sinks, input_data
    +
    +  def _map_task_to_runner_protos(self, map_task, data_operation_spec):
    +    input_data = {}
    +    side_input_data = {}
    +    runner_sinks = {}
    +
    +    context = pipeline_context.PipelineContext()
    +    transform_protos = {}
    +    used_pcollections = {}
    +
    +    def uniquify(*names):
    +      # An injective mapping from string* to string.
    +      return ':'.join("%s:%d" % (name, len(name)) for name in names)
    +
    +    def pcollection_id(op_ix, out_ix):
    +      if (op_ix, out_ix) not in used_pcollections:
    +        used_pcollections[op_ix, out_ix] = uniquify(
    +            map_task[op_ix][0], 'out', str(out_ix))
    +      return used_pcollections[op_ix, out_ix]
    +
    +    def get_inputs(op):
    +      if hasattr(op, 'inputs'):
    +        inputs = op.inputs
    +      elif hasattr(op, 'input'):
    +        inputs = [op.input]
    +      else:
    +        inputs = []
    +      return {'in%s' % ix: pcollection_id(*input)
    +              for ix, input in enumerate(inputs)}
    +
    +    def get_outputs(op_ix):
    +      op = map_task[op_ix][1]
    +      return {tag: pcollection_id(op_ix, out_ix)
    +              for out_ix, tag in enumerate(getattr(op, 'output_tags', ['out']))}
    +
    +    for op_ix, (stage_name, operation) in enumerate(map_task):
    +      transform_id = uniquify(stage_name)
    +
    +      if isinstance(operation, operation_specs.WorkerInMemoryWrite):
    +        # Write this data back to the runner.
    +        runner_sinks[(transform_id, 'out')] = operation
    +        transform_spec = beam_runner_api_pb2.FunctionSpec(
    +            urn=sdk_worker.DATA_OUTPUT_URN,
    +            parameter=proto_utils.pack_Any(data_operation_spec))
    +
    +      elif isinstance(operation, operation_specs.WorkerRead):
    +        # A Read from an in-memory source is done over the data plane.
    +        if (isinstance(operation.source.source,
    +                       maptask_executor_runner.InMemorySource)
    +            and isinstance(operation.source.source.default_output_coder(),
    +                           WindowedValueCoder)):
    +          input_data[(transform_id, 'input')] = self._reencode_elements(
    +              operation.source.source.read(None),
    +              operation.source.source.default_output_coder())
    +          transform_spec = beam_runner_api_pb2.FunctionSpec(
    +              urn=sdk_worker.DATA_INPUT_URN,
    +              parameter=proto_utils.pack_Any(data_operation_spec))
    +
    +        else:
    +          # Otherwise serialize the source and execute it there.
    +          # TODO: Use SDFs with an initial impulse.
    +          transform_spec = beam_runner_api_pb2.FunctionSpec(
    +              urn=sdk_worker.PYTHON_SOURCE_URN,
    +              parameter=proto_utils.pack_Any(
    +                  wrappers_pb2.BytesValue(
    +                      value=pickler.dumps(operation.source.source))))
    +
    +      elif isinstance(operation, operation_specs.WorkerDoFn):
    +        # Record the contents of each side input for access via the state api.
    +        side_input_extras = []
    +        for si in operation.side_inputs:
    +          assert isinstance(si.source, iobase.BoundedSource)
    +          element_coder = si.source.default_output_coder()
    +          # TODO(robertwb): Actually flesh out the ViewFn API.
    +          side_input_extras.append((si.tag, element_coder))
    +          side_input_data[sdk_worker.side_input_tag(transform_id, si.tag)] = (
    +              self._reencode_elements(
    +                  si.source.read(si.source.get_range_tracker(None, None)),
    +                  element_coder))
    +        augmented_serialized_fn = pickler.dumps(
    +            (operation.serialized_fn, side_input_extras))
    +        transform_spec = beam_runner_api_pb2.FunctionSpec(
    +            urn=sdk_worker.PYTHON_DOFN_URN,
    +            parameter=proto_utils.pack_Any(
    +                wrappers_pb2.BytesValue(value=augmented_serialized_fn)))
    +
    +      elif isinstance(operation, operation_specs.WorkerFlatten):
    +        # Flatten is nice and simple.
    +        transform_spec = beam_runner_api_pb2.FunctionSpec(
    +            urn=sdk_worker.IDENTITY_DOFN_URN)
    +
    +      else:
    +        raise NotImplementedError(operation)
    +
    +      transform_protos[transform_id] = beam_runner_api_pb2.PTransform(
    +          unique_name=stage_name,
    +          spec=transform_spec,
    +          inputs=get_inputs(operation),
    +          outputs=get_outputs(op_ix))
    +
    +    pcollection_protos = {
    +        name: beam_runner_api_pb2.PCollection(
    +            unique_name=name,
    +            coder_id=context.coders.get_id(
    +                map_task[op_id][1].output_coders[out_id]))
    +        for (op_id, out_id), name in used_pcollections.items()
    +    }
    +    # Must follow creation of pcollection_protos to capture used coders.
    +    context_proto = context.to_runner_api()
    +    process_bundle_descriptor = beam_fn_api_pb2.ProcessBundleDescriptor(
    +        id=self._next_uid(),
    +        transforms=transform_protos,
    +        pcollections=pcollection_protos,
    +        codersyyy=dict(context_proto.coders.items()),
    +        windowing_strategies=dict(context_proto.windowing_strategies.items()),
    +        environments=dict(context_proto.environments.items()))
    +    return input_data, side_input_data, runner_sinks, process_bundle_descriptor
    +
    +  def _map_task_to_fn_protos(self, map_task, data_operation_spec):
    +
         input_data = {}
    +    side_input_data = {}
         runner_sinks = {}
         transforms = []
         transform_index_to_id = {}
    @@ -264,9 +405,7 @@ def outputs(op):
                 element_coder.get_impl().encode_to_stream(
                     element, output_stream, True)
               elements_data = output_stream.get()
    -          state_key = beam_fn_api_pb2.StateKey.MultimapSideInput(key=view_id)
    -          state_handler.Clear(state_key)
    -          state_handler.Append(state_key, elements_data)
    +          side_input_data[view_id] = elements_data
     
           elif isinstance(operation, operation_specs.WorkerFlatten):
             fn = sdk_worker.pack_function_spec_data(
    @@ -299,13 +438,11 @@ def outputs(op):
           transforms.append(ptransform)
     
         process_bundle_descriptor = beam_fn_api_pb2.ProcessBundleDescriptor(
    -        id=self._next_uid(), coders=coders.values(),
    +        id=self._next_uid(),
    +        coders=coders.values(),
             primitive_transform=transforms)
    -    return beam_fn_api_pb2.InstructionRequest(
    -        instruction_id=self._next_uid(),
    -        register=beam_fn_api_pb2.RegisterRequest(
    -            process_bundle_descriptor=[process_bundle_descriptor
    -                                      ])), runner_sinks, input_data
    +
    +    return input_data, side_input_data, runner_sinks, process_bundle_descriptor
     
       def _run_map_task(
           self, map_task, control_handler, state_handler, data_plane_handler,
    @@ -467,3 +604,10 @@ def close(self):
           self.data_plane_handler.close()
           self.control_server.stop(5).wait()
           self.data_server.stop(5).wait()
    +
    +  @staticmethod
    +  def _reencode_elements(elements, element_coder):
    +    output_stream = create_OutputStream()
    +    for element in elements:
    +      element_coder.get_impl().encode_to_stream(element, output_stream, True)
    +    return output_stream.get()
    diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
    index 66d985a9b053c..e2eae26b2179f 100644
    --- a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
    +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
    @@ -23,10 +23,26 @@
     from apache_beam.runners.portability import maptask_executor_runner_test
     
     
    -class FnApiRunnerTest(maptask_executor_runner_test.MapTaskExecutorRunnerTest):
    +class FnApiRunnerTestWithRunnerProtos(
    +    maptask_executor_runner_test.MapTaskExecutorRunnerTest):
     
       def create_pipeline(self):
    -    return beam.Pipeline(runner=fn_api_runner.FnApiRunner())
    +    return beam.Pipeline(
    +        runner=fn_api_runner.FnApiRunner(use_runner_protos=True))
    +
    +  def test_combine_per_key(self):
    +    # TODO(robertwb): Implement PGBKCV operation.
    +    pass
    +
    +  # Inherits all tests from maptask_executor_runner.MapTaskExecutorRunner
    +
    +
    +class FnApiRunnerTestWithFnProtos(
    +    maptask_executor_runner_test.MapTaskExecutorRunnerTest):
    +
    +  def create_pipeline(self):
    +    return beam.Pipeline(
    +        runner=fn_api_runner.FnApiRunner(use_runner_protos=False))
     
       def test_combine_per_key(self):
         # TODO(robertwb): Implement PGBKCV operation.
    diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker.py b/sdks/python/apache_beam/runners/worker/sdk_worker.py
    index d08b1798a94b4..fd7ecc4325a89 100644
    --- a/sdks/python/apache_beam/runners/worker/sdk_worker.py
    +++ b/sdks/python/apache_beam/runners/worker/sdk_worker.py
    @@ -36,11 +36,13 @@
     from apache_beam.coders import WindowedValueCoder
     from apache_beam.internal import pickler
     from apache_beam.io import iobase
    -from apache_beam.runners.dataflow.native_io import iobase as native_iobase
    -from apache_beam.utils import counters
     from apache_beam.portability.api import beam_fn_api_pb2
    +from apache_beam.runners.dataflow.native_io import iobase as native_iobase
    +from apache_beam.runners import pipeline_context
     from apache_beam.runners.worker import operation_specs
     from apache_beam.runners.worker import operations
    +from apache_beam.utils import counters
    +from apache_beam.utils import proto_utils
     
     # This module is experimental. No backwards-compatibility guarantees.
     
    @@ -62,6 +64,10 @@
     PYTHON_SOURCE_URN = 'urn:org.apache.beam:source:java:0.1'
     
     
    +def side_input_tag(transform_id, tag):
    +  return str("%d[%s][%s]" % (len(transform_id), transform_id, tag))
    +
    +
     class RunnerIOOperation(operations.Operation):
       """Common baseclass for runner harness IO operations."""
     
    @@ -208,6 +214,23 @@ def load_compressed(compressed_data):
         dill.dill._trace(False)  # pylint: disable=protected-access
     
     
    +def memoize(func):
    +  cache = {}
    +  missing = object()
    +
    +  def wrapper(*args):
    +    result = cache.get(args, missing)
    +    if result is missing:
    +      result = cache[args] = func(*args)
    +    return result
    +  return wrapper
    +
    +
    +def only_element(iterable):
    +  element, = iterable
    +  return element
    +
    +
     class SdkHarness(object):
     
       def __init__(self, control_channel):
    @@ -296,6 +319,51 @@ def initial_source_split(self, request, unused_instruction_id=None):
         return response
     
       def create_execution_tree(self, descriptor):
    +    if descriptor.primitive_transform:
    +      return self.create_execution_tree_from_fn_api(descriptor)
    +    else:
    +      return self.create_execution_tree_from_runner_api(descriptor)
    +
    +  def create_execution_tree_from_runner_api(self, descriptor):
    +    # TODO(robertwb): Figure out the correct prefix to use for output counters
    +    # from StateSampler.
    +    counter_factory = counters.CounterFactory()
    +    state_sampler = statesampler.StateSampler(
    +        'fnapi-step%s-' % descriptor.id, counter_factory)
    +
    +    transform_factory = BeamTransformFactory(
    +        descriptor, self.data_channel_factory, counter_factory, state_sampler,
    +        self.state_handler)
    +
    +    pcoll_consumers = collections.defaultdict(list)
    +    for transform_id, transform_proto in descriptor.transforms.items():
    +      for pcoll_id in transform_proto.inputs.values():
    +        pcoll_consumers[pcoll_id].append(transform_id)
    +
    +    @memoize
    +    def get_operation(transform_id):
    +      transform_consumers = {
    +          tag: [get_operation(op) for op in pcoll_consumers[pcoll_id]]
    +          for tag, pcoll_id
    +          in descriptor.transforms[transform_id].outputs.items()
    +      }
    +      return transform_factory.create_operation(
    +          transform_id, transform_consumers)
    +
    +    # Operations must be started (hence returned) in order.
    +    @memoize
    +    def topological_height(transform_id):
    +      return 1 + max(
    +          [0] +
    +          [topological_height(consumer)
    +           for pcoll in descriptor.transforms[transform_id].outputs.values()
    +           for consumer in pcoll_consumers[pcoll]])
    +
    +    return [get_operation(transform_id)
    +            for transform_id in sorted(
    +                descriptor.transforms, key=topological_height, reverse=True)]
    +
    +  def create_execution_tree_from_fn_api(self, descriptor):
         # TODO(vikasrk): Add an id field to Coder proto and use that instead.
         coders = {coder.function_spec.id: operation_specs.get_coder_from_spec(
             json.loads(unpack_function_spec_data(coder.function_spec)))
    @@ -418,14 +486,14 @@ def create_side_input(tag, si):
           reversed_ops.append(op)
           ops_by_id[transform.id] = op
     
    -    return list(reversed(reversed_ops)), ops_by_id
    +    return list(reversed(reversed_ops))
     
       def process_bundle(self, request, instruction_id):
    -    ops, ops_by_id = self.create_execution_tree(
    +    ops = self.create_execution_tree(
             self.fns[request.process_bundle_descriptor_reference])
     
         expected_inputs = []
    -    for _, op in ops_by_id.items():
    +    for op in ops:
           if isinstance(op, DataOutputOperation):
             # TODO(robertwb): Is there a better way to pass the instruction id to
             # the operation?
    @@ -445,9 +513,7 @@ def process_bundle(self, request, instruction_id):
           for data in input_op.data_channel.input_elements(
               instruction_id, [input_op.target]):
             # ignores input name
    -        target_op = ops_by_id[data.target.primitive_transform_reference]
    -        # lacks coder for non-input ops
    -        target_op.process_encoded(data.data)
    +        input_op.process_encoded(data.data)
     
         # Finish all operations.
         for op in ops:
    @@ -455,3 +521,164 @@ def process_bundle(self, request, instruction_id):
           op.finish()
     
         return beam_fn_api_pb2.ProcessBundleResponse()
    +
    +
    +class BeamTransformFactory(object):
    +  """Factory for turning transform_protos into executable operations."""
    +  def __init__(self, descriptor, data_channel_factory, counter_factory,
    +               state_sampler, state_handler):
    +    self.descriptor = descriptor
    +    self.data_channel_factory = data_channel_factory
    +    self.counter_factory = counter_factory
    +    self.state_sampler = state_sampler
    +    self.state_handler = state_handler
    +    self.context = pipeline_context.PipelineContext(descriptor)
    +
    +  _known_urns = {}
    +
    +  @classmethod
    +  def register_urn(cls, urn, parameter_type):
    +    def wrapper(func):
    +      cls._known_urns[urn] = func, parameter_type
    +      return func
    +    return wrapper
    +
    +  def create_operation(self, transform_id, consumers):
    +    transform_proto = self.descriptor.transforms[transform_id]
    +    creator, parameter_type = self._known_urns[transform_proto.spec.urn]
    +    parameter = proto_utils.unpack_Any(
    +        transform_proto.spec.parameter, parameter_type)
    +    return creator(self, transform_id, transform_proto, parameter, consumers)
    +
    +  def get_coder(self, coder_id):
    +    return self.context.coders.get_by_id(coder_id)
    +
    +  def get_output_coders(self, transform_proto):
    +    return {
    +        tag: self.get_coder(self.descriptor.pcollections[pcoll_id].coder_id)
    +        for tag, pcoll_id in transform_proto.outputs.items()
    +    }
    +
    +  def get_only_output_coder(self, transform_proto):
    +    return only_element(self.get_output_coders(transform_proto).values())
    +
    +  def get_input_coders(self, transform_proto):
    +    return {
    +        tag: self.get_coder(self.descriptor.pcollections[pcoll_id].coder_id)
    +        for tag, pcoll_id in transform_proto.inputs.items()
    +    }
    +
    +  def get_only_input_coder(self, transform_proto):
    +    return only_element(self.get_input_coders(transform_proto).values())
    +
    +  # TODO(robertwb): Update all operations to take these in the constructor.
    +  @staticmethod
    +  def augment_oldstyle_op(op, step_name, consumers, tag_list=None):
    +    op.step_name = step_name
    +    for tag, op_consumers in consumers.items():
    +      for consumer in op_consumers:
    +        op.add_receiver(consumer, tag_list.index(tag) if tag_list else 0)
    +    return op
    +
    +
    +@BeamTransformFactory.register_urn(
    +    DATA_INPUT_URN, beam_fn_api_pb2.RemoteGrpcPort)
    +def create(factory, transform_id, transform_proto, grpc_port, consumers):
    +  target = beam_fn_api_pb2.Target(
    +      primitive_transform_reference=transform_id,
    +      name=only_element(transform_proto.outputs.keys()))
    +  return DataInputOperation(
    +      transform_proto.unique_name,
    +      transform_proto.unique_name,
    +      consumers,
    +      factory.counter_factory,
    +      factory.state_sampler,
    +      factory.get_only_output_coder(transform_proto),
    +      input_target=target,
    +      data_channel=factory.data_channel_factory.create_data_channel(grpc_port))
    +
    +
    +@BeamTransformFactory.register_urn(
    +    DATA_OUTPUT_URN, beam_fn_api_pb2.RemoteGrpcPort)
    +def create(factory, transform_id, transform_proto, grpc_port, consumers):
    +  target = beam_fn_api_pb2.Target(
    +      primitive_transform_reference=transform_id,
    +      name='out')
    +  return DataOutputOperation(
    +      transform_proto.unique_name,
    +      transform_proto.unique_name,
    +      consumers,
    +      factory.counter_factory,
    +      factory.state_sampler,
    +      # TODO(robertwb): Perhaps this could be distinct from the input coder?
    +      factory.get_only_input_coder(transform_proto),
    +      target=target,
    +      data_channel=factory.data_channel_factory.create_data_channel(grpc_port))
    +
    +
    +@BeamTransformFactory.register_urn(PYTHON_SOURCE_URN, wrappers_pb2.BytesValue)
    +def create(factory, transform_id, transform_proto, parameter, consumers):
    +  source = pickler.loads(parameter.value)
    +  spec = operation_specs.WorkerRead(
    +      iobase.SourceBundle(1.0, source, None, None),
    +      [WindowedValueCoder(source.default_output_coder())])
    +  return factory.augment_oldstyle_op(
    +      operations.ReadOperation(
    +          transform_proto.unique_name,
    +          spec,
    +          factory.counter_factory,
    +          factory.state_sampler),
    +      transform_proto.unique_name,
    +      consumers)
    +
    +
    +@BeamTransformFactory.register_urn(PYTHON_DOFN_URN, wrappers_pb2.BytesValue)
    +def create(factory, transform_id, transform_proto, parameter, consumers):
    +  dofn_data = pickler.loads(parameter.value)
    +  if len(dofn_data) == 2:
    +    # Has side input data.
    +    serialized_fn, side_input_data = dofn_data
    +  else:
    +    # No side input data.
    +    serialized_fn, side_input_data = parameter.value, []
    +
    +  def create_side_input(tag, coder):
    +    # TODO(robertwb): Extract windows (and keys) out of element data.
    +    # TODO(robertwb): Extract state key from ParDoPayload.
    +    return operation_specs.WorkerSideInputSource(
    +        tag=tag,
    +        source=SideInputSource(
    +            factory.state_handler,
    +            beam_fn_api_pb2.StateKey.MultimapSideInput(
    +                key=side_input_tag(transform_id, tag)),
    +            coder=coder))
    +  output_tags = list(transform_proto.outputs.keys())
    +  output_coders = factory.get_output_coders(transform_proto)
    +  spec = operation_specs.WorkerDoFn(
    +      serialized_fn=serialized_fn,
    +      output_tags=output_tags,
    +      input=None,
    +      side_inputs=[
    +          create_side_input(tag, coder) for tag, coder in side_input_data],
    +      output_coders=[output_coders[tag] for tag in output_tags])
    +  return factory.augment_oldstyle_op(
    +      operations.DoOperation(
    +          transform_proto.unique_name,
    +          spec,
    +          factory.counter_factory,
    +          factory.state_sampler),
    +      transform_proto.unique_name,
    +      consumers,
    +      output_tags)
    +
    +
    +@BeamTransformFactory.register_urn(IDENTITY_DOFN_URN, None)
    +def create(factory, transform_id, transform_proto, unused_parameter, consumers):
    +  return factory.augment_oldstyle_op(
    +      operations.FlattenOperation(
    +          transform_proto.unique_name,
    +          None,
    +          factory.counter_factory,
    +          factory.state_sampler),
    +      transform_proto.unique_name,
    +      consumers)
    
    From f69e3b53fafa4b79b21095d4b65edbe7cfeb7d2a Mon Sep 17 00:00:00 2001
    From: Pei He 
    Date: Mon, 19 Jun 2017 15:55:48 -0700
    Subject: [PATCH 065/200] FlinkRunner: remove the unused
     ReflectiveOneToOneOverrideFactory.
    
    ---
     .../FlinkStreamingPipelineTranslator.java     | 31 -------------------
     1 file changed, 31 deletions(-)
    
    diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPipelineTranslator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPipelineTranslator.java
    index a88ff071fcaca..d768b01146e1f 100644
    --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPipelineTranslator.java
    +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPipelineTranslator.java
    @@ -24,7 +24,6 @@
     import org.apache.beam.runners.core.construction.PTransformMatchers;
     import org.apache.beam.runners.core.construction.PTransformReplacements;
     import org.apache.beam.runners.core.construction.ReplacementOutputs;
    -import org.apache.beam.runners.core.construction.SingleInputOutputOverrideFactory;
     import org.apache.beam.runners.core.construction.SplittableParDo;
     import org.apache.beam.runners.core.construction.UnconsumedReads;
     import org.apache.beam.sdk.Pipeline;
    @@ -36,7 +35,6 @@
     import org.apache.beam.sdk.transforms.PTransform;
     import org.apache.beam.sdk.transforms.ParDo.MultiOutput;
     import org.apache.beam.sdk.transforms.View.CreatePCollectionView;
    -import org.apache.beam.sdk.util.InstanceBuilder;
     import org.apache.beam.sdk.values.PCollection;
     import org.apache.beam.sdk.values.PCollectionTuple;
     import org.apache.beam.sdk.values.PValue;
    @@ -198,35 +196,6 @@ boolean canTranslate(T transform, FlinkStreamingTranslationContext context) {
         }
       }
     
    -  private static class ReflectiveOneToOneOverrideFactory<
    -          InputT, OutputT, TransformT extends PTransform, PCollection>>
    -      extends SingleInputOutputOverrideFactory<
    -          PCollection, PCollection, TransformT> {
    -    private final Class, PCollection>> replacement;
    -    private final FlinkRunner runner;
    -
    -    private ReflectiveOneToOneOverrideFactory(
    -        Class, PCollection>> replacement,
    -        FlinkRunner runner) {
    -      this.replacement = replacement;
    -      this.runner = runner;
    -    }
    -
    -    @Override
    -    public PTransformReplacement, PCollection> getReplacementTransform(
    -        AppliedPTransform, PCollection, TransformT> transform) {
    -      return PTransformReplacement.of(
    -          PTransformReplacements.getSingletonMainInput(transform),
    -          InstanceBuilder.ofType(replacement)
    -              .withArg(FlinkRunner.class, runner)
    -              .withArg(
    -                  (Class, PCollection>>)
    -                      transform.getTransform().getClass(),
    -                  transform.getTransform())
    -              .build());
    -    }
    -  }
    -
       /**
        * A {@link PTransformOverrideFactory} that overrides a Splittable DoFn with {@link SplittableParDo}.
    
    From 52794096aa8b4d614423fd787835f5b89b1ea1ac Mon Sep 17 00:00:00 2001
    From: Pei He 
    Date: Mon, 19 Jun 2017 16:10:02 -0700
    Subject: [PATCH 066/200] Flink runner: refactor the translator into two
     phases: rewriting and translating.
    
    ---
     .../FlinkPipelineExecutionEnvironment.java    |  2 +
     .../FlinkStreamingPipelineTranslator.java     | 23 --------
     .../flink/FlinkTransformOverrides.java        | 53 +++++++++++++++++++
     3 files changed, 55 insertions(+), 23 deletions(-)
     create mode 100644 runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkTransformOverrides.java
    
    diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkPipelineExecutionEnvironment.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkPipelineExecutionEnvironment.java
    index fe5dd87e92a92..d2a2016c98a00 100644
    --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkPipelineExecutionEnvironment.java
    +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkPipelineExecutionEnvironment.java
    @@ -84,6 +84,8 @@ public void translate(FlinkRunner flinkRunner, Pipeline pipeline) {
         this.flinkBatchEnv = null;
         this.flinkStreamEnv = null;
     
    +    pipeline.replaceAll(FlinkTransformOverrides.getDefaultOverrides(options.isStreaming()));
    +
         PipelineTranslationOptimizer optimizer =
             new PipelineTranslationOptimizer(TranslationMode.BATCH, options);
     
    diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPipelineTranslator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPipelineTranslator.java
    index d768b01146e1f..27bb4ecfb9fe4 100644
    --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPipelineTranslator.java
    +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPipelineTranslator.java
    @@ -17,11 +17,7 @@
      */
     package org.apache.beam.runners.flink;
     
    -import com.google.common.collect.ImmutableList;
    -import java.util.List;
     import java.util.Map;
    -import org.apache.beam.runners.core.SplittableParDoViaKeyedWorkItems;
    -import org.apache.beam.runners.core.construction.PTransformMatchers;
     import org.apache.beam.runners.core.construction.PTransformReplacements;
     import org.apache.beam.runners.core.construction.ReplacementOutputs;
     import org.apache.beam.runners.core.construction.SplittableParDo;
    @@ -29,12 +25,10 @@
     import org.apache.beam.sdk.Pipeline;
     import org.apache.beam.sdk.options.PipelineOptions;
     import org.apache.beam.sdk.runners.AppliedPTransform;
    -import org.apache.beam.sdk.runners.PTransformOverride;
     import org.apache.beam.sdk.runners.PTransformOverrideFactory;
     import org.apache.beam.sdk.runners.TransformHierarchy;
     import org.apache.beam.sdk.transforms.PTransform;
     import org.apache.beam.sdk.transforms.ParDo.MultiOutput;
    -import org.apache.beam.sdk.transforms.View.CreatePCollectionView;
     import org.apache.beam.sdk.values.PCollection;
     import org.apache.beam.sdk.values.PCollectionTuple;
     import org.apache.beam.sdk.values.PValue;
    @@ -70,25 +64,8 @@ public FlinkStreamingPipelineTranslator(
     
       @Override
       public void translate(Pipeline pipeline) {
    -    List transformOverrides =
    -        ImmutableList.builder()
    -            .add(
    -                PTransformOverride.of(
    -                    PTransformMatchers.splittableParDoMulti(),
    -                    new SplittableParDoOverrideFactory()))
    -            .add(
    -                PTransformOverride.of(
    -                    PTransformMatchers.classEqualTo(SplittableParDo.ProcessKeyedElements.class),
    -                    new SplittableParDoViaKeyedWorkItems.OverrideFactory()))
    -            .add(
    -                PTransformOverride.of(
    -                    PTransformMatchers.classEqualTo(CreatePCollectionView.class),
    -                    new CreateStreamingFlinkView.Factory()))
    -            .build();
    -
         // Ensure all outputs of all reads are consumed.
         UnconsumedReads.ensureAllReadsConsumed(pipeline);
    -    pipeline.replaceAll(transformOverrides);
         super.translate(pipeline);
       }
     
    diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkTransformOverrides.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkTransformOverrides.java
    new file mode 100644
    index 0000000000000..1dc8de9101381
    --- /dev/null
    +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkTransformOverrides.java
    @@ -0,0 +1,53 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one
    + * or more contributor license agreements.  See the NOTICE file
    + * distributed with this work for additional information
    + * regarding copyright ownership.  The ASF licenses this file
    + * to you under the Apache License, Version 2.0 (the
    + * "License"); you may not use this file except in compliance
    + * with the License.  You may obtain a copy of the License at
    + *
    + *     http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +package org.apache.beam.runners.flink;
    +
    +import com.google.common.collect.ImmutableList;
    +import java.util.List;
    +import org.apache.beam.runners.core.SplittableParDoViaKeyedWorkItems;
    +import org.apache.beam.runners.core.construction.PTransformMatchers;
    +import org.apache.beam.runners.core.construction.SplittableParDo;
    +import org.apache.beam.sdk.runners.PTransformOverride;
    +import org.apache.beam.sdk.transforms.PTransform;
    +import org.apache.beam.sdk.transforms.View;
    +
    +/**
    + * {@link PTransform} overrides for Flink runner.
    + */
    +public class FlinkTransformOverrides {
    +  public static List getDefaultOverrides(boolean streaming) {
    +    if (streaming) {
    +      return ImmutableList.builder()
    +          .add(
    +              PTransformOverride.of(
    +                  PTransformMatchers.splittableParDoMulti(),
    +                  new FlinkStreamingPipelineTranslator.SplittableParDoOverrideFactory()))
    +          .add(
    +              PTransformOverride.of(
    +                  PTransformMatchers.classEqualTo(SplittableParDo.ProcessKeyedElements.class),
    +                  new SplittableParDoViaKeyedWorkItems.OverrideFactory()))
    +          .add(
    +              PTransformOverride.of(
    +                  PTransformMatchers.classEqualTo(View.CreatePCollectionView.class),
    +                  new CreateStreamingFlinkView.Factory()))
    +          .build();
    +    } else {
    +      return ImmutableList.of();
    +    }
    +  }
    +}
    
    From 42a2de91adf1387bb8eaf9aa515a24f6f276bf40 Mon Sep 17 00:00:00 2001
    From: Mairbek Khadikov 
    Date: Wed, 14 Jun 2017 13:03:36 -0700
    Subject: [PATCH 067/200] Support ValueProviders in SpannerIO.Write
    
    ---
     .../beam/sdk/io/gcp/spanner/SpannerIO.java    | 31 +++++++++++++------
     .../sdk/io/gcp/spanner/SpannerIOTest.java     | 21 +++++++++++++
     2 files changed, 43 insertions(+), 9 deletions(-)
    
    diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java
    index af5253ba1f3b9..8bfc247adda8c 100644
    --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java
    +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java
    @@ -37,6 +37,7 @@
     
     import org.apache.beam.sdk.annotations.Experimental;
     import org.apache.beam.sdk.options.PipelineOptions;
    +import org.apache.beam.sdk.options.ValueProvider;
     import org.apache.beam.sdk.transforms.DoFn;
     import org.apache.beam.sdk.transforms.PTransform;
     import org.apache.beam.sdk.transforms.ParDo;
    @@ -123,13 +124,13 @@ public static Write write() {
       public abstract static class Write extends PTransform, PDone> {
     
         @Nullable
    -    abstract String getProjectId();
    +    abstract ValueProvider getProjectId();
     
         @Nullable
    -    abstract String getInstanceId();
    +    abstract ValueProvider getInstanceId();
     
         @Nullable
    -    abstract String getDatabaseId();
    +    abstract ValueProvider getDatabaseId();
     
         abstract long getBatchSizeBytes();
     
    @@ -142,11 +143,11 @@ public abstract static class Write extends PTransform, PDo
         @AutoValue.Builder
         abstract static class Builder {
     
    -      abstract Builder setProjectId(String projectId);
    +      abstract Builder setProjectId(ValueProvider projectId);
     
    -      abstract Builder setInstanceId(String instanceId);
    +      abstract Builder setInstanceId(ValueProvider instanceId);
     
    -      abstract Builder setDatabaseId(String databaseId);
    +      abstract Builder setDatabaseId(ValueProvider databaseId);
     
           abstract Builder setBatchSizeBytes(long batchSizeBytes);
     
    @@ -162,6 +163,10 @@ abstract static class Builder {
          * 

    Does not modify this object. */ public Write withProjectId(String projectId) { + return withProjectId(ValueProvider.StaticValueProvider.of(projectId)); + } + + public Write withProjectId(ValueProvider projectId) { return toBuilder().setProjectId(projectId).build(); } @@ -172,6 +177,10 @@ public Write withProjectId(String projectId) { *

    Does not modify this object. */ public Write withInstanceId(String instanceId) { + return withInstanceId(ValueProvider.StaticValueProvider.of(instanceId)); + } + + public Write withInstanceId(ValueProvider instanceId) { return toBuilder().setInstanceId(instanceId).build(); } @@ -191,6 +200,10 @@ public Write withBatchSizeBytes(long batchSizeBytes) { *

    Does not modify this object. */ public Write withDatabaseId(String databaseId) { + return withDatabaseId(ValueProvider.StaticValueProvider.of(databaseId)); + } + + public Write withDatabaseId(ValueProvider databaseId) { return toBuilder().setDatabaseId(databaseId).build(); } @@ -291,7 +304,7 @@ public void setup() throws Exception { SpannerOptions spannerOptions = getSpannerOptions(); spanner = spannerOptions.getService(); dbClient = spanner.getDatabaseClient( - DatabaseId.of(projectId(), spec.getInstanceId(), spec.getDatabaseId())); + DatabaseId.of(projectId(), spec.getInstanceId().get(), spec.getDatabaseId().get())); mutations = new ArrayList<>(); batchSizeBytes = 0; } @@ -309,7 +322,7 @@ public void processElement(ProcessContext c) throws Exception { private String projectId() { return spec.getProjectId() == null ? ServiceOptions.getDefaultProjectId() - : spec.getProjectId(); + : spec.getProjectId().get(); } @FinishBundle @@ -334,7 +347,7 @@ private SpannerOptions getSpannerOptions() { spannerOptionsBuider.setServiceFactory(spec.getServiceFactory()); } if (spec.getProjectId() != null) { - spannerOptionsBuider.setProjectId(spec.getProjectId()); + spannerOptionsBuider.setProjectId(spec.getProjectId().get()); } return spannerOptionsBuider.build(); } diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOTest.java index 4a759fb119173..1e19a59c4849f 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOTest.java @@ -17,6 +17,9 @@ */ package org.apache.beam.sdk.io.gcp.spanner; +import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem; +import static org.hamcrest.Matchers.hasSize; +import static org.junit.Assert.assertThat; import static org.mockito.Mockito.argThat; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; @@ -42,6 +45,7 @@ import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.DoFnTester; +import org.apache.beam.sdk.transforms.display.DisplayData; import org.apache.beam.sdk.values.PCollection; import org.junit.Before; import org.junit.Rule; @@ -231,6 +235,23 @@ public void groups() throws Exception { .writeAtLeastOnce(argThat(new IterableOfSize(3))); } + @Test + public void displayData() throws Exception { + SpannerIO.Write write = + SpannerIO.write() + .withProjectId("test-project") + .withInstanceId("test-instance") + .withDatabaseId("test-database") + .withBatchSizeBytes(123); + + DisplayData data = DisplayData.from(write); + assertThat(data.items(), hasSize(4)); + assertThat(data, hasDisplayItem("projectId", "test-project")); + assertThat(data, hasDisplayItem("instanceId", "test-instance")); + assertThat(data, hasDisplayItem("databaseId", "test-database")); + assertThat(data, hasDisplayItem("batchSizeBytes", 123)); + } + private static class FakeServiceFactory implements ServiceFactory, Serializable { // Marked as static so they could be returned by serviceFactory, which is serializable. From 69b01a6118702277348d2f625af669225c9ed99e Mon Sep 17 00:00:00 2001 From: Reuven Lax Date: Sat, 13 May 2017 12:53:08 -0700 Subject: [PATCH 068/200] Add spilling code to WriteFiles. --- runners/direct-java/pom.xml | 3 +- .../beam/runners/direct/DirectRunner.java | 28 ++-- .../org/apache/beam/sdk/io/WriteFiles.java | 133 ++++++++++++++---- .../beam/sdk/testing/TestPipelineOptions.java | 10 ++ .../org/apache/beam/sdk/io/SimpleSink.java | 4 + .../apache/beam/sdk/io/WriteFilesTest.java | 89 +++++++++--- 6 files changed, 209 insertions(+), 58 deletions(-) diff --git a/runners/direct-java/pom.xml b/runners/direct-java/pom.xml index bec21139d9895..63465757ff14a 100644 --- a/runners/direct-java/pom.xml +++ b/runners/direct-java/pom.xml @@ -155,7 +155,8 @@ [ - "--runner=DirectRunner" + "--runner=DirectRunner", + "--unitTest" ] diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java index 136ccf3bd86a5..a16e24dc262bf 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java @@ -43,6 +43,7 @@ import org.apache.beam.sdk.metrics.MetricsEnvironment; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.runners.PTransformOverride; +import org.apache.beam.sdk.testing.TestPipelineOptions; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.ParDo.MultiOutput; @@ -221,15 +222,18 @@ public DirectPipelineResult run(Pipeline pipeline) { @SuppressWarnings("rawtypes") @VisibleForTesting List defaultTransformOverrides() { - return ImmutableList.builder() - .add( - PTransformOverride.of( - PTransformMatchers.writeWithRunnerDeterminedSharding(), - new WriteWithShardingFactory())) /* Uses a view internally. */ - .add( - PTransformOverride.of( - PTransformMatchers.urnEqualTo(PTransformTranslation.CREATE_VIEW_TRANSFORM_URN), - new ViewOverrideFactory())) /* Uses pardos and GBKs */ + TestPipelineOptions testOptions = options.as(TestPipelineOptions.class); + ImmutableList.Builder builder = ImmutableList.builder(); + if (!testOptions.isUnitTest()) { + builder.add( + PTransformOverride.of( + PTransformMatchers.writeWithRunnerDeterminedSharding(), + new WriteWithShardingFactory())); /* Uses a view internally. */ + } + builder = builder.add( + PTransformOverride.of( + PTransformMatchers.urnEqualTo(PTransformTranslation.CREATE_VIEW_TRANSFORM_URN), + new ViewOverrideFactory())) /* Uses pardos and GBKs */ .add( PTransformOverride.of( PTransformMatchers.urnEqualTo(PTransformTranslation.TEST_STREAM_TRANSFORM_URN), @@ -254,9 +258,9 @@ List defaultTransformOverrides() { new DirectGBKIntoKeyedWorkItemsOverrideFactory())) /* Returns a GBKO */ .add( PTransformOverride.of( - PTransformMatchers.urnEqualTo(PTransformTranslation.GROUP_BY_KEY_TRANSFORM_URN), - new DirectGroupByKeyOverrideFactory())) /* returns two chained primitives. */ - .build(); + PTransformMatchers.urnEqualTo(PTransformTranslation.GROUP_BY_KEY_TRANSFORM_URN), + new DirectGroupByKeyOverrideFactory())); /* returns two chained primitives. */ + return builder.build(); } /** diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/WriteFiles.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/WriteFiles.java index 2fd10ace994af..a220eabfe42d9 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/WriteFiles.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/WriteFiles.java @@ -32,6 +32,7 @@ import org.apache.beam.sdk.annotations.Experimental; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.coders.VarIntCoder; import org.apache.beam.sdk.coders.VoidCoder; import org.apache.beam.sdk.io.FileBasedSink.FileResult; import org.apache.beam.sdk.io.FileBasedSink.FileResultCoder; @@ -42,6 +43,7 @@ import org.apache.beam.sdk.options.ValueProvider.StaticValueProvider; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.Flatten; import org.apache.beam.sdk.transforms.GroupByKey; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; @@ -56,8 +58,12 @@ import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollection.IsBounded; +import org.apache.beam.sdk.values.PCollectionList; +import org.apache.beam.sdk.values.PCollectionTuple; import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.PDone; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.TupleTagList; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -86,6 +92,18 @@ public class WriteFiles extends PTransform, PDone> { private static final Logger LOG = LoggerFactory.getLogger(WriteFiles.class); + // The maximum number of file writers to keep open in a single bundle at a time, since file + // writers default to 64mb buffers. This comes into play when writing per-window files. + // The first 20 files from a single WriteFiles transform will write files inline in the + // transform. Anything beyond that might be shuffled. + // Keep in mind that specific runners may decide to run multiple bundles in parallel, based on + // their own policy. + private static final int DEFAULT_MAX_NUM_WRITERS_PER_BUNDLE = 20; + + // When we spill records, shard the output keys to prevent hotspots. + // We could consider making this a parameter. + private static final int SPILLED_RECORD_SHARDING_FACTOR = 10; + static final int UNKNOWN_SHARDNUM = -1; private FileBasedSink sink; private WriteOperation writeOperation; @@ -98,6 +116,7 @@ public class WriteFiles extends PTransform, PDone> { @Nullable private final ValueProvider numShardsProvider; private final boolean windowedWrites; + private int maxNumWritersPerBundle; /** * Creates a {@link WriteFiles} transform that writes to the given {@link FileBasedSink}, letting @@ -105,18 +124,21 @@ public class WriteFiles extends PTransform, PDone> { */ public static WriteFiles to(FileBasedSink sink) { checkNotNull(sink, "sink"); - return new WriteFiles<>(sink, null /* runner-determined sharding */, null, false); + return new WriteFiles<>(sink, null /* runner-determined sharding */, null, + false, DEFAULT_MAX_NUM_WRITERS_PER_BUNDLE); } private WriteFiles( FileBasedSink sink, @Nullable PTransform, PCollectionView> computeNumShards, @Nullable ValueProvider numShardsProvider, - boolean windowedWrites) { + boolean windowedWrites, + int maxNumWritersPerBundle) { this.sink = sink; this.computeNumShards = computeNumShards; this.numShardsProvider = numShardsProvider; this.windowedWrites = windowedWrites; + this.maxNumWritersPerBundle = maxNumWritersPerBundle; } @Override @@ -213,7 +235,16 @@ public WriteFiles withNumShards(int numShards) { * more information. */ public WriteFiles withNumShards(ValueProvider numShardsProvider) { - return new WriteFiles<>(sink, null, numShardsProvider, windowedWrites); + return new WriteFiles<>(sink, null, numShardsProvider, windowedWrites, + maxNumWritersPerBundle); + } + + /** + * Set the maximum number of writers created in a bundle before spilling to shuffle. + */ + public WriteFiles withMaxNumWritersPerBundle(int maxNumWritersPerBundle) { + return new WriteFiles<>(sink, null, numShardsProvider, windowedWrites, + maxNumWritersPerBundle); } /** @@ -226,7 +257,7 @@ public WriteFiles withNumShards(ValueProvider numShardsProvider) { public WriteFiles withSharding(PTransform, PCollectionView> sharding) { checkNotNull( sharding, "Cannot provide null sharding. Use withRunnerDeterminedSharding() instead"); - return new WriteFiles<>(sink, sharding, null, windowedWrites); + return new WriteFiles<>(sink, sharding, null, windowedWrites, maxNumWritersPerBundle); } /** @@ -234,7 +265,7 @@ public WriteFiles withSharding(PTransform, PCollectionView withRunnerDeterminedSharding() { - return new WriteFiles<>(sink, null, null, windowedWrites); + return new WriteFiles<>(sink, null, null, windowedWrites, maxNumWritersPerBundle); } /** @@ -252,7 +283,8 @@ public WriteFiles withRunnerDeterminedSharding() { * positive value. */ public WriteFiles withWindowedWrites() { - return new WriteFiles<>(sink, computeNumShards, numShardsProvider, true); + return new WriteFiles<>(sink, computeNumShards, numShardsProvider, true, + maxNumWritersPerBundle); } /** @@ -260,7 +292,13 @@ public WriteFiles withWindowedWrites() { * {@link WriteOperation} associated with the {@link FileBasedSink} with windowed writes enabled. */ private class WriteWindowedBundles extends DoFn { + private final TupleTag> unwrittedRecordsTag; private Map, Writer> windowedWriters; + int spilledShardNum = UNKNOWN_SHARDNUM; + + WriteWindowedBundles(TupleTag> unwrittedRecordsTag) { + this.unwrittedRecordsTag = unwrittedRecordsTag; + } @StartBundle public void startBundle(StartBundleContext c) { @@ -277,19 +315,28 @@ public void processElement(ProcessContext c, BoundedWindow window) throws Except KV key = KV.of(window, paneInfo); writer = windowedWriters.get(key); if (writer == null) { - String uuid = UUID.randomUUID().toString(); - LOG.info( - "Opening writer {} for write operation {}, window {} pane {}", - uuid, - writeOperation, - window, - paneInfo); - writer = writeOperation.createWriter(); - writer.openWindowed(uuid, window, paneInfo, UNKNOWN_SHARDNUM); - windowedWriters.put(key, writer); - LOG.debug("Done opening writer"); + if (windowedWriters.size() <= maxNumWritersPerBundle) { + String uuid = UUID.randomUUID().toString(); + LOG.info( + "Opening writer {} for write operation {}, window {} pane {}", + uuid, + writeOperation, + window, + paneInfo); + writer = writeOperation.createWriter(); + writer.openWindowed(uuid, window, paneInfo, UNKNOWN_SHARDNUM); + windowedWriters.put(key, writer); + LOG.debug("Done opening writer"); + } else { + if (spilledShardNum == UNKNOWN_SHARDNUM) { + spilledShardNum = ThreadLocalRandom.current().nextInt(SPILLED_RECORD_SHARDING_FACTOR); + } else { + spilledShardNum = (spilledShardNum + 1) % SPILLED_RECORD_SHARDING_FACTOR; + } + c.output(unwrittedRecordsTag, KV.of(spilledShardNum, c.element())); + return; + } } - writeOrClose(writer, c.element()); } @@ -352,11 +399,17 @@ public void populateDisplayData(DisplayData.Builder builder) { } } + enum ShardAssignment { ASSIGN_IN_FINALIZE, ASSIGN_WHEN_WRITING }; + /** * Like {@link WriteWindowedBundles} and {@link WriteUnwindowedBundles}, but where the elements * for each shard have been collected into a single iterable. */ private class WriteShardedBundles extends DoFn>, FileResult> { + ShardAssignment shardNumberAssignment; + WriteShardedBundles(ShardAssignment shardNumberAssignment) { + this.shardNumberAssignment = shardNumberAssignment; + } @ProcessElement public void processElement(ProcessContext c, BoundedWindow window) throws Exception { // In a sharded write, single input element represents one shard. We can open and close @@ -364,7 +417,9 @@ public void processElement(ProcessContext c, BoundedWindow window) throws Except LOG.info("Opening writer for write operation {}", writeOperation); Writer writer = writeOperation.createWriter(); if (windowedWrites) { - writer.openWindowed(UUID.randomUUID().toString(), window, c.pane(), c.element().getKey()); + int shardNumber = shardNumberAssignment == ShardAssignment.ASSIGN_WHEN_WRITING + ? c.element().getKey() : UNKNOWN_SHARDNUM; + writer.openWindowed(UUID.randomUUID().toString(), window, c.pane(), shardNumber); } else { writer.openUnwindowed(UUID.randomUUID().toString(), UNKNOWN_SHARDNUM); } @@ -493,14 +548,35 @@ private PDone createWrite(PCollection input) { // initial ParDo. PCollection results; final PCollectionView numShardsView; + @SuppressWarnings("unchecked") Coder shardedWindowCoder = (Coder) input.getWindowingStrategy().getWindowFn().windowCoder(); if (computeNumShards == null && numShardsProvider == null) { numShardsView = null; - results = - input.apply( - "WriteBundles", - ParDo.of(windowedWrites ? new WriteWindowedBundles() : new WriteUnwindowedBundles())); + if (windowedWrites) { + TupleTag writtenRecordsTag = new TupleTag<>("writtenRecordsTag"); + TupleTag> unwrittedRecordsTag = new TupleTag<>("unwrittenRecordsTag"); + PCollectionTuple writeTuple = input.apply("WriteWindowedBundles", ParDo.of( + new WriteWindowedBundles(unwrittedRecordsTag)) + .withOutputTags(writtenRecordsTag, TupleTagList.of(unwrittedRecordsTag))); + PCollection writtenBundleFiles = writeTuple.get(writtenRecordsTag) + .setCoder(FileResultCoder.of(shardedWindowCoder)); + // Any "spilled" elements are written using WriteShardedBundles. Assign shard numbers in + // finalize to stay consistent with what WriteWindowedBundles does. + PCollection writtenGroupedFiles = + writeTuple + .get(unwrittedRecordsTag) + .setCoder(KvCoder.of(VarIntCoder.of(), input.getCoder())) + .apply("GroupUnwritten", GroupByKey.create()) + .apply("WriteUnwritten", ParDo.of( + new WriteShardedBundles(ShardAssignment.ASSIGN_IN_FINALIZE))) + .setCoder(FileResultCoder.of(shardedWindowCoder)); + results = PCollectionList.of(writtenBundleFiles).and(writtenGroupedFiles) + .apply(Flatten.pCollections()); + } else { + results = + input.apply("WriteUnwindowedBundles", ParDo.of(new WriteUnwindowedBundles())); + } } else { List> sideInputs = Lists.newArrayList(); if (computeNumShards != null) { @@ -517,10 +593,13 @@ private PDone createWrite(PCollection input) { (numShardsView != null) ? null : numShardsProvider)) .withSideInputs(sideInputs)) .apply("GroupIntoShards", GroupByKey.create()); - shardedWindowCoder = - (Coder) sharded.getWindowingStrategy().getWindowFn().windowCoder(); - - results = sharded.apply("WriteShardedBundles", ParDo.of(new WriteShardedBundles())); + // Since this path might be used by streaming runners processing triggers, it's important + // to assign shard numbers here so that they are deterministic. The ASSIGN_IN_FINALIZE + // strategy works by sorting all FileResult objects and assigning them numbers, which is not + // guaranteed to work well when processing triggers - if the finalize step retries it might + // see a different Iterable of FileResult objects, and it will assign different shard numbers. + results = sharded.apply("WriteShardedBundles", + ParDo.of(new WriteShardedBundles(ShardAssignment.ASSIGN_WHEN_WRITING))); } results.setCoder(FileResultCoder.of(shardedWindowCoder)); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/TestPipelineOptions.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/TestPipelineOptions.java index 206bc1f343c4b..904f3a2ff837d 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/TestPipelineOptions.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/TestPipelineOptions.java @@ -20,8 +20,10 @@ import com.fasterxml.jackson.annotation.JsonIgnore; import javax.annotation.Nullable; import org.apache.beam.sdk.PipelineResult; +import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.options.Default; import org.apache.beam.sdk.options.DefaultValueFactory; +import org.apache.beam.sdk.options.Hidden; import org.apache.beam.sdk.options.PipelineOptions; import org.hamcrest.BaseMatcher; import org.hamcrest.Description; @@ -50,6 +52,14 @@ public interface TestPipelineOptions extends PipelineOptions { Long getTestTimeoutSeconds(); void setTestTimeoutSeconds(Long value); + @Default.Boolean(false) + @Internal + @Hidden + @org.apache.beam.sdk.options.Description( + "Indicates whether this is an automatically-run unit test.") + boolean isUnitTest(); + void setUnitTest(boolean unitTest); + /** * Factory for {@link PipelineResult} matchers which always pass. */ diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/SimpleSink.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/SimpleSink.java index c97313d397dde..bdf37f635ef5f 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/SimpleSink.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/SimpleSink.java @@ -40,6 +40,10 @@ public SimpleSink(ResourceId baseOutputDirectory, String prefix, String template writableByteChannelFactory); } + public SimpleSink(ResourceId baseOutputDirectory, FilenamePolicy filenamePolicy) { + super(StaticValueProvider.of(baseOutputDirectory), filenamePolicy); + } + @Override public SimpleWriteOperation createWriteOperation() { return new SimpleWriteOperation(this); diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/WriteFilesTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/WriteFilesTest.java index a5dacd10f6b05..e6a0dcf2c66fa 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/WriteFilesTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/WriteFilesTest.java @@ -41,6 +41,7 @@ import java.util.concurrent.ThreadLocalRandom; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.io.FileBasedSink.FilenamePolicy; import org.apache.beam.sdk.io.SimpleSink.SimpleWriter; import org.apache.beam.sdk.io.fs.MatchResult.Metadata; import org.apache.beam.sdk.io.fs.ResolveOptions.StandardResolveOptions; @@ -62,12 +63,15 @@ import org.apache.beam.sdk.transforms.View; import org.apache.beam.sdk.transforms.display.DisplayData; import org.apache.beam.sdk.transforms.windowing.FixedWindows; +import org.apache.beam.sdk.transforms.windowing.IntervalWindow; import org.apache.beam.sdk.transforms.windowing.Sessions; import org.apache.beam.sdk.transforms.windowing.Window; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionView; import org.joda.time.Duration; +import org.joda.time.format.DateTimeFormatter; +import org.joda.time.format.ISODateTimeFormat; import org.junit.Rule; import org.junit.Test; import org.junit.experimental.categories.Category; @@ -160,7 +164,7 @@ private String getBaseOutputFilename() { public void testWrite() throws IOException { List inputs = Arrays.asList("Critical canary", "Apprehensive eagle", "Intimidating pigeon", "Pedantic gull", "Frisky finch"); - runWrite(inputs, IDENTITY_MAP, getBaseOutputFilename()); + runWrite(inputs, IDENTITY_MAP, getBaseOutputFilename(), WriteFiles.to(makeSimpleSink())); } /** @@ -169,7 +173,8 @@ public void testWrite() throws IOException { @Test @Category(NeedsRunner.class) public void testEmptyWrite() throws IOException { - runWrite(Collections.emptyList(), IDENTITY_MAP, getBaseOutputFilename()); + runWrite(Collections.emptyList(), IDENTITY_MAP, getBaseOutputFilename(), + WriteFiles.to(makeSimpleSink())); checkFileContents(getBaseOutputFilename(), Collections.emptyList(), Optional.of(1)); } @@ -185,7 +190,7 @@ public void testShardedWrite() throws IOException { Arrays.asList("one", "two", "three", "four", "five", "six"), IDENTITY_MAP, getBaseOutputFilename(), - Optional.of(1)); + WriteFiles.to(makeSimpleSink()).withNumShards(1)); } private ResourceId getBaseOutputDirectory() { @@ -194,7 +199,8 @@ private ResourceId getBaseOutputDirectory() { } private SimpleSink makeSimpleSink() { - return new SimpleSink(getBaseOutputDirectory(), "file", "-SS-of-NN", "simple"); + FilenamePolicy filenamePolicy = new PerWindowFiles("file", "simple"); + return new SimpleSink(getBaseOutputDirectory(), filenamePolicy); } @Test @@ -235,7 +241,7 @@ public void testExpandShardedWrite() throws IOException { Arrays.asList("one", "two", "three", "four", "five", "six"), IDENTITY_MAP, getBaseOutputFilename(), - Optional.of(20)); + WriteFiles.to(makeSimpleSink()).withNumShards(20)); } /** @@ -245,7 +251,7 @@ public void testExpandShardedWrite() throws IOException { @Category(NeedsRunner.class) public void testWriteWithEmptyPCollection() throws IOException { List inputs = new ArrayList<>(); - runWrite(inputs, IDENTITY_MAP, getBaseOutputFilename()); + runWrite(inputs, IDENTITY_MAP, getBaseOutputFilename(), WriteFiles.to(makeSimpleSink())); } /** @@ -258,7 +264,7 @@ public void testWriteWindowed() throws IOException { "Intimidating pigeon", "Pedantic gull", "Frisky finch"); runWrite( inputs, new WindowAndReshuffle<>(Window.into(FixedWindows.of(Duration.millis(2)))), - getBaseOutputFilename()); + getBaseOutputFilename(), WriteFiles.to(makeSimpleSink())); } /** @@ -274,10 +280,23 @@ public void testWriteWithSessions() throws IOException { inputs, new WindowAndReshuffle<>( Window.into(Sessions.withGapDuration(Duration.millis(1)))), - getBaseOutputFilename()); + getBaseOutputFilename(), + WriteFiles.to(makeSimpleSink())); } @Test + @Category(NeedsRunner.class) + public void testWriteSpilling() throws IOException { + List inputs = Lists.newArrayList(); + for (int i = 0; i < 100; ++i) { + inputs.add("mambo_number_" + i); + } + runWrite( + inputs, Window.into(FixedWindows.of(Duration.millis(2))), + getBaseOutputFilename(), + WriteFiles.to(makeSimpleSink()).withMaxNumWritersPerBundle(2).withWindowedWrites()); + } + public void testBuildWrite() { SimpleSink sink = makeSimpleSink(); WriteFiles write = WriteFiles.to(sink).withNumShards(3); @@ -365,8 +384,45 @@ public void populateDisplayData(DisplayData.Builder builder) { */ private void runWrite( List inputs, PTransform, PCollection> transform, - String baseName) throws IOException { - runShardedWrite(inputs, transform, baseName, Optional.absent()); + String baseName, WriteFiles write) throws IOException { + runShardedWrite(inputs, transform, baseName, write); + } + + private static class PerWindowFiles extends FilenamePolicy { + private static final DateTimeFormatter FORMATTER = ISODateTimeFormat.hourMinuteSecondMillis(); + private final String prefix; + private final String suffix; + + public PerWindowFiles(String prefix, String suffix) { + this.prefix = prefix; + this.suffix = suffix; + } + + public String filenamePrefixForWindow(IntervalWindow window) { + return String.format("%s%s-%s", + prefix, FORMATTER.print(window.start()), FORMATTER.print(window.end())); + } + + @Override + public ResourceId windowedFilename( + ResourceId outputDirectory, WindowedContext context, String extension) { + IntervalWindow window = (IntervalWindow) context.getWindow(); + String filename = String.format( + "%s-%s-of-%s%s%s", + filenamePrefixForWindow(window), context.getShardNumber(), context.getNumShards(), + extension, suffix); + return outputDirectory.resolve(filename, StandardResolveOptions.RESOLVE_FILE); + } + + @Override + public ResourceId unwindowedFilename( + ResourceId outputDirectory, Context context, String extension) { + String filename = String.format( + "%s%s-of-%s%s%s", + prefix, context.getShardNumber(), context.getNumShards(), + extension, suffix); + return outputDirectory.resolve(filename, StandardResolveOptions.RESOLVE_FILE); + } } /** @@ -379,7 +435,7 @@ private void runShardedWrite( List inputs, PTransform, PCollection> transform, String baseName, - Optional numConfiguredShards) throws IOException { + WriteFiles write) throws IOException { // Flag to validate that the pipeline options are passed to the Sink WriteOptions options = TestPipeline.testingPipelineOptions().as(WriteOptions.class); options.setTestFlag("test_value"); @@ -390,18 +446,15 @@ private void runShardedWrite( for (long i = 0; i < inputs.size(); i++) { timestamps.add(i + 1); } - - SimpleSink sink = makeSimpleSink(); - WriteFiles write = WriteFiles.to(sink); - if (numConfiguredShards.isPresent()) { - write = write.withNumShards(numConfiguredShards.get()); - } p.apply(Create.timestamped(inputs, timestamps).withCoder(StringUtf8Coder.of())) .apply(transform) .apply(write); p.run(); - checkFileContents(baseName, inputs, numConfiguredShards); + Optional numShards = + (write.getNumShards() != null) + ? Optional.of(write.getNumShards().get()) : Optional.absent(); + checkFileContents(baseName, inputs, numShards); } static void checkFileContents(String baseName, List inputs, From 4f6032c9c1774a9797e3ff25cc2a05fe56453f21 Mon Sep 17 00:00:00 2001 From: Eugene Kirpichov Date: Mon, 19 Jun 2017 08:34:31 -0700 Subject: [PATCH 069/200] Bump Dataflow worker to 20170619 --- runners/google-cloud-dataflow-java/pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/runners/google-cloud-dataflow-java/pom.xml b/runners/google-cloud-dataflow-java/pom.xml index 92c94a8394db6..f627f123de3e5 100644 --- a/runners/google-cloud-dataflow-java/pom.xml +++ b/runners/google-cloud-dataflow-java/pom.xml @@ -33,7 +33,7 @@ jar - beam-master-20170530 + beam-master-20170619 1 6 From a06c8bfae6fb9e35deeb4adfdd7761889b12be89 Mon Sep 17 00:00:00 2001 From: Eugene Kirpichov Date: Wed, 1 Feb 2017 17:26:55 -0800 Subject: [PATCH 070/200] [BEAM-1377] Splittable DoFn in Dataflow streaming runner Transform expansion and translation for the involved primitive transforms. Of course, the current PR will only work after the respective Dataflow worker and backend changes are released. --- runners/google-cloud-dataflow-java/pom.xml | 6 +- .../dataflow/DataflowPipelineTranslator.java | 40 +++++++++ .../beam/runners/dataflow/DataflowRunner.java | 14 +++ .../dataflow/SplittableParDoOverrides.java | 76 ++++++++++++++++ .../runners/dataflow/util/PropertyNames.java | 1 + .../DataflowPipelineTranslatorTest.java | 89 +++++++++++++++++++ .../sdk/transforms/SplittableDoFnTest.java | 22 ++++- 7 files changed, 246 insertions(+), 2 deletions(-) create mode 100644 runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/SplittableParDoOverrides.java diff --git a/runners/google-cloud-dataflow-java/pom.xml b/runners/google-cloud-dataflow-java/pom.xml index f627f123de3e5..d1bce32b9c5b4 100644 --- a/runners/google-cloud-dataflow-java/pom.xml +++ b/runners/google-cloud-dataflow-java/pom.xml @@ -216,13 +216,17 @@ validates-runner-tests + org.apache.beam.sdk.testing.LargeKeys$Above10MB, org.apache.beam.sdk.testing.UsesDistributionMetrics, org.apache.beam.sdk.testing.UsesGaugeMetrics, org.apache.beam.sdk.testing.UsesSetState, org.apache.beam.sdk.testing.UsesMapState, - org.apache.beam.sdk.testing.UsesSplittableParDo, + org.apache.beam.sdk.testing.UsesSplittableParDoWithWindowedSideInputs, org.apache.beam.sdk.testing.UsesUnboundedPCollections, org.apache.beam.sdk.testing.UsesTestStream, diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java index afc34e6fc8833..bfd9b649add4b 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java @@ -56,6 +56,7 @@ import java.util.Map; import java.util.concurrent.atomic.AtomicLong; import javax.annotation.Nullable; +import org.apache.beam.runners.core.construction.SplittableParDo; import org.apache.beam.runners.core.construction.TransformInputs; import org.apache.beam.runners.core.construction.WindowingStrategyTranslation; import org.apache.beam.runners.dataflow.BatchViewOverrides.GroupByKeyAndSortValuesOnly; @@ -886,6 +887,45 @@ private void translateHelper(Window.Assign transform, TranslationContext // IO Translation. registerTransformTranslator(Read.Bounded.class, new ReadTranslator()); + + /////////////////////////////////////////////////////////////////////////// + // Splittable DoFn translation. + + registerTransformTranslator( + SplittableParDo.ProcessKeyedElements.class, + new TransformTranslator() { + @Override + public void translate( + SplittableParDo.ProcessKeyedElements transform, TranslationContext context) { + translateTyped(transform, context); + } + + private void translateTyped( + SplittableParDo.ProcessKeyedElements transform, + TranslationContext context) { + StepTranslationContext stepContext = + context.addStep(transform, "SplittableProcessKeyed"); + + translateInputs( + stepContext, context.getInput(transform), transform.getSideInputs(), context); + BiMap> outputMap = + translateOutputs(context.getOutputs(transform), stepContext); + stepContext.addInput( + PropertyNames.SERIALIZED_FN, + byteArrayToJsonString( + serializeToByteArray( + DoFnInfo.forFn( + transform.getFn(), + transform.getInputWindowingStrategy(), + transform.getSideInputs(), + transform.getElementCoder(), + outputMap.inverse().get(transform.getMainOutputTag()), + outputMap)))); + stepContext.addInput( + PropertyNames.RESTRICTION_CODER, + CloudObjects.asCloudObject(transform.getRestrictionCoder())); + } + }); } private static void translateInputs( diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java index ea9db24ff638e..c584b318832c2 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java @@ -325,6 +325,20 @@ private List getOverrides(boolean streaming) { new StreamingFnApiCreateOverrideFactory())); } overridesBuilder + // Support Splittable DoFn for now only in streaming mode. + // The order of the following overrides is important because they are applied in order. + + // By default Dataflow runner replaces single-output ParDo with a ParDoSingle override. + // However, we want a different expansion for single-output splittable ParDo. + .add( + PTransformOverride.of( + PTransformMatchers.splittableParDoSingle(), + new ReflectiveOneToOneOverrideFactory( + SplittableParDoOverrides.ParDoSingleViaMulti.class, this))) + .add( + PTransformOverride.of( + PTransformMatchers.splittableParDoMulti(), + new SplittableParDoOverrides.SplittableParDoOverrideFactory())) .add( // Streaming Bounded Read is implemented in terms of Streaming Unbounded Read, and // must precede it diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/SplittableParDoOverrides.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/SplittableParDoOverrides.java new file mode 100644 index 0000000000000..93228782372bc --- /dev/null +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/SplittableParDoOverrides.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow; + +import java.util.Map; +import org.apache.beam.runners.core.construction.ForwardingPTransform; +import org.apache.beam.runners.core.construction.PTransformReplacements; +import org.apache.beam.runners.core.construction.ReplacementOutputs; +import org.apache.beam.runners.core.construction.SplittableParDo; +import org.apache.beam.sdk.runners.AppliedPTransform; +import org.apache.beam.sdk.runners.PTransformOverrideFactory; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionTuple; +import org.apache.beam.sdk.values.PValue; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.TupleTagList; + +/** Transform overrides for supporting {@link SplittableParDo} in the Dataflow runner. */ +class SplittableParDoOverrides { + static class ParDoSingleViaMulti + extends ForwardingPTransform, PCollection> { + private final ParDo.SingleOutput original; + + public ParDoSingleViaMulti( + DataflowRunner ignored, ParDo.SingleOutput original) { + this.original = original; + } + + @Override + protected PTransform, PCollection> delegate() { + return original; + } + + @Override + public PCollection expand(PCollection input) { + TupleTag mainOutput = new TupleTag<>(); + return input.apply(original.withOutputTags(mainOutput, TupleTagList.empty())).get(mainOutput); + } + } + + static class SplittableParDoOverrideFactory + implements PTransformOverrideFactory< + PCollection, PCollectionTuple, ParDo.MultiOutput> { + @Override + public PTransformReplacement, PCollectionTuple> getReplacementTransform( + AppliedPTransform, PCollectionTuple, ParDo.MultiOutput> + appliedTransform) { + return PTransformReplacement.of( + PTransformReplacements.getSingletonMainInput(appliedTransform), + new SplittableParDo<>(appliedTransform.getTransform())); + } + + @Override + public Map mapOutputs( + Map, PValue> outputs, PCollectionTuple newOutput) { + return ReplacementOutputs.tagged(outputs, newOutput); + } + } +} diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/PropertyNames.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/PropertyNames.java index f82f1f112c173..55e0c4ebff97b 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/PropertyNames.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/PropertyNames.java @@ -63,4 +63,5 @@ public class PropertyNames { public static final String USES_KEYED_STATE = "uses_keyed_state"; public static final String VALUE = "value"; public static final String DISPLAY_DATA = "display_data"; + public static final String RESTRICTION_CODER = "restriction_coder"; } diff --git a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslatorTest.java b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslatorTest.java index 53215f60f1179..948af1cf606ac 100644 --- a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslatorTest.java +++ b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslatorTest.java @@ -18,11 +18,14 @@ package org.apache.beam.runners.dataflow; import static org.apache.beam.runners.dataflow.util.Structs.getString; +import static org.apache.beam.sdk.util.StringUtils.jsonStringToByteArray; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasEntry; import static org.hamcrest.Matchers.hasKey; +import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.not; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; @@ -66,11 +69,15 @@ import org.apache.beam.runners.dataflow.DataflowPipelineTranslator.JobSpecification; import org.apache.beam.runners.dataflow.options.DataflowPipelineOptions; import org.apache.beam.runners.dataflow.options.DataflowPipelineWorkerPoolOptions; +import org.apache.beam.runners.dataflow.util.CloudObject; +import org.apache.beam.runners.dataflow.util.CloudObjects; +import org.apache.beam.runners.dataflow.util.DoFnInfo; import org.apache.beam.runners.dataflow.util.OutputReference; import org.apache.beam.runners.dataflow.util.PropertyNames; import org.apache.beam.runners.dataflow.util.Structs; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.SerializableCoder; import org.apache.beam.sdk.coders.VarIntCoder; import org.apache.beam.sdk.coders.VoidCoder; import org.apache.beam.sdk.extensions.gcp.auth.TestCredential; @@ -91,7 +98,13 @@ import org.apache.beam.sdk.transforms.Sum; import org.apache.beam.sdk.transforms.View; import org.apache.beam.sdk.transforms.display.DisplayData; +import org.apache.beam.sdk.transforms.splittabledofn.OffsetRange; +import org.apache.beam.sdk.transforms.splittabledofn.OffsetRangeTracker; +import org.apache.beam.sdk.transforms.windowing.FixedWindows; +import org.apache.beam.sdk.transforms.windowing.Window; +import org.apache.beam.sdk.transforms.windowing.WindowFn; import org.apache.beam.sdk.util.GcsUtil; +import org.apache.beam.sdk.util.SerializableUtils; import org.apache.beam.sdk.util.gcsfs.GcsPath; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; @@ -100,6 +113,8 @@ import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.TupleTagList; import org.apache.beam.sdk.values.WindowingStrategy; +import org.hamcrest.Matchers; +import org.joda.time.Duration; import org.junit.Assert; import org.junit.Rule; import org.junit.Test; @@ -896,6 +911,68 @@ public void process(ProcessContext c) { not(equalTo("true"))); } + /** + * Smoke test to fail fast if translation of a splittable ParDo + * in streaming breaks. + */ + @Test + public void testStreamingSplittableParDoTranslation() throws Exception { + DataflowPipelineOptions options = buildPipelineOptions(); + DataflowRunner runner = DataflowRunner.fromOptions(options); + options.setStreaming(true); + DataflowPipelineTranslator translator = DataflowPipelineTranslator.fromOptions(options); + + Pipeline pipeline = Pipeline.create(options); + + PCollection windowedInput = pipeline + .apply(Create.of("a")) + .apply(Window.into(FixedWindows.of(Duration.standardMinutes(1)))); + windowedInput.apply(ParDo.of(new TestSplittableFn())); + + runner.replaceTransforms(pipeline); + + Job job = + translator + .translate( + pipeline, + runner, + Collections.emptyList()) + .getJob(); + + // The job should contain a SplittableParDo.ProcessKeyedElements step, translated as + // "SplittableProcessKeyed". + + List steps = job.getSteps(); + Step processKeyedStep = null; + for (Step step : steps) { + if (step.getKind().equals("SplittableProcessKeyed")) { + assertNull(processKeyedStep); + processKeyedStep = step; + } + } + assertNotNull(processKeyedStep); + + @SuppressWarnings({"unchecked", "rawtypes"}) + DoFnInfo fnInfo = + (DoFnInfo) + SerializableUtils.deserializeFromByteArray( + jsonStringToByteArray( + Structs.getString( + processKeyedStep.getProperties(), PropertyNames.SERIALIZED_FN)), + "DoFnInfo"); + assertThat(fnInfo.getDoFn(), instanceOf(TestSplittableFn.class)); + assertThat( + fnInfo.getWindowingStrategy().getWindowFn(), + Matchers.equalTo(FixedWindows.of(Duration.standardMinutes(1)))); + Coder restrictionCoder = + CloudObjects.coderFromCloudObject( + (CloudObject) + Structs.getObject( + processKeyedStep.getProperties(), PropertyNames.RESTRICTION_CODER)); + + assertEquals(SerializableCoder.of(OffsetRange.class), restrictionCoder); + } + @Test public void testToSingletonTranslationWithIsmSideInput() throws Exception { // A "change detector" test that makes sure the translation @@ -1090,4 +1167,16 @@ private static void assertAllStepOutputsHaveUniqueIds(Job job) assertTrue(String.format("Found duplicate output ids %s", outputIds), outputIds.size() == 0); } + + private static class TestSplittableFn extends DoFn { + @ProcessElement + public void process(ProcessContext c, OffsetRangeTracker tracker) { + // noop + } + + @GetInitialRestriction + public OffsetRange getInitialRange(String element) { + return null; + } + } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/SplittableDoFnTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/SplittableDoFnTest.java index 646d8d310bf8a..0c2bd1c871d07 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/SplittableDoFnTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/SplittableDoFnTest.java @@ -18,6 +18,7 @@ package org.apache.beam.sdk.transforms; import static com.google.common.base.Preconditions.checkState; +import static org.apache.beam.sdk.testing.TestPipeline.testingPipelineOptions; import static org.hamcrest.Matchers.greaterThan; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; @@ -33,6 +34,8 @@ import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.coders.VarIntCoder; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.StreamingOptions; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.testing.TestStream; @@ -97,8 +100,25 @@ public void process(ProcessContext c) { } } + private static PipelineOptions streamingTestPipelineOptions() { + // Using testing options with streaming=true makes it possible to enable UsesSplittableParDo + // tests in Dataflow runner, because as of writing, it can run Splittable DoFn only in + // streaming mode. + // This is a no-op for other runners currently (Direct runner doesn't care, and other + // runners don't implement SDF at all yet). + // + // This is a workaround until https://issues.apache.org/jira/browse/BEAM-1620 + // is properly implemented and supports marking tests as streaming-only. + // + // https://issues.apache.org/jira/browse/BEAM-2483 specifically tracks the removal of the + // current workaround. + PipelineOptions options = testingPipelineOptions(); + options.as(StreamingOptions.class).setStreaming(true); + return options; + } + @Rule - public final transient TestPipeline p = TestPipeline.create(); + public final transient TestPipeline p = TestPipeline.fromOptions(streamingTestPipelineOptions()); @Test @Category({ValidatesRunner.class, UsesSplittableParDo.class}) From ef19024d2e9dc046c6699aeee1edc483beb9a009 Mon Sep 17 00:00:00 2001 From: Ahmet Altay Date: Tue, 20 Jun 2017 14:25:55 -0700 Subject: [PATCH 071/200] Add a cloud-pubsub dependency to the list of gcp extra packages --- .../apache_beam/examples/streaming_wordcount.py | 13 ++++++++----- sdks/python/setup.py | 1 + 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/sdks/python/apache_beam/examples/streaming_wordcount.py b/sdks/python/apache_beam/examples/streaming_wordcount.py index ed8b5d08dc62e..f2b179aa2438d 100644 --- a/sdks/python/apache_beam/examples/streaming_wordcount.py +++ b/sdks/python/apache_beam/examples/streaming_wordcount.py @@ -25,16 +25,19 @@ import argparse import logging -import re import apache_beam as beam import apache_beam.transforms.window as window +def split_fn(lines): + import re + return re.findall(r'[A-Za-z\']+', x) + + def run(argv=None): """Build and run the pipeline.""" - parser = argparse.ArgumentParser() parser.add_argument( '--input_topic', required=True, @@ -46,14 +49,14 @@ def run(argv=None): with beam.Pipeline(argv=pipeline_args) as p: - # Read the text file[pattern] into a PCollection. + # Read from PubSub into a PCollection. lines = p | beam.io.ReadStringsFromPubSub(known_args.input_topic) # Capitalize the characters in each line. transformed = (lines + # Use a pre-defined function that imports the re package. | 'Split' >> ( - beam.FlatMap(lambda x: re.findall(r'[A-Za-z\']+', x)) - .with_output_types(unicode)) + beam.FlatMap(split_fn).with_output_types(unicode)) | 'PairWithOne' >> beam.Map(lambda x: (x, 1)) | beam.WindowInto(window.FixedWindows(15, 0)) | 'Group' >> beam.GroupByKey() diff --git a/sdks/python/setup.py b/sdks/python/setup.py index 051043b07e7f5..584c852c57b82 100644 --- a/sdks/python/setup.py +++ b/sdks/python/setup.py @@ -118,6 +118,7 @@ def get_version(): 'google-apitools>=0.5.10,<=0.5.11', 'proto-google-cloud-datastore-v1>=0.90.0,<=0.90.4', 'googledatastore==7.0.1', + 'google-cloud-pubsub==0.25.0', # GCP packages required by tests 'google-cloud-bigquery>=0.23.0,<0.25.0', ] From cbb922c8a72680c5b8b4299197b515abf650bfdf Mon Sep 17 00:00:00 2001 From: Reuven Lax Date: Wed, 8 Feb 2017 12:53:27 -0800 Subject: [PATCH 072/200] BEAM-1438 Auto shard streaming sinks If a Write operation in streaming requests runner-determined sharding, make the Dataflow runner default to maxNumWorkers * 2 shards. --- .../beam/runners/dataflow/DataflowRunner.java | 57 ++++++++++++++++++ .../runners/dataflow/DataflowRunnerTest.java | 60 ++++++++++++++++++- 2 files changed, 114 insertions(+), 3 deletions(-) diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java index c584b318832c2..1741287d77036 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java @@ -67,10 +67,12 @@ import org.apache.beam.runners.core.construction.SingleInputOutputOverrideFactory; import org.apache.beam.runners.core.construction.UnboundedReadFromBoundedSource; import org.apache.beam.runners.core.construction.UnconsumedReads; +import org.apache.beam.runners.core.construction.WriteFilesTranslation; import org.apache.beam.runners.dataflow.DataflowPipelineTranslator.JobSpecification; import org.apache.beam.runners.dataflow.StreamingViewOverrides.StreamingCreatePCollectionViewFactory; import org.apache.beam.runners.dataflow.options.DataflowPipelineDebugOptions; import org.apache.beam.runners.dataflow.options.DataflowPipelineOptions; +import org.apache.beam.runners.dataflow.options.DataflowPipelineWorkerPoolOptions; import org.apache.beam.runners.dataflow.util.DataflowTemplateJob; import org.apache.beam.runners.dataflow.util.DataflowTransport; import org.apache.beam.runners.dataflow.util.MonitoringUtil; @@ -91,6 +93,7 @@ import org.apache.beam.sdk.io.FileSystems; import org.apache.beam.sdk.io.Read; import org.apache.beam.sdk.io.UnboundedSource; +import org.apache.beam.sdk.io.WriteFiles; import org.apache.beam.sdk.io.fs.ResourceId; import org.apache.beam.sdk.io.gcp.pubsub.PubsubMessage; import org.apache.beam.sdk.io.gcp.pubsub.PubsubMessageWithAttributesCoder; @@ -339,6 +342,10 @@ private List getOverrides(boolean streaming) { PTransformOverride.of( PTransformMatchers.splittableParDoMulti(), new SplittableParDoOverrides.SplittableParDoOverrideFactory())) + .add( + PTransformOverride.of( + PTransformMatchers.writeWithRunnerDeterminedSharding(), + new StreamingShardedWriteFactory(options))) .add( // Streaming Bounded Read is implemented in terms of Streaming Unbounded Read, and // must precede it @@ -1442,6 +1449,56 @@ public Map mapOutputs( } } + @VisibleForTesting + static class StreamingShardedWriteFactory + implements PTransformOverrideFactory, PDone, WriteFiles> { + // We pick 10 as a a default, as it works well with the default number of workers started + // by Dataflow. + static final int DEFAULT_NUM_SHARDS = 10; + DataflowPipelineWorkerPoolOptions options; + + StreamingShardedWriteFactory(PipelineOptions options) { + this.options = options.as(DataflowPipelineWorkerPoolOptions.class); + } + + @Override + public PTransformReplacement, PDone> getReplacementTransform( + AppliedPTransform, PDone, WriteFiles> transform) { + // By default, if numShards is not set WriteFiles will produce one file per bundle. In + // streaming, there are large numbers of small bundles, resulting in many tiny files. + // Instead we pick max workers * 2 to ensure full parallelism, but prevent too-many files. + // (current_num_workers * 2 might be a better choice, but that value is not easily available + // today). + // If the user does not set either numWorkers or maxNumWorkers, default to 10 shards. + int numShards; + if (options.getMaxNumWorkers() > 0) { + numShards = options.getMaxNumWorkers() * 2; + } else if (options.getNumWorkers() > 0) { + numShards = options.getNumWorkers() * 2; + } else { + numShards = DEFAULT_NUM_SHARDS; + } + + try { + WriteFiles replacement = WriteFiles.to(WriteFilesTranslation.getSink(transform)); + if (WriteFilesTranslation.isWindowedWrites(transform)) { + replacement = replacement.withWindowedWrites(); + } + return PTransformReplacement.of( + PTransformReplacements.getSingletonMainInput(transform), + replacement.withNumShards(numShards)); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + @Override + public Map mapOutputs(Map, PValue> outputs, + PDone newOutput) { + return Collections.emptyMap(); + } + } + @VisibleForTesting static String getContainerImageForJob(DataflowPipelineOptions options) { String workerHarnessContainerImage = options.getWorkerHarnessContainerImage(); diff --git a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java index 8f10b18d3eda3..aae21cffd4c00 100644 --- a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java +++ b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java @@ -23,6 +23,7 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasItem; import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; import static org.hamcrest.Matchers.startsWith; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; @@ -62,21 +63,28 @@ import java.util.List; import java.util.concurrent.atomic.AtomicBoolean; import java.util.regex.Pattern; +import org.apache.beam.runners.dataflow.DataflowRunner.StreamingShardedWriteFactory; import org.apache.beam.runners.dataflow.options.DataflowPipelineDebugOptions; import org.apache.beam.runners.dataflow.options.DataflowPipelineOptions; +import org.apache.beam.runners.dataflow.options.DataflowPipelineWorkerPoolOptions; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.Pipeline.PipelineVisitor; import org.apache.beam.sdk.coders.BigEndianIntegerCoder; import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.VoidCoder; import org.apache.beam.sdk.extensions.gcp.auth.NoopCredentialFactory; import org.apache.beam.sdk.extensions.gcp.auth.TestCredential; import org.apache.beam.sdk.extensions.gcp.storage.NoopPathValidator; +import org.apache.beam.sdk.io.FileBasedSink; import org.apache.beam.sdk.io.FileSystems; import org.apache.beam.sdk.io.TextIO; +import org.apache.beam.sdk.io.WriteFiles; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.PipelineOptions.CheckEnabled; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.options.ValueProvider; +import org.apache.beam.sdk.options.ValueProvider.StaticValueProvider; +import org.apache.beam.sdk.runners.AppliedPTransform; import org.apache.beam.sdk.runners.TransformHierarchy; import org.apache.beam.sdk.runners.TransformHierarchy.Node; import org.apache.beam.sdk.testing.ExpectedLogs; @@ -87,7 +95,10 @@ import org.apache.beam.sdk.util.ReleaseInfo; import org.apache.beam.sdk.util.gcsfs.GcsPath; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PDone; +import org.apache.beam.sdk.values.PValue; import org.apache.beam.sdk.values.TimestampedValue; +import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.WindowingStrategy; import org.hamcrest.Description; import org.hamcrest.Matchers; @@ -823,7 +834,6 @@ public void testTempLocationAndNoGcpTempLocationSucceeds() throws Exception { DataflowRunner.fromOptions(options); } - @Test public void testValidProfileLocation() throws IOException { DataflowPipelineOptions options = buildPipelineOptions(); @@ -1047,8 +1057,8 @@ public void testToString() { } /** - * Tests that the {@link DataflowRunner} with {@code --templateLocation} returns normally - * when the runner issuccessfully run. + * Tests that the {@link DataflowRunner} with {@code --templateLocation} returns normally when the + * runner is successfully run. */ @Test public void testTemplateRunnerFullCompletion() throws Exception { @@ -1127,4 +1137,48 @@ public void testWorkerHarnessContainerImage() { assertThat( getContainerImageForJob(options), equalTo("gcr.io/java/foo")); } + + @Test + public void testStreamingWriteWithNoShardingReturnsNewTransform() { + PipelineOptions options = TestPipeline.testingPipelineOptions(); + options.as(DataflowPipelineWorkerPoolOptions.class).setMaxNumWorkers(10); + testStreamingWriteOverride(options, 20); + } + + @Test + public void testStreamingWriteWithNoShardingReturnsNewTransformMaxWorkersUnset() { + PipelineOptions options = TestPipeline.testingPipelineOptions(); + testStreamingWriteOverride(options, StreamingShardedWriteFactory.DEFAULT_NUM_SHARDS); + } + + private void testStreamingWriteOverride(PipelineOptions options, int expectedNumShards) { + TestPipeline p = TestPipeline.fromOptions(options); + + StreamingShardedWriteFactory factory = + new StreamingShardedWriteFactory<>(p.getOptions()); + WriteFiles original = WriteFiles.to(new TestSink(tmpFolder.toString())); + PCollection objs = (PCollection) p.apply(Create.empty(VoidCoder.of())); + AppliedPTransform, PDone, WriteFiles> originalApplication = + AppliedPTransform.of( + "writefiles", objs.expand(), Collections., PValue>emptyMap(), original, p); + + WriteFiles replacement = (WriteFiles) + factory.getReplacementTransform(originalApplication).getTransform(); + assertThat(replacement, not(equalTo((Object) original))); + assertThat(replacement.getNumShards().get(), equalTo(expectedNumShards)); + } + + private static class TestSink extends FileBasedSink { + @Override + public void validate(PipelineOptions options) {} + + TestSink(String tmpFolder) { + super(StaticValueProvider.of(FileSystems.matchNewResource(tmpFolder, true)), + null); + } + @Override + public WriteOperation createWriteOperation() { + throw new IllegalArgumentException("Should not be used"); + } + } } From 5a95d620ad0c7f427a5b849059a0215c3f061a58 Mon Sep 17 00:00:00 2001 From: Sourabh Bajaj Date: Tue, 20 Jun 2017 18:53:50 -0700 Subject: [PATCH 073/200] None should be a valid return element --- sdks/python/apache_beam/typehints/trivial_inference.py | 3 +-- .../python/apache_beam/typehints/trivial_inference_test.py | 7 +++++++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/sdks/python/apache_beam/typehints/trivial_inference.py b/sdks/python/apache_beam/typehints/trivial_inference.py index 977ea066a9f4c..c7405963f1614 100644 --- a/sdks/python/apache_beam/typehints/trivial_inference.py +++ b/sdks/python/apache_beam/typehints/trivial_inference.py @@ -40,8 +40,7 @@ def instance_to_type(o): """ t = type(o) if o is None: - # TODO(robertwb): Eliminate inconsistent use of None vs. NoneType. - return None + return type(None) elif t not in typehints.DISALLOWED_PRIMITIVE_TYPES: if t == types.InstanceType: return o.__class__ diff --git a/sdks/python/apache_beam/typehints/trivial_inference_test.py b/sdks/python/apache_beam/typehints/trivial_inference_test.py index ac00baa3f4c73..e7f451da11a6f 100644 --- a/sdks/python/apache_beam/typehints/trivial_inference_test.py +++ b/sdks/python/apache_beam/typehints/trivial_inference_test.py @@ -60,6 +60,13 @@ def reverse((a, b)): self.assertReturnType(any_tuple, reverse, [trivial_inference.Const((1, 2, 3))]) + def testNoneReturn(self): + def func(a): + if a == 5: + return a + return None + self.assertReturnType(typehints.Union[int, type(None)], func, [int]) + def testListComprehension(self): self.assertReturnType( typehints.List[int], From 6b6d20d9dc5afa0c1d8520cf6dbc98e6488a58a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Isma=C3=ABl=20Mej=C3=ADa?= Date: Wed, 21 Jun 2017 01:04:18 +0200 Subject: [PATCH 074/200] Return a valid Coder for any subtype of Mutation on HBaseCoderProviderRegistrar --- .../io/hbase/HBaseCoderProviderRegistrar.java | 11 +---- .../beam/sdk/io/hbase/HBaseMutationCoder.java | 42 +++++++++++++++++++ .../HBaseCoderProviderRegistrarTest.java | 4 ++ 3 files changed, 47 insertions(+), 10 deletions(-) diff --git a/sdks/java/io/hbase/src/main/java/org/apache/beam/sdk/io/hbase/HBaseCoderProviderRegistrar.java b/sdks/java/io/hbase/src/main/java/org/apache/beam/sdk/io/hbase/HBaseCoderProviderRegistrar.java index dee3c703addef..2973d1b2dc23d 100644 --- a/sdks/java/io/hbase/src/main/java/org/apache/beam/sdk/io/hbase/HBaseCoderProviderRegistrar.java +++ b/sdks/java/io/hbase/src/main/java/org/apache/beam/sdk/io/hbase/HBaseCoderProviderRegistrar.java @@ -24,11 +24,6 @@ import org.apache.beam.sdk.coders.CoderProviderRegistrar; import org.apache.beam.sdk.coders.CoderProviders; import org.apache.beam.sdk.values.TypeDescriptor; -import org.apache.hadoop.hbase.client.Append; -import org.apache.hadoop.hbase.client.Delete; -import org.apache.hadoop.hbase.client.Increment; -import org.apache.hadoop.hbase.client.Mutation; -import org.apache.hadoop.hbase.client.Put; import org.apache.hadoop.hbase.client.Result; /** @@ -39,11 +34,7 @@ public class HBaseCoderProviderRegistrar implements CoderProviderRegistrar { @Override public List getCoderProviders() { return ImmutableList.of( - CoderProviders.forCoder(TypeDescriptor.of(Append.class), HBaseMutationCoder.of()), - CoderProviders.forCoder(TypeDescriptor.of(Delete.class), HBaseMutationCoder.of()), - CoderProviders.forCoder(TypeDescriptor.of(Increment.class), HBaseMutationCoder.of()), - CoderProviders.forCoder(TypeDescriptor.of(Mutation.class), HBaseMutationCoder.of()), - CoderProviders.forCoder(TypeDescriptor.of(Put.class), HBaseMutationCoder.of()), + HBaseMutationCoder.getCoderProvider(), CoderProviders.forCoder(TypeDescriptor.of(Result.class), HBaseResultCoder.of())); } } diff --git a/sdks/java/io/hbase/src/main/java/org/apache/beam/sdk/io/hbase/HBaseMutationCoder.java b/sdks/java/io/hbase/src/main/java/org/apache/beam/sdk/io/hbase/HBaseMutationCoder.java index 501fe09259c81..ee83114d3b41b 100644 --- a/sdks/java/io/hbase/src/main/java/org/apache/beam/sdk/io/hbase/HBaseMutationCoder.java +++ b/sdks/java/io/hbase/src/main/java/org/apache/beam/sdk/io/hbase/HBaseMutationCoder.java @@ -21,8 +21,12 @@ import java.io.InputStream; import java.io.OutputStream; import java.io.Serializable; +import java.util.List; import org.apache.beam.sdk.coders.AtomicCoder; +import org.apache.beam.sdk.coders.CannotProvideCoderException; import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.CoderProvider; +import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.hadoop.hbase.client.Delete; import org.apache.hadoop.hbase.client.Mutation; import org.apache.hadoop.hbase.client.Put; @@ -65,4 +69,42 @@ private static MutationType getType(Mutation mutation) { throw new IllegalArgumentException("Only Put and Delete are supported"); } } + + /** + * Returns a {@link CoderProvider} which uses the {@link HBaseMutationCoder} for + * {@link Mutation mutations}. + */ + static CoderProvider getCoderProvider() { + return HBASE_MUTATION_CODER_PROVIDER; + } + + private static final CoderProvider HBASE_MUTATION_CODER_PROVIDER = + new HBaseMutationCoderProvider(); + + /** + * A {@link CoderProvider} for {@link Mutation mutations}. + */ + private static class HBaseMutationCoderProvider extends CoderProvider { + @Override + public Coder coderFor(TypeDescriptor typeDescriptor, + List> componentCoders) throws CannotProvideCoderException { + if (!typeDescriptor.isSubtypeOf(HBASE_MUTATION_TYPE_DESCRIPTOR)) { + throw new CannotProvideCoderException( + String.format( + "Cannot provide %s because %s is not a subclass of %s", + HBaseMutationCoder.class.getSimpleName(), + typeDescriptor, + Mutation.class.getName())); + } + + try { + return (Coder) HBaseMutationCoder.of(); + } catch (IllegalArgumentException e) { + throw new CannotProvideCoderException(e); + } + } + } + + private static final TypeDescriptor HBASE_MUTATION_TYPE_DESCRIPTOR = + new TypeDescriptor() {}; } diff --git a/sdks/java/io/hbase/src/test/java/org/apache/beam/sdk/io/hbase/HBaseCoderProviderRegistrarTest.java b/sdks/java/io/hbase/src/test/java/org/apache/beam/sdk/io/hbase/HBaseCoderProviderRegistrarTest.java index ac81e8a7ad5fa..5b2e13861bf64 100644 --- a/sdks/java/io/hbase/src/test/java/org/apache/beam/sdk/io/hbase/HBaseCoderProviderRegistrarTest.java +++ b/sdks/java/io/hbase/src/test/java/org/apache/beam/sdk/io/hbase/HBaseCoderProviderRegistrarTest.java @@ -18,7 +18,9 @@ package org.apache.beam.sdk.io.hbase; import org.apache.beam.sdk.coders.CoderRegistry; +import org.apache.hadoop.hbase.client.Delete; import org.apache.hadoop.hbase.client.Mutation; +import org.apache.hadoop.hbase.client.Put; import org.apache.hadoop.hbase.client.Result; import org.junit.Test; import org.junit.runner.RunWith; @@ -37,5 +39,7 @@ public void testResultCoderIsRegistered() throws Exception { @Test public void testMutationCoderIsRegistered() throws Exception { CoderRegistry.createDefault().getCoder(Mutation.class); + CoderRegistry.createDefault().getCoder(Put.class); + CoderRegistry.createDefault().getCoder(Delete.class); } } From 6681472a2aa277ba83fd9e2ffec5a57c46d5820c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Isma=C3=ABl=20Mej=C3=ADa?= Date: Tue, 20 Jun 2017 14:13:00 +0200 Subject: [PATCH 075/200] [BEAM-2481] Update commons-lang3 dependency to version 3.6 --- pom.xml | 9 ++++++++- runners/apex/pom.xml | 2 +- runners/spark/pom.xml | 5 +++++ .../beam/runners/spark/SparkNativePipelineVisitor.java | 3 +-- sdks/java/io/google-cloud-platform/pom.xml | 3 ++- .../apache/beam/sdk/io/gcp/spanner/SpannerWriteIT.java | 9 +++++---- 6 files changed, 22 insertions(+), 9 deletions(-) diff --git a/pom.xml b/pom.xml index 9373a40b19dd5..98cace95808ab 100644 --- a/pom.xml +++ b/pom.xml @@ -101,8 +101,9 @@ - 3.5 1.9 + 3.6 + 1.1 2.24.0 1.0.0-rc2 1.8.2 @@ -570,6 +571,12 @@ ${apache.commons.lang.version} + + org.apache.commons + commons-text + ${apache.commons.text.version} + + io.grpc grpc-all diff --git a/runners/apex/pom.xml b/runners/apex/pom.xml index d3d4318d2dcf4..2c5465499995a 100644 --- a/runners/apex/pom.xml +++ b/runners/apex/pom.xml @@ -256,7 +256,7 @@ org.apache.apex:apex-api:jar:${apex.core.version} - org.apache.commons:commons-lang3::3.1 + org.apache.commons:commons-lang3::${apache.commons.lang.version} commons-io:commons-io:jar:2.4 com.esotericsoftware.kryo:kryo::${apex.kryo.version} com.datatorrent:netlet::1.3.0 diff --git a/runners/spark/pom.xml b/runners/spark/pom.xml index d1dba323b94ff..0f6b73091688c 100644 --- a/runners/spark/pom.xml +++ b/runners/spark/pom.xml @@ -195,6 +195,11 @@ org.apache.commons commons-lang3 + provided + + + org.apache.commons + commons-text commons-io diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkNativePipelineVisitor.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkNativePipelineVisitor.java index d75c955960e64..6972acb16246c 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkNativePipelineVisitor.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkNativePipelineVisitor.java @@ -35,8 +35,7 @@ import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.values.PInput; import org.apache.beam.sdk.values.POutput; -import org.apache.commons.lang3.text.WordUtils; - +import org.apache.commons.text.WordUtils; /** * Pipeline visitor for translating a Beam pipeline into equivalent Spark operations. diff --git a/sdks/java/io/google-cloud-platform/pom.xml b/sdks/java/io/google-cloud-platform/pom.xml index 8b5382092ec64..6737eea5b256a 100644 --- a/sdks/java/io/google-cloud-platform/pom.xml +++ b/sdks/java/io/google-cloud-platform/pom.xml @@ -255,7 +255,8 @@ org.apache.commons - commons-lang3 + commons-text + test diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteIT.java index 8df224b76a08b..e1f6582749a4f 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteIT.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteIT.java @@ -42,7 +42,8 @@ import org.apache.beam.sdk.testing.TestPipelineOptions; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.ParDo; -import org.apache.commons.lang3.RandomStringUtils; +import org.apache.commons.text.RandomStringGenerator; + import org.junit.After; import org.junit.Before; import org.junit.Rule; @@ -115,8 +116,8 @@ public void setUp() throws Exception { } private String generateDatabaseName() { - String random = RandomStringUtils - .randomAlphanumeric(MAX_DB_NAME_LENGTH - 1 - options.getDatabaseIdPrefix().length()) + String random = new RandomStringGenerator.Builder().build() + .generate(MAX_DB_NAME_LENGTH - 1 - options.getDatabaseIdPrefix().length()) .toLowerCase(); return options.getDatabaseIdPrefix() + "-" + random; } @@ -165,7 +166,7 @@ public void processElement(ProcessContext c) { Mutation.WriteBuilder builder = Mutation.newInsertOrUpdateBuilder(table); Long key = c.element(); builder.set("Key").to(key); - builder.set("Value").to(RandomStringUtils.randomAlphabetic(valueSize)); + builder.set("Value").to(new RandomStringGenerator.Builder().build().generate(valueSize)); Mutation mutation = builder.build(); c.output(mutation); } From 56041b7850abfbb10d4a6ff2ddecb227a0a4e7c8 Mon Sep 17 00:00:00 2001 From: Charles Chen Date: Tue, 20 Jun 2017 15:22:58 -0700 Subject: [PATCH 076/200] Use state / timer API for DirectRunner timer firings --- .../runners/direct/evaluation_context.py | 3 +- .../apache_beam/runners/direct/executor.py | 37 +++++++----- .../runners/direct/transform_evaluator.py | 48 +++++++++++++--- .../direct/{transform_result.py => util.py} | 34 ++++++++--- .../runners/direct/watermark_manager.py | 56 +++++++++++-------- sdks/python/apache_beam/transforms/trigger.py | 10 +++- 6 files changed, 131 insertions(+), 57 deletions(-) rename sdks/python/apache_beam/runners/direct/{transform_result.py => util.py} (61%) diff --git a/sdks/python/apache_beam/runners/direct/evaluation_context.py b/sdks/python/apache_beam/runners/direct/evaluation_context.py index 8fa8e06922d03..976e9e8c8e958 100644 --- a/sdks/python/apache_beam/runners/direct/evaluation_context.py +++ b/sdks/python/apache_beam/runners/direct/evaluation_context.py @@ -148,7 +148,8 @@ def __init__(self, pipeline_options, bundle_factory, root_transforms, self._transform_keyed_states = self._initialize_keyed_states( root_transforms, value_to_consumers) self._watermark_manager = WatermarkManager( - Clock(), root_transforms, value_to_consumers) + Clock(), root_transforms, value_to_consumers, + self._transform_keyed_states) self._side_inputs_container = _SideInputsContainer(views) self._pending_unblocked_tasks = [] self._counter_factory = counters.CounterFactory() diff --git a/sdks/python/apache_beam/runners/direct/executor.py b/sdks/python/apache_beam/runners/direct/executor.py index eff2d3c41e66d..a0a3886f733c8 100644 --- a/sdks/python/apache_beam/runners/direct/executor.py +++ b/sdks/python/apache_beam/runners/direct/executor.py @@ -222,14 +222,14 @@ class _CompletionCallback(object): or for a source transform. """ - def __init__(self, evaluation_context, all_updates, timers=None): + def __init__(self, evaluation_context, all_updates, timer_firings=None): self._evaluation_context = evaluation_context self._all_updates = all_updates - self._timers = timers + self._timer_firings = timer_firings or [] def handle_result(self, input_committed_bundle, transform_result): output_committed_bundles = self._evaluation_context.handle_result( - input_committed_bundle, self._timers, transform_result) + input_committed_bundle, self._timer_firings, transform_result) for output_committed_bundle in output_committed_bundles: self._all_updates.offer(_ExecutorServiceParallelExecutor._ExecutorUpdate( output_committed_bundle, None)) @@ -251,11 +251,12 @@ class TransformExecutor(_ExecutorService.CallableTask): """ def __init__(self, transform_evaluator_registry, evaluation_context, - input_bundle, applied_ptransform, completion_callback, - transform_evaluation_state): + input_bundle, fired_timers, applied_ptransform, + completion_callback, transform_evaluation_state): self._transform_evaluator_registry = transform_evaluator_registry self._evaluation_context = evaluation_context self._input_bundle = input_bundle + self._fired_timers = fired_timers self._applied_ptransform = applied_ptransform self._completion_callback = completion_callback self._transform_evaluation_state = transform_evaluation_state @@ -288,6 +289,10 @@ def call(self): self._applied_ptransform, self._input_bundle, side_input_values, scoped_metrics_container) + if self._fired_timers: + for timer_firing in self._fired_timers: + evaluator.process_timer_wrapper(timer_firing) + if self._input_bundle: for value in self._input_bundle.get_elements_iterable(): evaluator.process_element(value) @@ -379,11 +384,11 @@ def schedule_consumers(self, committed_bundle): if committed_bundle.pcollection in self.value_to_consumers: consumers = self.value_to_consumers[committed_bundle.pcollection] for applied_ptransform in consumers: - self.schedule_consumption(applied_ptransform, committed_bundle, + self.schedule_consumption(applied_ptransform, committed_bundle, [], self.default_completion_callback) def schedule_consumption(self, consumer_applied_ptransform, committed_bundle, - on_complete): + fired_timers, on_complete): """Schedules evaluation of the given bundle with the transform.""" assert consumer_applied_ptransform assert committed_bundle @@ -397,8 +402,8 @@ def schedule_consumption(self, consumer_applied_ptransform, committed_bundle, transform_executor = TransformExecutor( self.transform_evaluator_registry, self.evaluation_context, - committed_bundle, consumer_applied_ptransform, on_complete, - transform_executor_service) + committed_bundle, fired_timers, consumer_applied_ptransform, + on_complete, transform_executor_service) transform_executor_service.schedule(transform_executor) class _TypedUpdateQueue(object): @@ -527,19 +532,21 @@ def _fire_timers(self): Returns: True if timers fired. """ - fired_timers = self._executor.evaluation_context.extract_fired_timers() - for applied_ptransform in fired_timers: + transform_fired_timers = ( + self._executor.evaluation_context.extract_fired_timers()) + for applied_ptransform, fired_timers in transform_fired_timers: # Use an empty committed bundle. just to trigger. empty_bundle = ( self._executor.evaluation_context.create_empty_committed_bundle( applied_ptransform.inputs[0])) timer_completion_callback = _CompletionCallback( self._executor.evaluation_context, self._executor.all_updates, - applied_ptransform) + timer_firings=fired_timers) self._executor.schedule_consumption( - applied_ptransform, empty_bundle, timer_completion_callback) - return bool(fired_timers) + applied_ptransform, empty_bundle, fired_timers, + timer_completion_callback) + return bool(transform_fired_timers) def _is_executing(self): """Returns True if there is at least one non-blocked TransformExecutor.""" @@ -582,6 +589,6 @@ def _add_work_if_necessary(self, timers_fired): applied_ptransform, []) for bundle in pending_bundles: self._executor.schedule_consumption( - applied_ptransform, bundle, + applied_ptransform, bundle, [], self._executor.default_completion_callback) self._executor.node_to_pending_bundles[applied_ptransform] = [] diff --git a/sdks/python/apache_beam/runners/direct/transform_evaluator.py b/sdks/python/apache_beam/runners/direct/transform_evaluator.py index 6e73561d3fe08..e92d799e3ed35 100644 --- a/sdks/python/apache_beam/runners/direct/transform_evaluator.py +++ b/sdks/python/apache_beam/runners/direct/transform_evaluator.py @@ -28,13 +28,15 @@ from apache_beam.runners.common import DoFnRunner from apache_beam.runners.common import DoFnState from apache_beam.runners.direct.watermark_manager import WatermarkManager -from apache_beam.runners.direct.transform_result import TransformResult +from apache_beam.runners.direct.util import KeyedWorkItem +from apache_beam.runners.direct.util import TransformResult from apache_beam.runners.dataflow.native_io.iobase import _NativeWrite # pylint: disable=protected-access from apache_beam.transforms import core from apache_beam.transforms.window import GlobalWindows from apache_beam.transforms.window import WindowedValue from apache_beam.transforms.trigger import _CombiningValueStateTag from apache_beam.transforms.trigger import _ListStateTag +from apache_beam.transforms.trigger import TimeDomain from apache_beam.typehints.typecheck import OutputCheckWrapperDoFn from apache_beam.typehints.typecheck import TypeCheckError from apache_beam.typehints.typecheck import TypeCheckWrapperDoFn @@ -199,6 +201,25 @@ def start_bundle(self): """Starts a new bundle.""" pass + def process_timer_wrapper(self, timer_firing): + """Process timer by clearing and then calling process_timer(). + + This method is called with any timer firing and clears the delivered + timer from the keyed state and then calls process_timer(). The default + process_timer() implementation emits a KeyedWorkItem for the particular + timer and passes it to process_element(). Evaluator subclasses which + desire different timer delivery semantics can override process_timer(). + """ + state = self.step_context.get_keyed_state(timer_firing.key) + state.clear_timer( + timer_firing.window, timer_firing.name, timer_firing.time_domain) + self.process_timer(timer_firing) + + def process_timer(self, timer_firing): + """Default process_timer() impl. generating KeyedWorkItem element.""" + self.process_element( + KeyedWorkItem(timer_firing.key, timer_firing=timer_firing)) + def process_element(self, element): """Processes a new element as part of the current bundle.""" raise NotImplementedError('%s do not process elements.', type(self)) @@ -244,7 +265,7 @@ def _read_values_to_bundles(reader): bundles = _read_values_to_bundles(reader) return TransformResult( - self._applied_ptransform, bundles, None, None, None) + self._applied_ptransform, bundles, None, None) class _FlattenEvaluator(_TransformEvaluator): @@ -268,7 +289,7 @@ def process_element(self, element): def finish_bundle(self): bundles = [self.bundle] return TransformResult( - self._applied_ptransform, bundles, None, None, None) + self._applied_ptransform, bundles, None, None) class _TaggedReceivers(dict): @@ -357,7 +378,7 @@ def finish_bundle(self): bundles = self._tagged_receivers.values() result_counters = self._counter_factory.get_counters() return TransformResult( - self._applied_ptransform, bundles, None, result_counters, None, + self._applied_ptransform, bundles, result_counters, None, self._tagged_receivers.undeclared_in_memory_tag_values) @@ -375,7 +396,6 @@ def __init__(self, evaluation_context, applied_ptransform, evaluation_context, applied_ptransform, input_committed_bundle, side_inputs, scoped_metrics_container) - @property def _is_final_bundle(self): return (self._execution_context.watermarks.input_watermark == WatermarkManager.WATERMARK_POS_INF) @@ -392,6 +412,10 @@ def start_bundle(self): self._applied_ptransform.transform.get_type_hints().input_types[0]) self.key_coder = coders.registry.get_coder(kv_type_hint[0].tuple_types[0]) + def process_timer(self, timer_firing): + # We do not need to emit a KeyedWorkItem to process_element(). + pass + def process_element(self, element): assert not self.global_state.get_state( None, _GroupByKeyOnlyEvaluator.COMPLETION_TAG) @@ -408,7 +432,7 @@ def process_element(self, element): % element) def finish_bundle(self): - if self._is_final_bundle: + if self._is_final_bundle(): if self.global_state.get_state( None, _GroupByKeyOnlyEvaluator.COMPLETION_TAG): # Ignore empty bundles after emitting output. (This may happen because @@ -441,9 +465,11 @@ def len_element_fn(element): else: bundles = [] hold = WatermarkManager.WATERMARK_NEG_INF + self.global_state.set_timer( + None, '', TimeDomain.WATERMARK, WatermarkManager.WATERMARK_POS_INF) return TransformResult( - self._applied_ptransform, bundles, None, None, hold) + self._applied_ptransform, bundles, None, hold) class _NativeWriteEvaluator(_TransformEvaluator): @@ -475,6 +501,10 @@ def start_bundle(self): self.step_context = self._execution_context.get_step_context() self.global_state = self.step_context.get_keyed_state(None) + def process_timer(self, timer_firing): + # We do not need to emit a KeyedWorkItem to process_element(). + pass + def process_element(self, element): self.global_state.add_state( None, _NativeWriteEvaluator.ELEMENTS_TAG, element) @@ -500,6 +530,8 @@ def finish_bundle(self): hold = WatermarkManager.WATERMARK_POS_INF else: hold = WatermarkManager.WATERMARK_NEG_INF + self.global_state.set_timer( + None, '', TimeDomain.WATERMARK, WatermarkManager.WATERMARK_POS_INF) return TransformResult( - self._applied_ptransform, [], None, None, hold) + self._applied_ptransform, [], None, hold) diff --git a/sdks/python/apache_beam/runners/direct/transform_result.py b/sdks/python/apache_beam/runners/direct/util.py similarity index 61% rename from sdks/python/apache_beam/runners/direct/transform_result.py rename to sdks/python/apache_beam/runners/direct/util.py index 51593e3a434ba..daaaceb4738f7 100644 --- a/sdks/python/apache_beam/runners/direct/transform_result.py +++ b/sdks/python/apache_beam/runners/direct/util.py @@ -15,26 +15,44 @@ # limitations under the License. # -"""The result of evaluating an AppliedPTransform with a TransformEvaluator.""" +"""Utility classes used by the DirectRunner. + +For internal use only. No backwards compatibility guarantees. +""" from __future__ import absolute_import class TransformResult(object): - """For internal use only; no backwards-compatibility guarantees. - - The result of evaluating an AppliedPTransform with a TransformEvaluator.""" + """Result of evaluating an AppliedPTransform with a TransformEvaluator.""" def __init__(self, applied_ptransform, uncommitted_output_bundles, - timer_update, counters, watermark_hold, - undeclared_tag_values=None): + counters, watermark_hold, undeclared_tag_values=None): self.transform = applied_ptransform self.uncommitted_output_bundles = uncommitted_output_bundles - # TODO: timer update is currently unused. - self.timer_update = timer_update self.counters = counters self.watermark_hold = watermark_hold # Only used when caching (materializing) all values is requested. self.undeclared_tag_values = undeclared_tag_values # Populated by the TransformExecutor. self.logical_metric_updates = None + + +class TimerFiring(object): + """A single instance of a fired timer.""" + + def __init__(self, key, window, name, time_domain, timestamp): + self.key = key + self.window = window + self.name = name + self.time_domain = time_domain + self.timestamp = timestamp + + +class KeyedWorkItem(object): + """A keyed item that can either be a timer firing or a list of elements.""" + def __init__(self, key, timer_firing=None, elements=None): + self.key = key + assert not timer_firing and elements + self.timer_firing = timer_firing + self.elements = elements diff --git a/sdks/python/apache_beam/runners/direct/watermark_manager.py b/sdks/python/apache_beam/runners/direct/watermark_manager.py index 0d7cd4fd79c00..10d25d7f07aa0 100644 --- a/sdks/python/apache_beam/runners/direct/watermark_manager.py +++ b/sdks/python/apache_beam/runners/direct/watermark_manager.py @@ -23,6 +23,7 @@ from apache_beam import pipeline from apache_beam import pvalue +from apache_beam.runners.direct.util import TimerFiring from apache_beam.utils.timestamp import MAX_TIMESTAMP from apache_beam.utils.timestamp import MIN_TIMESTAMP from apache_beam.utils.timestamp import TIME_GRANULARITY @@ -36,21 +37,23 @@ class WatermarkManager(object): WATERMARK_POS_INF = MAX_TIMESTAMP WATERMARK_NEG_INF = MIN_TIMESTAMP - def __init__(self, clock, root_transforms, value_to_consumers): + def __init__(self, clock, root_transforms, value_to_consumers, + transform_keyed_states): self._clock = clock # processing time clock - self._value_to_consumers = value_to_consumers self._root_transforms = root_transforms + self._value_to_consumers = value_to_consumers + self._transform_keyed_states = transform_keyed_states # AppliedPTransform -> TransformWatermarks self._transform_to_watermarks = {} for root_transform in root_transforms: self._transform_to_watermarks[root_transform] = _TransformWatermarks( - self._clock) + self._clock, transform_keyed_states[root_transform], root_transform) for consumers in value_to_consumers.values(): for consumer in consumers: self._transform_to_watermarks[consumer] = _TransformWatermarks( - self._clock) + self._clock, transform_keyed_states[consumer], consumer) for consumers in value_to_consumers.values(): for consumer in consumers: @@ -90,16 +93,17 @@ def get_watermarks(self, applied_ptransform): return self._transform_to_watermarks[applied_ptransform] def update_watermarks(self, completed_committed_bundle, applied_ptransform, - timer_update, outputs, earliest_hold): + completed_timers, outputs, earliest_hold): assert isinstance(applied_ptransform, pipeline.AppliedPTransform) self._update_pending( - completed_committed_bundle, applied_ptransform, timer_update, outputs) + completed_committed_bundle, applied_ptransform, completed_timers, + outputs) tw = self.get_watermarks(applied_ptransform) tw.hold(earliest_hold) self._refresh_watermarks(applied_ptransform) def _update_pending(self, input_committed_bundle, applied_ptransform, - timer_update, output_committed_bundles): + completed_timers, output_committed_bundles): """Updated list of pending bundles for the given AppliedPTransform.""" # Update pending elements. Filter out empty bundles. They do not impact @@ -113,7 +117,7 @@ def _update_pending(self, input_committed_bundle, applied_ptransform, consumer_tw.add_pending(output) completed_tw = self._transform_to_watermarks[applied_ptransform] - completed_tw.update_timers(timer_update) + completed_tw.update_timers(completed_timers) assert input_committed_bundle or applied_ptransform in self._root_transforms if input_committed_bundle and input_committed_bundle.has_elements(): @@ -137,33 +141,37 @@ def _refresh_watermarks(self, applied_ptransform): def extract_fired_timers(self): all_timers = [] for applied_ptransform, tw in self._transform_to_watermarks.iteritems(): - if tw.extract_fired_timers(): - all_timers.append(applied_ptransform) + fired_timers = tw.extract_fired_timers() + if fired_timers: + all_timers.append((applied_ptransform, fired_timers)) return all_timers class _TransformWatermarks(object): - """Tracks input and output watermarks for aan AppliedPTransform.""" + """Tracks input and output watermarks for an AppliedPTransform.""" - def __init__(self, clock): + def __init__(self, clock, keyed_states, transform): self._clock = clock + self._keyed_states = keyed_states self._input_transform_watermarks = [] self._input_watermark = WatermarkManager.WATERMARK_NEG_INF self._output_watermark = WatermarkManager.WATERMARK_NEG_INF self._earliest_hold = WatermarkManager.WATERMARK_POS_INF self._pending = set() # Scheduled bundles targeted for this transform. - self._fired_timers = False + self._fired_timers = set() self._lock = threading.Lock() + self._label = str(transform) + def update_input_transform_watermarks(self, input_transform_watermarks): with self._lock: self._input_transform_watermarks = input_transform_watermarks - def update_timers(self, timer_update): + def update_timers(self, completed_timers): with self._lock: - if timer_update: - assert self._fired_timers - self._fired_timers = False + for timer_firing in completed_timers: + print 'REMOVE', timer_firing + self._fired_timers.remove(timer_firing) @property def input_watermark(self): @@ -233,8 +241,12 @@ def extract_fired_timers(self): if self._fired_timers: return False - should_fire = ( - self._earliest_hold < WatermarkManager.WATERMARK_POS_INF and - self._input_watermark == WatermarkManager.WATERMARK_POS_INF) - self._fired_timers = should_fire - return should_fire + fired_timers = [] + for key, state in self._keyed_states.iteritems(): + timers = state.get_timers(watermark=self._input_watermark) + for expired in timers: + window, (name, time_domain, timestamp) = expired + fired_timers.append( + TimerFiring(key, window, name, time_domain, timestamp)) + self._fired_timers.update(fired_timers) + return fired_timers diff --git a/sdks/python/apache_beam/transforms/trigger.py b/sdks/python/apache_beam/transforms/trigger.py index 89c6ec535db9d..7ff44fa8fde39 100644 --- a/sdks/python/apache_beam/transforms/trigger.py +++ b/sdks/python/apache_beam/transforms/trigger.py @@ -1102,17 +1102,21 @@ def clear_state(self, window, tag): if not self.state[window]: self.state.pop(window, None) - def get_and_clear_timers(self, watermark=MAX_TIMESTAMP): + def get_timers(self, clear=False, watermark=MAX_TIMESTAMP): expired = [] for window, timers in list(self.timers.items()): for (name, time_domain), timestamp in list(timers.items()): if timestamp <= watermark: expired.append((window, (name, time_domain, timestamp))) - del timers[(name, time_domain)] - if not timers: + if clear: + del timers[(name, time_domain)] + if not timers and clear: del self.timers[window] return expired + def get_and_clear_timers(self, watermark=MAX_TIMESTAMP): + return self.get_timers(clear=True, watermark=watermark) + def __repr__(self): state_str = '\n'.join('%s: %s' % (key, dict(state)) for key, state in self.state.items()) From 65a6d66251b081d540de85ed55dce3b62de797e7 Mon Sep 17 00:00:00 2001 From: jasonkuster Date: Wed, 21 Jun 2017 09:40:51 -0700 Subject: [PATCH 077/200] Turn notifications for broken Windows test off. --- .../job_beam_PostCommit_Java_MavenInstall_Windows.groovy | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.test-infra/jenkins/job_beam_PostCommit_Java_MavenInstall_Windows.groovy b/.test-infra/jenkins/job_beam_PostCommit_Java_MavenInstall_Windows.groovy index f781b4eec033b..6ef272cfbe8d1 100644 --- a/.test-infra/jenkins/job_beam_PostCommit_Java_MavenInstall_Windows.groovy +++ b/.test-infra/jenkins/job_beam_PostCommit_Java_MavenInstall_Windows.groovy @@ -32,7 +32,8 @@ mavenJob('beam_PostCommit_Java_MavenInstall_Windows') { common_job_properties.setMavenConfig(delegate, 'Maven 3.3.3 (Windows)') // Sets that this is a PostCommit job. - common_job_properties.setPostCommit(delegate, '0 */6 * * *', false) + // TODO(BEAM-1042, BEAM-1045, BEAM-2269, BEAM-2299) Turn notifications back on once fixed. + common_job_properties.setPostCommit(delegate, '0 */6 * * *', false, '', false) // Allows triggering this build against pull requests. common_job_properties.enablePhraseTriggeringFromPullRequest( From c11f0ff57efca5786fb5da20006d9eb96b44cffe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Isma=C3=ABl=20Mej=C3=ADa?= Date: Fri, 9 Jun 2017 00:01:55 +0200 Subject: [PATCH 078/200] Fix minor issues on HCatalogIO - Restrict access level when possible - Rename Filter to Partition for the write to be coherent with the HCatalog API - Improve test coverage - Fix documentation details - Implement TearDown method for the writer --- .../beam/sdk/io/hcatalog/HCatalogIO.java | 113 ++++++++---------- .../io/hcatalog/EmbeddedMetastoreService.java | 3 +- .../beam/sdk/io/hcatalog/HCatalogIOTest.java | 54 +++++---- .../sdk/io/hcatalog/HCatalogIOTestUtils.java | 22 ++-- 4 files changed, 90 insertions(+), 102 deletions(-) diff --git a/sdks/java/io/hcatalog/src/main/java/org/apache/beam/sdk/io/hcatalog/HCatalogIO.java b/sdks/java/io/hcatalog/src/main/java/org/apache/beam/sdk/io/hcatalog/HCatalogIO.java index 07b56e3f650cd..1549dab048422 100644 --- a/sdks/java/io/hcatalog/src/main/java/org/apache/beam/sdk/io/hcatalog/HCatalogIO.java +++ b/sdks/java/io/hcatalog/src/main/java/org/apache/beam/sdk/io/hcatalog/HCatalogIO.java @@ -78,11 +78,10 @@ * * pipeline * .apply(HCatalogIO.read() - * .withConfigProperties(configProperties) //mandatory - * .withTable("employee") //mandatory + * .withConfigProperties(configProperties) * .withDatabase("default") //optional, assumes default if none specified - * .withFilter(filterString) //optional, - * should be specified if the table is partitioned + * .withTable("employee") + * .withFilter(filterString) //optional, may be specified if the table is partitioned * } * *

    Writing using HCatalog

    @@ -100,13 +99,11 @@ * pipeline * .apply(...) * .apply(HiveIO.write() - * .withConfigProperties(configProperties) //mandatory - * .withTable("employee") //mandatory + * .withConfigProperties(configProperties) * .withDatabase("default") //optional, assumes default if none specified - * .withFilter(partitionValues) //optional, - * should be specified if the table is partitioned - * .withBatchSize(1024L)) //optional, - * assumes a default batch size of 1024 if none specified + * .withTable("employee") + * .withPartition(partitionValues) //optional, may be specified if the table is partitioned + * .withBatchSize(1024L)) //optional, assumes a default batch size of 1024 if none specified * } */ @Experimental @@ -114,14 +111,17 @@ public class HCatalogIO { private static final Logger LOG = LoggerFactory.getLogger(HCatalogIO.class); + private static final long BATCH_SIZE = 1024L; + private static final String DEFAULT_DATABASE = "default"; + /** Write data to Hive. */ public static Write write() { - return new AutoValue_HCatalogIO_Write.Builder().setBatchSize(1024L).build(); + return new AutoValue_HCatalogIO_Write.Builder().setBatchSize(BATCH_SIZE).build(); } /** Read data from Hive. */ public static Read read() { - return new AutoValue_HCatalogIO_Read.Builder().setDatabase("default").build(); + return new AutoValue_HCatalogIO_Read.Builder().setDatabase(DEFAULT_DATABASE).build(); } private HCatalogIO() {} @@ -130,44 +130,26 @@ private HCatalogIO() {} @VisibleForTesting @AutoValue public abstract static class Read extends PTransform> { - @Nullable - abstract Map getConfigProperties(); - - @Nullable - abstract String getDatabase(); - - @Nullable - abstract String getTable(); - - @Nullable - abstract String getFilter(); - - @Nullable - abstract ReaderContext getContext(); - - @Nullable - abstract Integer getSplitId(); - + @Nullable abstract Map getConfigProperties(); + @Nullable abstract String getDatabase(); + @Nullable abstract String getTable(); + @Nullable abstract String getFilter(); + @Nullable abstract ReaderContext getContext(); + @Nullable abstract Integer getSplitId(); abstract Builder toBuilder(); @AutoValue.Builder abstract static class Builder { abstract Builder setConfigProperties(Map configProperties); - abstract Builder setDatabase(String database); - abstract Builder setTable(String table); - abstract Builder setFilter(String filter); - abstract Builder setSplitId(Integer splitId); - abstract Builder setContext(ReaderContext context); - abstract Read build(); } - /** Sets the configuration properties like metastore URI. This is mandatory */ + /** Sets the configuration properties like metastore URI. */ public Read withConfigProperties(Map configProperties) { return toBuilder().setConfigProperties(new HashMap<>(configProperties)).build(); } @@ -177,12 +159,12 @@ public Read withDatabase(String database) { return toBuilder().setDatabase(database).build(); } - /** Sets the table name to read from. This is mandatory */ + /** Sets the table name to read from. */ public Read withTable(String table) { return toBuilder().setTable(table).build(); } - /** Sets the filter (partition) details. This is optional, assumes none if not specified */ + /** Sets the filter details. This is optional, assumes none if not specified */ public Read withFilter(String filter) { return toBuilder().setFilter(filter).build(); } @@ -220,7 +202,7 @@ public void populateDisplayData(DisplayData.Builder builder) { /** A HCatalog {@link BoundedSource} reading {@link HCatRecord} from a given instance. */ @VisibleForTesting static class BoundedHCatalogSource extends BoundedSource { - private Read spec; + private final Read spec; BoundedHCatalogSource(Read spec) { this.spec = spec; @@ -367,38 +349,24 @@ public void close() { /** A {@link PTransform} to write to a HCatalog managed source. */ @AutoValue public abstract static class Write extends PTransform, PDone> { - @Nullable - abstract Map getConfigProperties(); - - @Nullable - abstract String getDatabase(); - - @Nullable - abstract String getTable(); - - @Nullable - abstract Map getFilter(); - + @Nullable abstract Map getConfigProperties(); + @Nullable abstract String getDatabase(); + @Nullable abstract String getTable(); + @Nullable abstract Map getPartition(); abstract long getBatchSize(); - abstract Builder toBuilder(); @AutoValue.Builder abstract static class Builder { abstract Builder setConfigProperties(Map configProperties); - abstract Builder setDatabase(String database); - abstract Builder setTable(String table); - - abstract Builder setFilter(Map partition); - + abstract Builder setPartition(Map partition); abstract Builder setBatchSize(long batchSize); - abstract Write build(); } - /** Sets the configuration properties like metastore URI. This is mandatory */ + /** Sets the configuration properties like metastore URI. */ public Write withConfigProperties(Map configProperties) { return toBuilder().setConfigProperties(new HashMap<>(configProperties)).build(); } @@ -408,14 +376,14 @@ public Write withDatabase(String database) { return toBuilder().setDatabase(database).build(); } - /** Sets the table name to write to, the table should exist beforehand. This is mandatory */ + /** Sets the table name to write to, the table should exist beforehand. */ public Write withTable(String table) { return toBuilder().setTable(table).build(); } - /** Sets the filter (partition) details. This is required if the table is partitioned */ - public Write withFilter(Map filter) { - return toBuilder().setFilter(filter).build(); + /** Sets the partition details. */ + public Write withPartition(Map partition) { + return toBuilder().setPartition(partition).build(); } /** @@ -454,7 +422,7 @@ public void populateDisplayData(DisplayData.Builder builder) { super.populateDisplayData(builder); builder.addIfNotNull(DisplayData.item("database", spec.getDatabase())); builder.add(DisplayData.item("table", spec.getTable())); - builder.addIfNotNull(DisplayData.item("filter", String.valueOf(spec.getFilter()))); + builder.addIfNotNull(DisplayData.item("partition", String.valueOf(spec.getPartition()))); builder.add(DisplayData.item("configProperties", spec.getConfigProperties().toString())); builder.add(DisplayData.item("batchSize", spec.getBatchSize())); } @@ -465,7 +433,7 @@ public void initiateWrite() throws HCatException { new WriteEntity.Builder() .withDatabase(spec.getDatabase()) .withTable(spec.getTable()) - .withPartition(spec.getFilter()) + .withPartition(spec.getPartition()) .build(); masterWriter = DataTransferFactory.getHCatWriter(entity, spec.getConfigProperties()); writerContext = masterWriter.prepareWrite(); @@ -506,6 +474,19 @@ private void flush() throws HCatException { hCatRecordsBatch.clear(); } } + + @Teardown + public void tearDown() throws Exception { + if (slaveWriter != null) { + slaveWriter = null; + } + if (masterWriter != null) { + masterWriter = null; + } + if (writerContext != null) { + writerContext = null; + } + } } } } diff --git a/sdks/java/io/hcatalog/src/test/java/org/apache/beam/sdk/io/hcatalog/EmbeddedMetastoreService.java b/sdks/java/io/hcatalog/src/test/java/org/apache/beam/sdk/io/hcatalog/EmbeddedMetastoreService.java index 5792bf6f810cd..31e5b1cf44d28 100644 --- a/sdks/java/io/hcatalog/src/test/java/org/apache/beam/sdk/io/hcatalog/EmbeddedMetastoreService.java +++ b/sdks/java/io/hcatalog/src/test/java/org/apache/beam/sdk/io/hcatalog/EmbeddedMetastoreService.java @@ -35,7 +35,7 @@ * https://github.com/apache/hive/blob/master/hcatalog/core/src/test/java/org/apache/hive/hcatalog/mapreduce * /HCatBaseTest.java */ -public final class EmbeddedMetastoreService implements AutoCloseable { +final class EmbeddedMetastoreService implements AutoCloseable { private final Driver driver; private final HiveConf hiveConf; private final SessionState sessionState; @@ -57,7 +57,6 @@ public final class EmbeddedMetastoreService implements AutoCloseable { hiveConf.setVar(HiveConf.ConfVars.POSTEXECHOOKS, ""); hiveConf.setBoolVar(HiveConf.ConfVars.HIVE_SUPPORT_CONCURRENCY, false); hiveConf.setVar(HiveConf.ConfVars.METASTOREWAREHOUSE, testWarehouseDirPath); - hiveConf.setVar(HiveConf.ConfVars.HIVEMAPREDMODE, "nonstrict"); hiveConf.setBoolVar(HiveConf.ConfVars.HIVEOPTIMIZEMETADATAQUERIES, true); hiveConf.setVar( HiveConf.ConfVars.HIVE_AUTHORIZATION_MANAGER, diff --git a/sdks/java/io/hcatalog/src/test/java/org/apache/beam/sdk/io/hcatalog/HCatalogIOTest.java b/sdks/java/io/hcatalog/src/test/java/org/apache/beam/sdk/io/hcatalog/HCatalogIOTest.java index 49c538f0eb695..91671a522a598 100644 --- a/sdks/java/io/hcatalog/src/test/java/org/apache/beam/sdk/io/hcatalog/HCatalogIOTest.java +++ b/sdks/java/io/hcatalog/src/test/java/org/apache/beam/sdk/io/hcatalog/HCatalogIOTest.java @@ -17,8 +17,10 @@ */ package org.apache.beam.sdk.io.hcatalog; +import static org.apache.beam.sdk.io.hcatalog.HCatalogIOTestUtils.TEST_DATABASE; +import static org.apache.beam.sdk.io.hcatalog.HCatalogIOTestUtils.TEST_FILTER; import static org.apache.beam.sdk.io.hcatalog.HCatalogIOTestUtils.TEST_RECORDS_COUNT; -import static org.apache.beam.sdk.io.hcatalog.HCatalogIOTestUtils.TEST_TABLE_NAME; +import static org.apache.beam.sdk.io.hcatalog.HCatalogIOTestUtils.TEST_TABLE; import static org.apache.beam.sdk.io.hcatalog.HCatalogIOTestUtils.getConfigPropertiesAsMap; import static org.apache.beam.sdk.io.hcatalog.HCatalogIOTestUtils.getExpectedRecords; import static org.apache.beam.sdk.io.hcatalog.HCatalogIOTestUtils.getHCatRecords; @@ -69,7 +71,7 @@ /** Test for HCatalogIO. */ public class HCatalogIOTest implements Serializable { - public static final PipelineOptions OPTIONS = PipelineOptionsFactory.create(); + private static final PipelineOptions OPTIONS = PipelineOptionsFactory.create(); @ClassRule public static final TemporaryFolder TMP_FOLDER = new TemporaryFolder(); @@ -103,12 +105,12 @@ public void evaluate() throws Throwable { /** Use this annotation to setup complete test data(table populated with records). */ @Retention(RetentionPolicy.RUNTIME) @Target({ElementType.METHOD}) - @interface NeedsTestData {} + private @interface NeedsTestData {} /** Use this annotation to setup test tables alone(empty tables, no records are populated). */ @Retention(RetentionPolicy.RUNTIME) @Target({ElementType.METHOD}) - @interface NeedsEmptyTestTables {} + private @interface NeedsEmptyTestTables {} @BeforeClass public static void setupEmbeddedMetastoreService () throws IOException { @@ -117,7 +119,7 @@ public static void setupEmbeddedMetastoreService () throws IOException { @AfterClass public static void shutdownEmbeddedMetastoreService () throws Exception { - service.executeQuery("drop table " + TEST_TABLE_NAME); + service.executeQuery("drop table " + TEST_TABLE); service.close(); } @@ -130,23 +132,27 @@ public void testWriteThenReadSuccess() throws Exception { .apply( HCatalogIO.write() .withConfigProperties(getConfigPropertiesAsMap(service.getHiveConf())) - .withTable(TEST_TABLE_NAME)); + .withDatabase(TEST_DATABASE) + .withTable(TEST_TABLE) + .withPartition(new java.util.HashMap()) + .withBatchSize(512L)); defaultPipeline.run(); - PCollection output = - readAfterWritePipeline - .apply( - HCatalogIO.read() - .withConfigProperties(getConfigPropertiesAsMap(service.getHiveConf())) - .withTable(HCatalogIOTestUtils.TEST_TABLE_NAME)) - .apply( - ParDo.of( - new DoFn() { - @ProcessElement - public void processElement(ProcessContext c) { - c.output(c.element().get(0).toString()); - } - })); + PCollection output = readAfterWritePipeline + .apply( + HCatalogIO.read() + .withConfigProperties(getConfigPropertiesAsMap(service.getHiveConf())) + .withDatabase(TEST_DATABASE) + .withTable(TEST_TABLE) + .withFilter(TEST_FILTER)) + .apply( + ParDo.of( + new DoFn() { + @ProcessElement + public void processElement(ProcessContext c) { + c.output(c.element().get(0).toString()); + } + })); PAssert.that(output).containsInAnyOrder(getExpectedRecords(TEST_RECORDS_COUNT)); readAfterWritePipeline.run(); } @@ -222,7 +228,7 @@ public void testReadFromSource() throws Exception { HCatalogIO.read() .withConfigProperties(getConfigPropertiesAsMap(service.getHiveConf())) .withContext(context) - .withTable(TEST_TABLE_NAME); + .withTable(TEST_TABLE); List records = new ArrayList<>(); for (int i = 0; i < context.numSplits(); i++) { @@ -246,7 +252,7 @@ public void testSourceEqualsSplits() throws Exception { HCatalogIO.read() .withConfigProperties(getConfigPropertiesAsMap(service.getHiveConf())) .withContext(context) - .withTable(TEST_TABLE_NAME); + .withTable(TEST_TABLE); BoundedHCatalogSource source = new BoundedHCatalogSource(spec); List> unSplitSource = source.split(-1, OPTIONS); @@ -260,8 +266,8 @@ public void testSourceEqualsSplits() throws Exception { } private void reCreateTestTable() throws CommandNeedRetryException { - service.executeQuery("drop table " + TEST_TABLE_NAME); - service.executeQuery("create table " + TEST_TABLE_NAME + "(mycol1 string, mycol2 int)"); + service.executeQuery("drop table " + TEST_TABLE); + service.executeQuery("create table " + TEST_TABLE + "(mycol1 string, mycol2 int)"); } private void prepareTestData() throws Exception { diff --git a/sdks/java/io/hcatalog/src/test/java/org/apache/beam/sdk/io/hcatalog/HCatalogIOTestUtils.java b/sdks/java/io/hcatalog/src/test/java/org/apache/beam/sdk/io/hcatalog/HCatalogIOTestUtils.java index f66e0bcc1e8fa..ae1eb50d608b7 100644 --- a/sdks/java/io/hcatalog/src/test/java/org/apache/beam/sdk/io/hcatalog/HCatalogIOTestUtils.java +++ b/sdks/java/io/hcatalog/src/test/java/org/apache/beam/sdk/io/hcatalog/HCatalogIOTestUtils.java @@ -35,15 +35,16 @@ import org.apache.hive.hcatalog.data.transfer.WriterContext; /** Utility class for HCatalogIOTest. */ -public class HCatalogIOTestUtils { - public static final String TEST_TABLE_NAME = "mytable"; - - public static final int TEST_RECORDS_COUNT = 1000; +class HCatalogIOTestUtils { + static final String TEST_DATABASE = "default"; + static final String TEST_TABLE = "mytable"; + static final String TEST_FILTER = "myfilter"; + static final int TEST_RECORDS_COUNT = 1000; private static final ReadEntity READ_ENTITY = - new ReadEntity.Builder().withTable(TEST_TABLE_NAME).build(); + new ReadEntity.Builder().withTable(TEST_TABLE).build(); private static final WriteEntity WRITE_ENTITY = - new WriteEntity.Builder().withTable(TEST_TABLE_NAME).build(); + new WriteEntity.Builder().withTable(TEST_TABLE).build(); /** Returns a ReaderContext instance for the passed datastore config params. */ static ReaderContext getReaderContext(Map config) throws HCatException { @@ -51,17 +52,18 @@ static ReaderContext getReaderContext(Map config) throws HCatExc } /** Returns a WriterContext instance for the passed datastore config params. */ - static WriterContext getWriterContext(Map config) throws HCatException { + private static WriterContext getWriterContext(Map config) throws HCatException { return DataTransferFactory.getHCatWriter(WRITE_ENTITY, config).prepareWrite(); } /** Writes records to the table using the passed WriterContext. */ - static void writeRecords(WriterContext context) throws HCatException { + private static void writeRecords(WriterContext context) throws HCatException { DataTransferFactory.getHCatWriter(context).write(getHCatRecords(TEST_RECORDS_COUNT).iterator()); } /** Commits the pending writes to the database. */ - static void commitRecords(Map config, WriterContext context) throws IOException { + private static void commitRecords(Map config, WriterContext context) + throws IOException { DataTransferFactory.getHCatWriter(WRITE_ENTITY, config).commit(context); } @@ -100,7 +102,7 @@ static Map getConfigPropertiesAsMap(HiveConf hiveConf) { } /** returns a DefaultHCatRecord instance for passed value. */ - static DefaultHCatRecord toHCatRecord(int value) { + private static DefaultHCatRecord toHCatRecord(int value) { return new DefaultHCatRecord(Arrays.asList("record " + value, value)); } } From 3520f94882b00aa8db64f6379044689d1b78ac06 Mon Sep 17 00:00:00 2001 From: Charles Chen Date: Tue, 20 Jun 2017 17:16:20 -0700 Subject: [PATCH 079/200] Allow production of unprocessed bundles, introduce TestStream evaluator in DirectRunner --- .../runners/direct/evaluation_context.py | 14 +-- .../apache_beam/runners/direct/executor.py | 40 +++++++-- .../runners/direct/transform_evaluator.py | 88 +++++++++++++++++-- .../python/apache_beam/runners/direct/util.py | 4 +- .../runners/direct/watermark_manager.py | 11 ++- .../python/apache_beam/testing/test_stream.py | 5 ++ .../apache_beam/testing/test_stream_test.py | 37 ++++++++ 7 files changed, 176 insertions(+), 23 deletions(-) diff --git a/sdks/python/apache_beam/runners/direct/evaluation_context.py b/sdks/python/apache_beam/runners/direct/evaluation_context.py index 976e9e8c8e958..669a68a13c7da 100644 --- a/sdks/python/apache_beam/runners/direct/evaluation_context.py +++ b/sdks/python/apache_beam/runners/direct/evaluation_context.py @@ -208,11 +208,12 @@ def handle_result( the committed bundles contained within the handled result. """ with self._lock: - committed_bundles = self._commit_bundles( - result.uncommitted_output_bundles) + committed_bundles, unprocessed_bundles = self._commit_bundles( + result.uncommitted_output_bundles, + result.unprocessed_bundles) self._watermark_manager.update_watermarks( completed_bundle, result.transform, completed_timers, - committed_bundles, result.watermark_hold) + committed_bundles, unprocessed_bundles, result.watermark_hold) self._metrics.commit_logical(completed_bundle, result.logical_metric_updates) @@ -252,14 +253,17 @@ def schedule_pending_unblocked_tasks(self, executor_service): executor_service.submit(task) self._pending_unblocked_tasks = [] - def _commit_bundles(self, uncommitted_bundles): + def _commit_bundles(self, uncommitted_bundles, unprocessed_bundles): """Commits bundles and returns a immutable set of committed bundles.""" for in_progress_bundle in uncommitted_bundles: producing_applied_ptransform = in_progress_bundle.pcollection.producer watermarks = self._watermark_manager.get_watermarks( producing_applied_ptransform) in_progress_bundle.commit(watermarks.synchronized_processing_output_time) - return tuple(uncommitted_bundles) + + for unprocessed_bundle in unprocessed_bundles: + unprocessed_bundle.commit(None) + return tuple(uncommitted_bundles), tuple(unprocessed_bundles) def get_execution_context(self, applied_ptransform): return _ExecutionContext( diff --git a/sdks/python/apache_beam/runners/direct/executor.py b/sdks/python/apache_beam/runners/direct/executor.py index a0a3886f733c8..e70e326978842 100644 --- a/sdks/python/apache_beam/runners/direct/executor.py +++ b/sdks/python/apache_beam/runners/direct/executor.py @@ -227,17 +227,25 @@ def __init__(self, evaluation_context, all_updates, timer_firings=None): self._all_updates = all_updates self._timer_firings = timer_firings or [] - def handle_result(self, input_committed_bundle, transform_result): + def handle_result(self, transform_executor, input_committed_bundle, + transform_result): output_committed_bundles = self._evaluation_context.handle_result( input_committed_bundle, self._timer_firings, transform_result) for output_committed_bundle in output_committed_bundles: self._all_updates.offer(_ExecutorServiceParallelExecutor._ExecutorUpdate( - output_committed_bundle, None)) + transform_executor, + committed_bundle=output_committed_bundle)) + for unprocessed_bundle in transform_result.unprocessed_bundles: + self._all_updates.offer( + _ExecutorServiceParallelExecutor._ExecutorUpdate( + transform_executor, + unprocessed_bundle=unprocessed_bundle)) return output_committed_bundles - def handle_exception(self, exception): + def handle_exception(self, transform_executor, exception): self._all_updates.offer( - _ExecutorServiceParallelExecutor._ExecutorUpdate(None, exception)) + _ExecutorServiceParallelExecutor._ExecutorUpdate( + transform_executor, exception=exception)) class TransformExecutor(_ExecutorService.CallableTask): @@ -312,10 +320,10 @@ def call(self): self._evaluation_context.append_to_cache( self._applied_ptransform, tag, value) - self._completion_callback.handle_result(self._input_bundle, result) + self._completion_callback.handle_result(self, self._input_bundle, result) return result except Exception as e: # pylint: disable=broad-except - self._completion_callback.handle_exception(e) + self._completion_callback.handle_exception(self, e) finally: self._evaluation_context.metrics().commit_physical( self._input_bundle, @@ -387,6 +395,10 @@ def schedule_consumers(self, committed_bundle): self.schedule_consumption(applied_ptransform, committed_bundle, [], self.default_completion_callback) + def schedule_unprocessed_bundle(self, applied_ptransform, + unprocessed_bundle): + self.node_to_pending_bundles[applied_ptransform].append(unprocessed_bundle) + def schedule_consumption(self, consumer_applied_ptransform, committed_bundle, fired_timers, on_complete): """Schedules evaluation of the given bundle with the transform.""" @@ -433,10 +445,16 @@ def offer(self, item): class _ExecutorUpdate(object): """An internal status update on the state of the executor.""" - def __init__(self, produced_bundle=None, exception=None): + def __init__(self, transform_executor, committed_bundle=None, + unprocessed_bundle=None, exception=None): + self.transform_executor = transform_executor # Exactly one of them should be not-None - assert bool(produced_bundle) != bool(exception) - self.committed_bundle = produced_bundle + assert sum([ + bool(committed_bundle), + bool(unprocessed_bundle), + bool(exception)]) == 1 + self.committed_bundle = committed_bundle + self.unprocessed_bundle = unprocessed_bundle self.exception = exception self.exc_info = sys.exc_info() if self.exc_info[1] is not exception: @@ -471,6 +489,10 @@ def call(self): while update: if update.committed_bundle: self._executor.schedule_consumers(update.committed_bundle) + elif update.unprocessed_bundle: + self._executor.schedule_unprocessed_bundle( + update.transform_executor._applied_ptransform, + update.unprocessed_bundle) else: assert update.exception logging.warning('A task failed with exception.\n %s', diff --git a/sdks/python/apache_beam/runners/direct/transform_evaluator.py b/sdks/python/apache_beam/runners/direct/transform_evaluator.py index e92d799e3ed35..3aefbb8d5a1b2 100644 --- a/sdks/python/apache_beam/runners/direct/transform_evaluator.py +++ b/sdks/python/apache_beam/runners/direct/transform_evaluator.py @@ -31,6 +31,10 @@ from apache_beam.runners.direct.util import KeyedWorkItem from apache_beam.runners.direct.util import TransformResult from apache_beam.runners.dataflow.native_io.iobase import _NativeWrite # pylint: disable=protected-access +from apache_beam.testing.test_stream import TestStream +from apache_beam.testing.test_stream import ElementEvent +from apache_beam.testing.test_stream import WatermarkEvent +from apache_beam.testing.test_stream import ProcessingTimeEvent from apache_beam.transforms import core from apache_beam.transforms.window import GlobalWindows from apache_beam.transforms.window import WindowedValue @@ -41,6 +45,7 @@ from apache_beam.typehints.typecheck import TypeCheckError from apache_beam.typehints.typecheck import TypeCheckWrapperDoFn from apache_beam.utils import counters +from apache_beam.utils.timestamp import MIN_TIMESTAMP from apache_beam.options.pipeline_options import TypeOptions @@ -59,9 +64,11 @@ def __init__(self, evaluation_context): core.ParDo: _ParDoEvaluator, core._GroupByKeyOnly: _GroupByKeyOnlyEvaluator, _NativeWrite: _NativeWriteEvaluator, + TestStream: _TestStreamEvaluator, } self._root_bundle_providers = { core.PTransform: DefaultRootBundleProvider, + TestStream: _TestStreamRootBundleProvider, } def get_evaluator( @@ -142,6 +149,23 @@ def get_root_bundles(self): return [empty_bundle] +class _TestStreamRootBundleProvider(RootBundleProvider): + """Provides an initial bundle for the TestStream evaluator.""" + + def get_root_bundles(self): + test_stream = self._applied_ptransform.transform + bundles = [] + if len(test_stream.events) > 0: + bundle = self._evaluation_context.create_bundle( + pvalue.PBegin(self._applied_ptransform.transform.pipeline)) + # Explicitly set timestamp to MIN_TIMESTAMP to ensure that we hold the + # watermark. + bundle.add(GlobalWindows.windowed_value(0, timestamp=MIN_TIMESTAMP)) + bundle.commit(None) + bundles.append(bundle) + return bundles + + class _TransformEvaluator(object): """An evaluator of a specific application of a transform.""" @@ -265,7 +289,61 @@ def _read_values_to_bundles(reader): bundles = _read_values_to_bundles(reader) return TransformResult( - self._applied_ptransform, bundles, None, None) + self._applied_ptransform, bundles, [], None, None) + + +class _TestStreamEvaluator(_TransformEvaluator): + """TransformEvaluator for the TestStream transform.""" + + def __init__(self, evaluation_context, applied_ptransform, + input_committed_bundle, side_inputs, scoped_metrics_container): + assert not side_inputs + self.test_stream = applied_ptransform.transform + super(_TestStreamEvaluator, self).__init__( + evaluation_context, applied_ptransform, input_committed_bundle, + side_inputs, scoped_metrics_container) + + def start_bundle(self): + self.current_index = -1 + self.watermark = MIN_TIMESTAMP + self.bundles = [] + + def process_element(self, element): + index = element.value + self.watermark = element.timestamp + assert isinstance(index, int) + assert 0 <= index <= len(self.test_stream.events) + self.current_index = index + event = self.test_stream.events[self.current_index] + if isinstance(event, ElementEvent): + assert len(self._outputs) == 1 + output_pcollection = list(self._outputs)[0] + bundle = self._evaluation_context.create_bundle(output_pcollection) + for tv in event.timestamped_values: + bundle.output( + GlobalWindows.windowed_value(tv.value, timestamp=tv.timestamp)) + self.bundles.append(bundle) + elif isinstance(event, WatermarkEvent): + assert event.new_watermark >= self.watermark + self.watermark = event.new_watermark + elif isinstance(event, ProcessingTimeEvent): + # TODO(ccy): advance processing time in the context's mock clock. + pass + else: + raise ValueError('Invalid TestStream event: %s.' % event) + + def finish_bundle(self): + unprocessed_bundles = [] + hold = None + if self.current_index < len(self.test_stream.events) - 1: + unprocessed_bundle = self._evaluation_context.create_bundle( + pvalue.PBegin(self._applied_ptransform.transform.pipeline)) + unprocessed_bundle.add(GlobalWindows.windowed_value( + self.current_index + 1, timestamp=self.watermark)) + unprocessed_bundles.append(unprocessed_bundle) + hold = self.watermark + return TransformResult( + self._applied_ptransform, self.bundles, unprocessed_bundles, None, hold) class _FlattenEvaluator(_TransformEvaluator): @@ -289,7 +367,7 @@ def process_element(self, element): def finish_bundle(self): bundles = [self.bundle] return TransformResult( - self._applied_ptransform, bundles, None, None) + self._applied_ptransform, bundles, [], None, None) class _TaggedReceivers(dict): @@ -378,7 +456,7 @@ def finish_bundle(self): bundles = self._tagged_receivers.values() result_counters = self._counter_factory.get_counters() return TransformResult( - self._applied_ptransform, bundles, result_counters, None, + self._applied_ptransform, bundles, [], result_counters, None, self._tagged_receivers.undeclared_in_memory_tag_values) @@ -469,7 +547,7 @@ def len_element_fn(element): None, '', TimeDomain.WATERMARK, WatermarkManager.WATERMARK_POS_INF) return TransformResult( - self._applied_ptransform, bundles, None, hold) + self._applied_ptransform, bundles, [], None, hold) class _NativeWriteEvaluator(_TransformEvaluator): @@ -534,4 +612,4 @@ def finish_bundle(self): None, '', TimeDomain.WATERMARK, WatermarkManager.WATERMARK_POS_INF) return TransformResult( - self._applied_ptransform, [], None, hold) + self._applied_ptransform, [], [], None, hold) diff --git a/sdks/python/apache_beam/runners/direct/util.py b/sdks/python/apache_beam/runners/direct/util.py index daaaceb4738f7..8c846fc55eb4d 100644 --- a/sdks/python/apache_beam/runners/direct/util.py +++ b/sdks/python/apache_beam/runners/direct/util.py @@ -27,9 +27,11 @@ class TransformResult(object): """Result of evaluating an AppliedPTransform with a TransformEvaluator.""" def __init__(self, applied_ptransform, uncommitted_output_bundles, - counters, watermark_hold, undeclared_tag_values=None): + unprocessed_bundles, counters, watermark_hold, + undeclared_tag_values=None): self.transform = applied_ptransform self.uncommitted_output_bundles = uncommitted_output_bundles + self.unprocessed_bundles = unprocessed_bundles self.counters = counters self.watermark_hold = watermark_hold # Only used when caching (materializing) all values is requested. diff --git a/sdks/python/apache_beam/runners/direct/watermark_manager.py b/sdks/python/apache_beam/runners/direct/watermark_manager.py index 10d25d7f07aa0..2146bb5d9b1c2 100644 --- a/sdks/python/apache_beam/runners/direct/watermark_manager.py +++ b/sdks/python/apache_beam/runners/direct/watermark_manager.py @@ -93,17 +93,19 @@ def get_watermarks(self, applied_ptransform): return self._transform_to_watermarks[applied_ptransform] def update_watermarks(self, completed_committed_bundle, applied_ptransform, - completed_timers, outputs, earliest_hold): + completed_timers, outputs, unprocessed_bundles, + earliest_hold): assert isinstance(applied_ptransform, pipeline.AppliedPTransform) self._update_pending( completed_committed_bundle, applied_ptransform, completed_timers, - outputs) + outputs, unprocessed_bundles) tw = self.get_watermarks(applied_ptransform) tw.hold(earliest_hold) self._refresh_watermarks(applied_ptransform) def _update_pending(self, input_committed_bundle, applied_ptransform, - completed_timers, output_committed_bundles): + completed_timers, output_committed_bundles, + unprocessed_bundles): """Updated list of pending bundles for the given AppliedPTransform.""" # Update pending elements. Filter out empty bundles. They do not impact @@ -119,6 +121,9 @@ def _update_pending(self, input_committed_bundle, applied_ptransform, completed_tw = self._transform_to_watermarks[applied_ptransform] completed_tw.update_timers(completed_timers) + for unprocessed_bundle in unprocessed_bundles: + completed_tw.add_pending(unprocessed_bundle) + assert input_committed_bundle or applied_ptransform in self._root_transforms if input_committed_bundle and input_committed_bundle.has_elements(): completed_tw.remove_pending(input_committed_bundle) diff --git a/sdks/python/apache_beam/testing/test_stream.py b/sdks/python/apache_beam/testing/test_stream.py index a06bcd0795f2c..7989fb2eee94d 100644 --- a/sdks/python/apache_beam/testing/test_stream.py +++ b/sdks/python/apache_beam/testing/test_stream.py @@ -24,8 +24,10 @@ from abc import abstractmethod from apache_beam import coders +from apache_beam import core from apache_beam import pvalue from apache_beam.transforms import PTransform +from apache_beam.transforms import window from apache_beam.transforms.window import TimestampedValue from apache_beam.utils import timestamp from apache_beam.utils.windowed_value import WindowedValue @@ -99,6 +101,9 @@ def __init__(self, coder=coders.FastPrimitivesCoder): self.current_watermark = timestamp.MIN_TIMESTAMP self.events = [] + def get_windowing(self, unused_inputs): + return core.Windowing(window.GlobalWindows()) + def expand(self, pbegin): assert isinstance(pbegin, pvalue.PBegin) self.pipeline = pbegin.pipeline diff --git a/sdks/python/apache_beam/testing/test_stream_test.py b/sdks/python/apache_beam/testing/test_stream_test.py index e32dda2aeb267..bf05ac16f7ef8 100644 --- a/sdks/python/apache_beam/testing/test_stream_test.py +++ b/sdks/python/apache_beam/testing/test_stream_test.py @@ -19,6 +19,8 @@ import unittest +import apache_beam as beam +from apache_beam.testing.test_pipeline import TestPipeline from apache_beam.testing.test_stream import ElementEvent from apache_beam.testing.test_stream import ProcessingTimeEvent from apache_beam.testing.test_stream import TestStream @@ -78,6 +80,41 @@ def test_test_stream_errors(self): TimestampedValue('a', timestamp.MAX_TIMESTAMP) ])) + def test_basic_execution(self): + test_stream = (TestStream() + .advance_watermark_to(10) + .add_elements(['a', 'b', 'c']) + .advance_watermark_to(20) + .add_elements(['d']) + .add_elements(['e']) + .advance_processing_time(10) + .advance_watermark_to(300) + .add_elements([TimestampedValue('late', 12)]) + .add_elements([TimestampedValue('last', 310)])) + + global _seen_elements # pylint: disable=global-variable-undefined + _seen_elements = [] + + class RecordFn(beam.DoFn): + def process(self, element=beam.DoFn.ElementParam, + timestamp=beam.DoFn.TimestampParam): + _seen_elements.append((element, timestamp)) + + p = TestPipeline() + my_record_fn = RecordFn() + p | test_stream | beam.ParDo(my_record_fn) # pylint: disable=expression-not-assigned + p.run() + + self.assertEqual([ + ('a', timestamp.Timestamp(10)), + ('b', timestamp.Timestamp(10)), + ('c', timestamp.Timestamp(10)), + ('d', timestamp.Timestamp(20)), + ('e', timestamp.Timestamp(20)), + ('late', timestamp.Timestamp(12)), + ('last', timestamp.Timestamp(310)),], _seen_elements) + del _seen_elements + if __name__ == '__main__': unittest.main() From c6d0d7983b19ce2e01b7b06a12f704fef17a00cc Mon Sep 17 00:00:00 2001 From: "chamikara@google.com" Date: Wed, 21 Jun 2017 10:37:11 -0700 Subject: [PATCH 080/200] Remove GroupedShuffleRangeTracker which is unused in the SDK --- sdks/python/apache_beam/io/range_trackers.py | 130 ------------ .../apache_beam/io/range_trackers_test.py | 186 ------------------ 2 files changed, 316 deletions(-) diff --git a/sdks/python/apache_beam/io/range_trackers.py b/sdks/python/apache_beam/io/range_trackers.py index 9cb36e73dc8e7..bef77d4004768 100644 --- a/sdks/python/apache_beam/io/range_trackers.py +++ b/sdks/python/apache_beam/io/range_trackers.py @@ -193,136 +193,6 @@ def set_split_points_unclaimed_callback(self, callback): self._split_points_unclaimed_callback = callback -class GroupedShuffleRangeTracker(iobase.RangeTracker): - """For internal use only; no backwards-compatibility guarantees. - - A 'RangeTracker' for positions used by'GroupedShuffleReader'. - - These positions roughly correspond to hashes of keys. In case of hash - collisions, multiple groups can have the same position. In that case, the - first group at a particular position is considered a split point (because - it is the first to be returned when reading a position range starting at this - position), others are not. - """ - - def __init__(self, decoded_start_pos, decoded_stop_pos): - super(GroupedShuffleRangeTracker, self).__init__() - self._decoded_start_pos = decoded_start_pos - self._decoded_stop_pos = decoded_stop_pos - self._decoded_last_group_start = None - self._last_group_was_at_a_split_point = False - self._split_points_seen = 0 - self._lock = threading.Lock() - - def start_position(self): - return self._decoded_start_pos - - def stop_position(self): - return self._decoded_stop_pos - - def last_group_start(self): - return self._decoded_last_group_start - - def _validate_decoded_group_start(self, decoded_group_start, split_point): - if self.start_position() and decoded_group_start < self.start_position(): - raise ValueError('Trying to return record at %r which is before the' - ' starting position at %r' % - (decoded_group_start, self.start_position())) - - if (self.last_group_start() and - decoded_group_start < self.last_group_start()): - raise ValueError('Trying to return group at %r which is before the' - ' last-returned group at %r' % - (decoded_group_start, self.last_group_start())) - if (split_point and self.last_group_start() and - self.last_group_start() == decoded_group_start): - raise ValueError('Trying to return a group at a split point with ' - 'same position as the previous group: both at %r, ' - 'last group was %sat a split point.' % - (decoded_group_start, - ('' if self._last_group_was_at_a_split_point - else 'not '))) - if not split_point: - if self.last_group_start() is None: - raise ValueError('The first group [at %r] must be at a split point' % - decoded_group_start) - if self.last_group_start() != decoded_group_start: - # This case is not a violation of general RangeTracker semantics, but it - # is contrary to how GroupingShuffleReader in particular works. Hitting - # it would mean it's behaving unexpectedly. - raise ValueError('Trying to return a group not at a split point, but ' - 'with a different position than the previous group: ' - 'last group was %r at %r, current at a %s split' - ' point.' % - (self.last_group_start() - , decoded_group_start - , ('' if self._last_group_was_at_a_split_point - else 'non-'))) - - def try_claim(self, decoded_group_start): - with self._lock: - self._validate_decoded_group_start(decoded_group_start, True) - if (self.stop_position() - and decoded_group_start >= self.stop_position()): - return False - - self._decoded_last_group_start = decoded_group_start - self._last_group_was_at_a_split_point = True - self._split_points_seen += 1 - return True - - def set_current_position(self, decoded_group_start): - with self._lock: - self._validate_decoded_group_start(decoded_group_start, False) - self._decoded_last_group_start = decoded_group_start - self._last_group_was_at_a_split_point = False - - def try_split(self, decoded_split_position): - with self._lock: - if self.last_group_start() is None: - logging.info('Refusing to split %r at %r: unstarted' - , self, decoded_split_position) - return - - if decoded_split_position <= self.last_group_start(): - logging.info('Refusing to split %r at %r: already past proposed split ' - 'position' - , self, decoded_split_position) - return - - if ((self.stop_position() - and decoded_split_position >= self.stop_position()) - or (self.start_position() - and decoded_split_position <= self.start_position())): - logging.error('Refusing to split %r at %r: proposed split position out ' - 'of range', self, decoded_split_position) - return - - logging.debug('Agreeing to split %r at %r' - , self, decoded_split_position) - self._decoded_stop_pos = decoded_split_position - - # Since GroupedShuffleRangeTracker cannot determine relative sizes of the - # two splits, returning 0.5 as the fraction below so that the framework - # assumes the splits to be of the same size. - return self._decoded_stop_pos, 0.5 - - def fraction_consumed(self): - # GroupingShuffle sources have special support on the service and the - # service will estimate progress from positions for us. - raise RuntimeError('GroupedShuffleRangeTracker does not measure fraction' - ' consumed due to positions being opaque strings' - ' that are interpreted by the service') - - def split_points(self): - with self._lock: - splits_points_consumed = ( - 0 if self._split_points_seen <= 1 else (self._split_points_seen - 1)) - - return (splits_points_consumed, - iobase.RangeTracker.SPLIT_POINTS_UNKNOWN) - - class OrderedPositionRangeTracker(iobase.RangeTracker): """ An abstract base class for range trackers whose positions are comparable. diff --git a/sdks/python/apache_beam/io/range_trackers_test.py b/sdks/python/apache_beam/io/range_trackers_test.py index edb6386379b29..3e926634c85f9 100644 --- a/sdks/python/apache_beam/io/range_trackers_test.py +++ b/sdks/python/apache_beam/io/range_trackers_test.py @@ -17,14 +17,11 @@ """Unit tests for the range_trackers module.""" -import array import copy import logging import math import unittest - -from apache_beam.io import iobase from apache_beam.io import range_trackers @@ -189,189 +186,6 @@ def dummy_callback(stop_position): (3, 41)) -class GroupedShuffleRangeTrackerTest(unittest.TestCase): - - def bytes_to_position(self, bytes_array): - return array.array('B', bytes_array).tostring() - - def test_try_return_record_in_infinite_range(self): - tracker = range_trackers.GroupedShuffleRangeTracker('', '') - self.assertTrue(tracker.try_claim( - self.bytes_to_position([1, 2, 3]))) - self.assertTrue(tracker.try_claim( - self.bytes_to_position([1, 2, 5]))) - self.assertTrue(tracker.try_claim( - self.bytes_to_position([3, 6, 8, 10]))) - - def test_try_return_record_finite_range(self): - tracker = range_trackers.GroupedShuffleRangeTracker( - self.bytes_to_position([1, 0, 0]), self.bytes_to_position([5, 0, 0])) - self.assertTrue(tracker.try_claim( - self.bytes_to_position([1, 2, 3]))) - self.assertTrue(tracker.try_claim( - self.bytes_to_position([1, 2, 5]))) - self.assertTrue(tracker.try_claim( - self.bytes_to_position([3, 6, 8, 10]))) - self.assertTrue(tracker.try_claim( - self.bytes_to_position([4, 255, 255, 255]))) - # Should fail for positions that are lexicographically equal to or larger - # than the defined stop position. - self.assertFalse(copy.copy(tracker).try_claim( - self.bytes_to_position([5, 0, 0]))) - self.assertFalse(copy.copy(tracker).try_claim( - self.bytes_to_position([5, 0, 1]))) - self.assertFalse(copy.copy(tracker).try_claim( - self.bytes_to_position([6, 0, 0]))) - - def test_try_return_record_with_non_split_point(self): - tracker = range_trackers.GroupedShuffleRangeTracker( - self.bytes_to_position([1, 0, 0]), self.bytes_to_position([5, 0, 0])) - self.assertTrue(tracker.try_claim( - self.bytes_to_position([1, 2, 3]))) - tracker.set_current_position(self.bytes_to_position([1, 2, 3])) - tracker.set_current_position(self.bytes_to_position([1, 2, 3])) - self.assertTrue(tracker.try_claim( - self.bytes_to_position([1, 2, 5]))) - tracker.set_current_position(self.bytes_to_position([1, 2, 5])) - self.assertTrue(tracker.try_claim( - self.bytes_to_position([3, 6, 8, 10]))) - self.assertTrue(tracker.try_claim( - self.bytes_to_position([4, 255, 255, 255]))) - - def test_first_record_non_split_point(self): - tracker = range_trackers.GroupedShuffleRangeTracker( - self.bytes_to_position([3, 0, 0]), self.bytes_to_position([5, 0, 0])) - with self.assertRaises(ValueError): - tracker.set_current_position(self.bytes_to_position([3, 4, 5])) - - def test_non_split_point_record_with_different_position(self): - tracker = range_trackers.GroupedShuffleRangeTracker( - self.bytes_to_position([3, 0, 0]), self.bytes_to_position([5, 0, 0])) - self.assertTrue(tracker.try_claim(self.bytes_to_position([3, 4, 5]))) - with self.assertRaises(ValueError): - tracker.set_current_position(self.bytes_to_position([3, 4, 6])) - - def test_try_return_record_before_start(self): - tracker = range_trackers.GroupedShuffleRangeTracker( - self.bytes_to_position([3, 0, 0]), self.bytes_to_position([5, 0, 0])) - with self.assertRaises(ValueError): - tracker.try_claim(self.bytes_to_position([1, 2, 3])) - - def test_try_return_non_monotonic(self): - tracker = range_trackers.GroupedShuffleRangeTracker( - self.bytes_to_position([3, 0, 0]), self.bytes_to_position([5, 0, 0])) - self.assertTrue(tracker.try_claim(self.bytes_to_position([3, 4, 5]))) - self.assertTrue(tracker.try_claim(self.bytes_to_position([3, 4, 6]))) - with self.assertRaises(ValueError): - tracker.try_claim(self.bytes_to_position([3, 2, 1])) - - def test_try_return_identical_positions(self): - tracker = range_trackers.GroupedShuffleRangeTracker( - self.bytes_to_position([3, 0, 0]), self.bytes_to_position([5, 0, 0])) - self.assertTrue(tracker.try_claim( - self.bytes_to_position([3, 4, 5]))) - with self.assertRaises(ValueError): - tracker.try_claim(self.bytes_to_position([3, 4, 5])) - - def test_try_split_at_position_infinite_range(self): - tracker = range_trackers.GroupedShuffleRangeTracker('', '') - # Should fail before first record is returned. - self.assertFalse(tracker.try_split( - self.bytes_to_position([3, 4, 5, 6]))) - - self.assertTrue(tracker.try_claim( - self.bytes_to_position([1, 2, 3]))) - - # Should now succeed. - self.assertIsNotNone(tracker.try_split( - self.bytes_to_position([3, 4, 5, 6]))) - # Should not split at same or larger position. - self.assertIsNone(tracker.try_split( - self.bytes_to_position([3, 4, 5, 6]))) - self.assertIsNone(tracker.try_split( - self.bytes_to_position([3, 4, 5, 6, 7]))) - self.assertIsNone(tracker.try_split( - self.bytes_to_position([4, 5, 6, 7]))) - - # Should split at smaller position. - self.assertIsNotNone(tracker.try_split( - self.bytes_to_position([3, 2, 1]))) - - self.assertTrue(tracker.try_claim( - self.bytes_to_position([2, 3, 4]))) - - # Should not split at a position we're already past. - self.assertIsNone(tracker.try_split( - self.bytes_to_position([2, 3, 4]))) - self.assertIsNone(tracker.try_split( - self.bytes_to_position([2, 3, 3]))) - - self.assertTrue(tracker.try_claim( - self.bytes_to_position([3, 2, 0]))) - self.assertFalse(tracker.try_claim( - self.bytes_to_position([3, 2, 1]))) - - def test_try_test_split_at_position_finite_range(self): - tracker = range_trackers.GroupedShuffleRangeTracker( - self.bytes_to_position([0, 0, 0]), - self.bytes_to_position([10, 20, 30])) - # Should fail before first record is returned. - self.assertFalse(tracker.try_split( - self.bytes_to_position([0, 0, 0]))) - self.assertFalse(tracker.try_split( - self.bytes_to_position([3, 4, 5, 6]))) - - self.assertTrue(tracker.try_claim( - self.bytes_to_position([1, 2, 3]))) - - # Should now succeed. - self.assertTrue(tracker.try_split( - self.bytes_to_position([3, 4, 5, 6]))) - # Should not split at same or larger position. - self.assertFalse(tracker.try_split( - self.bytes_to_position([3, 4, 5, 6]))) - self.assertFalse(tracker.try_split( - self.bytes_to_position([3, 4, 5, 6, 7]))) - self.assertFalse(tracker.try_split( - self.bytes_to_position([4, 5, 6, 7]))) - - # Should split at smaller position. - self.assertTrue(tracker.try_split( - self.bytes_to_position([3, 2, 1]))) - # But not at a position at or before last returned record. - self.assertFalse(tracker.try_split( - self.bytes_to_position([1, 2, 3]))) - - self.assertTrue(tracker.try_claim( - self.bytes_to_position([2, 3, 4]))) - self.assertTrue(tracker.try_claim( - self.bytes_to_position([3, 2, 0]))) - self.assertFalse(tracker.try_claim( - self.bytes_to_position([3, 2, 1]))) - - def test_split_points(self): - tracker = range_trackers.GroupedShuffleRangeTracker( - self.bytes_to_position([1, 0, 0]), - self.bytes_to_position([5, 0, 0])) - self.assertEqual(tracker.split_points(), - (0, iobase.RangeTracker.SPLIT_POINTS_UNKNOWN)) - self.assertTrue(tracker.try_claim(self.bytes_to_position([1, 2, 3]))) - self.assertEqual(tracker.split_points(), - (0, iobase.RangeTracker.SPLIT_POINTS_UNKNOWN)) - self.assertTrue(tracker.try_claim(self.bytes_to_position([1, 2, 5]))) - self.assertEqual(tracker.split_points(), - (1, iobase.RangeTracker.SPLIT_POINTS_UNKNOWN)) - self.assertTrue(tracker.try_claim(self.bytes_to_position([3, 6, 8]))) - self.assertEqual(tracker.split_points(), - (2, iobase.RangeTracker.SPLIT_POINTS_UNKNOWN)) - self.assertTrue(tracker.try_claim(self.bytes_to_position([4, 255, 255]))) - self.assertEqual(tracker.split_points(), - (3, iobase.RangeTracker.SPLIT_POINTS_UNKNOWN)) - self.assertFalse(tracker.try_claim(self.bytes_to_position([5, 1, 0]))) - self.assertEqual(tracker.split_points(), - (3, iobase.RangeTracker.SPLIT_POINTS_UNKNOWN)) - - class OrderedPositionRangeTrackerTest(unittest.TestCase): class DoubleRangeTracker(range_trackers.OrderedPositionRangeTracker): From cd886300719ac9d702fbe7b105b09bdc5bbe0d3b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Author=3A=20=E6=B3=A2=E7=89=B9?= Date: Fri, 26 May 2017 17:40:27 +0800 Subject: [PATCH 081/200] ReduceFnRunner.onTrigger: skip storeCurrentPaneInfo() if trigger isFinished. --- .../java/org/apache/beam/runners/core/ReduceFnRunner.java | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/ReduceFnRunner.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/ReduceFnRunner.java index 62d519f6f8e5d..b5c3e3ecc016a 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/ReduceFnRunner.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/ReduceFnRunner.java @@ -948,7 +948,7 @@ private void prefetchOnTrigger( private Instant onTrigger( final ReduceFn.Context directContext, ReduceFn.Context renamedContext, - boolean isFinished, boolean isEndOfWindow) + final boolean isFinished, boolean isEndOfWindow) throws Exception { Instant inputWM = timerInternals.currentInputWatermarkTime(); @@ -1005,9 +1005,11 @@ private Instant onTrigger( @Override public void output(OutputT toOutput) { // We're going to output panes, so commit the (now used) PaneInfo. - // TODO: This is unnecessary if the trigger isFinished since the saved + // This is unnecessary if the trigger isFinished since the saved // state will be immediately deleted. - paneInfoTracker.storeCurrentPaneInfo(directContext, pane); + if (!isFinished) { + paneInfoTracker.storeCurrentPaneInfo(directContext, pane); + } // Output the actual value. outputter.outputWindowedValue( From 17c50122e684655846c4e07f19d16a38fa47d5a3 Mon Sep 17 00:00:00 2001 From: Mark Liu Date: Wed, 21 Jun 2017 14:28:26 -0700 Subject: [PATCH 082/200] [BEAM-2495] Add Python test dependency six>=1.9 --- sdks/python/setup.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sdks/python/setup.py b/sdks/python/setup.py index 584c852c57b82..6646a58e529ab 100644 --- a/sdks/python/setup.py +++ b/sdks/python/setup.py @@ -112,6 +112,8 @@ def get_version(): REQUIRED_TEST_PACKAGES = [ 'pyhamcrest>=1.9,<2.0', + # Six required by nose plugins management. + 'six>=1.9', ] GCP_REQUIREMENTS = [ From aa65ea11e6e0d50864de21340219b5f4d019dbc2 Mon Sep 17 00:00:00 2001 From: Sourabh Bajaj Date: Wed, 21 Jun 2017 10:32:14 -0700 Subject: [PATCH 083/200] Add example for Bigquery streaming sink --- .../examples/windowed_wordcount.py | 93 +++++++++++++++++++ 1 file changed, 93 insertions(+) create mode 100644 sdks/python/apache_beam/examples/windowed_wordcount.py diff --git a/sdks/python/apache_beam/examples/windowed_wordcount.py b/sdks/python/apache_beam/examples/windowed_wordcount.py new file mode 100644 index 0000000000000..bd57847c67f98 --- /dev/null +++ b/sdks/python/apache_beam/examples/windowed_wordcount.py @@ -0,0 +1,93 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""A streaming word-counting workflow. + +Important: streaming pipeline support in Python Dataflow is in development +and is not yet available for use. +""" + +from __future__ import absolute_import + +import argparse +import logging + + +import apache_beam as beam +import apache_beam.transforms.window as window + +TABLE_SCHEMA = ('word:STRING, count:INTEGER, ' + 'window_start:TIMESTAMP, window_end:TIMESTAMP') + + +def find_words(element): + import re + return re.findall(r'[A-Za-z\']+', element) + + +class FormatDoFn(beam.DoFn): + def process(self, element, window=beam.DoFn.WindowParam): + ts_format = '%Y-%m-%d %H:%M:%S.%f UTC' + window_start = window.start.to_utc_datetime().strftime(ts_format) + window_end = window.end.to_utc_datetime().strftime(ts_format) + return [{'word': element[0], + 'count': element[1], + 'window_start':window_start, + 'window_end':window_end}] + + +def run(argv=None): + """Build and run the pipeline.""" + + parser = argparse.ArgumentParser() + parser.add_argument( + '--input_topic', required=True, + help='Input PubSub topic of the form "/topics//".') + parser.add_argument( + '--output_table', required=True, + help= + ('Output BigQuery table for results specified as: PROJECT:DATASET.TABLE ' + 'or DATASET.TABLE.')) + known_args, pipeline_args = parser.parse_known_args(argv) + + with beam.Pipeline(argv=pipeline_args) as p: + + # Read the text from PubSub messages + lines = p | beam.io.ReadStringsFromPubSub(known_args.input_topic) + + # Capitalize the characters in each line. + transformed = (lines + | 'Split' >> (beam.FlatMap(find_words) + .with_output_types(unicode)) + | 'PairWithOne' >> beam.Map(lambda x: (x, 1)) + | beam.WindowInto(window.FixedWindows(2*60, 0)) + | 'Group' >> beam.GroupByKey() + | 'Count' >> beam.Map(lambda (word, ones): (word, sum(ones))) + | 'Format' >> beam.ParDo(FormatDoFn())) + + # Write to BigQuery. + # pylint: disable=expression-not-assigned + transformed | 'Write' >> beam.io.WriteToBigQuery( + known_args.output_table, + schema=TABLE_SCHEMA, + create_disposition=beam.io.BigQueryDisposition.CREATE_IF_NEEDED, + write_disposition=beam.io.BigQueryDisposition.WRITE_APPEND) + + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.INFO) + run() From 20820fa5477ffcdd4a9ef2e9340353ed3c5691a9 Mon Sep 17 00:00:00 2001 From: Aviem Zur Date: Mon, 12 Jun 2017 17:04:00 +0300 Subject: [PATCH 084/200] [BEAM-2359] Fix watermark broadcasting to executors in Spark runner --- .../beam/runners/spark/SparkRunner.java | 2 +- .../beam/runners/spark/TestSparkRunner.java | 2 +- .../SparkGroupAlsoByWindowViaWindowSet.java | 6 +- .../spark/stateful/SparkTimerInternals.java | 18 ++- .../spark/util/GlobalWatermarkHolder.java | 127 +++++++++++++----- .../spark/GlobalWatermarkHolderTest.java | 18 +-- 6 files changed, 120 insertions(+), 53 deletions(-) diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunner.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunner.java index d008718af0cc8..595521fd9ff59 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunner.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunner.java @@ -171,7 +171,7 @@ public SparkPipelineResult run(final Pipeline pipeline) { } // register Watermarks listener to broadcast the advanced WMs. - jssc.addStreamingListener(new JavaStreamingListenerWrapper(new WatermarksListener(jssc))); + jssc.addStreamingListener(new JavaStreamingListenerWrapper(new WatermarksListener())); // The reason we call initAccumulators here even though it is called in // SparkRunnerStreamingContextFactory is because the factory is not called when resuming diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java index eccee574afaa0..a13a3b141aa49 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java @@ -169,7 +169,7 @@ private static void awaitWatermarksOrTimeout( result.waitUntilFinish(Duration.millis(batchDurationMillis)); do { SparkTimerInternals sparkTimerInternals = - SparkTimerInternals.global(GlobalWatermarkHolder.get()); + SparkTimerInternals.global(GlobalWatermarkHolder.get(batchDurationMillis)); sparkTimerInternals.advanceWatermark(); globalWatermark = sparkTimerInternals.currentInputWatermarkTime(); // let another batch-interval period of execution, just to reason about WM propagation. diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java index be4f3f65a3b7a..1385e071978f7 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java @@ -104,13 +104,15 @@ private abstract static class SerializableFunction1 public static JavaDStream>>> groupAlsoByWindow( - JavaDStream>>>> inputDStream, + final JavaDStream>>>> inputDStream, final Coder keyCoder, final Coder> wvCoder, final WindowingStrategy windowingStrategy, final SparkRuntimeContext runtimeContext, final List sourceIds) { + final long batchDurationMillis = + runtimeContext.getPipelineOptions().as(SparkPipelineOptions.class).getBatchIntervalMillis(); final IterableCoder> itrWvCoder = IterableCoder.of(wvCoder); final Coder iCoder = ((FullWindowedValueCoder) wvCoder).getValueCoder(); final Coder wCoder = @@ -239,7 +241,7 @@ public JavaPairRDD call( SparkStateInternals stateInternals; SparkTimerInternals timerInternals = SparkTimerInternals.forStreamFromSources( - sourceIds, GlobalWatermarkHolder.get()); + sourceIds, GlobalWatermarkHolder.get(batchDurationMillis)); // get state(internals) per key. if (prevStateAndTimersOpt.isEmpty()) { // no previous state. diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkTimerInternals.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkTimerInternals.java index 107915f7f43d0..a68da5516da79 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkTimerInternals.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkTimerInternals.java @@ -34,7 +34,6 @@ import org.apache.beam.runners.spark.util.GlobalWatermarkHolder.SparkWatermarks; import org.apache.beam.sdk.state.TimeDomain; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.spark.broadcast.Broadcast; import org.joda.time.Instant; @@ -58,10 +57,10 @@ private SparkTimerInternals( /** Build the {@link TimerInternals} according to the feeding streams. */ public static SparkTimerInternals forStreamFromSources( List sourceIds, - @Nullable Broadcast> broadcast) { - // if broadcast is invalid for the specific ids, use defaults. - if (broadcast == null || broadcast.getValue().isEmpty() - || Collections.disjoint(sourceIds, broadcast.getValue().keySet())) { + Map watermarks) { + // if watermarks are invalid for the specific ids, use defaults. + if (watermarks == null || watermarks.isEmpty() + || Collections.disjoint(sourceIds, watermarks.keySet())) { return new SparkTimerInternals( BoundedWindow.TIMESTAMP_MIN_VALUE, BoundedWindow.TIMESTAMP_MIN_VALUE, new Instant(0)); } @@ -71,7 +70,7 @@ public static SparkTimerInternals forStreamFromSources( // synchronized processing time should clearly be synchronized. Instant synchronizedProcessingTime = null; for (Integer sourceId: sourceIds) { - SparkWatermarks sparkWatermarks = broadcast.getValue().get(sourceId); + SparkWatermarks sparkWatermarks = watermarks.get(sourceId); if (sparkWatermarks != null) { // keep slowest WMs. slowestLowWatermark = slowestLowWatermark.isBefore(sparkWatermarks.getLowWatermark()) @@ -94,10 +93,9 @@ public static SparkTimerInternals forStreamFromSources( } /** Build a global {@link TimerInternals} for all feeding streams.*/ - public static SparkTimerInternals global( - @Nullable Broadcast> broadcast) { - return broadcast == null ? forStreamFromSources(Collections.emptyList(), null) - : forStreamFromSources(Lists.newArrayList(broadcast.getValue().keySet()), broadcast); + public static SparkTimerInternals global(Map watermarks) { + return watermarks == null ? forStreamFromSources(Collections.emptyList(), null) + : forStreamFromSources(Lists.newArrayList(watermarks.keySet()), watermarks); } Collection getTimers() { diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/util/GlobalWatermarkHolder.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/util/GlobalWatermarkHolder.java index 8b384d8e2ab03..2cb6f26f8a0f4 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/util/GlobalWatermarkHolder.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/util/GlobalWatermarkHolder.java @@ -21,31 +21,43 @@ import static com.google.common.base.Preconditions.checkState; import com.google.common.annotations.VisibleForTesting; +import com.google.common.cache.CacheBuilder; +import com.google.common.cache.CacheLoader; +import com.google.common.cache.LoadingCache; +import com.google.common.collect.Maps; import java.io.Serializable; import java.util.HashMap; import java.util.Map; import java.util.Queue; import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import javax.annotation.Nonnull; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.SparkEnv; import org.apache.spark.broadcast.Broadcast; -import org.apache.spark.streaming.api.java.JavaStreamingContext; +import org.apache.spark.storage.BlockId; +import org.apache.spark.storage.BlockManager; +import org.apache.spark.storage.BlockResult; +import org.apache.spark.storage.BlockStore; +import org.apache.spark.storage.StorageLevel; import org.apache.spark.streaming.api.java.JavaStreamingListener; import org.apache.spark.streaming.api.java.JavaStreamingListenerBatchCompleted; import org.joda.time.Instant; - +import scala.Option; /** - * A {@link Broadcast} variable to hold the global watermarks for a micro-batch. + * A {@link BlockStore} variable to hold the global watermarks for a micro-batch. * *

    For each source, holds a queue for the watermarks of each micro-batch that was read, * and advances the watermarks according to the queue (first-in-first-out). */ public class GlobalWatermarkHolder { - // the broadcast is broadcasted to the workers. - private static volatile Broadcast> broadcast = null; - // this should only live in the driver so transient. - private static final transient Map> sourceTimes = new HashMap<>(); + private static final Map> sourceTimes = new HashMap<>(); + private static final BlockId WATERMARKS_BLOCK_ID = BlockId.apply("broadcast_0WATERMARKS"); + + private static volatile Map driverWatermarks = null; + private static volatile LoadingCache> watermarkCache = null; public static void add(int sourceId, SparkWatermarks sparkWatermarks) { Queue timesQueue = sourceTimes.get(sourceId); @@ -71,22 +83,48 @@ public static void addAll(Map> sourceTimes) { * Returns the {@link Broadcast} containing the {@link SparkWatermarks} mapped * to their sources. */ - public static Broadcast> get() { - return broadcast; + @SuppressWarnings("unchecked") + public static Map get(Long cacheInterval) { + if (driverWatermarks != null) { + // if we are executing in local mode simply return the local values. + return driverWatermarks; + } else { + if (watermarkCache == null) { + initWatermarkCache(cacheInterval); + } + try { + return watermarkCache.get("SINGLETON"); + } catch (ExecutionException e) { + throw new RuntimeException(e); + } + } + } + + private static synchronized void initWatermarkCache(Long batchDuration) { + if (watermarkCache == null) { + watermarkCache = + CacheBuilder.newBuilder() + // expire watermarks every half batch duration to ensure they update in every batch. + .expireAfterWrite(batchDuration / 2, TimeUnit.MILLISECONDS) + .build(new WatermarksLoader()); + } } /** * Advances the watermarks to the next-in-line watermarks. * SparkWatermarks are monotonically increasing. */ - public static void advance(JavaSparkContext jsc) { - synchronized (GlobalWatermarkHolder.class){ + @SuppressWarnings("unchecked") + public static void advance() { + synchronized (GlobalWatermarkHolder.class) { + BlockManager blockManager = SparkEnv.get().blockManager(); + if (sourceTimes.isEmpty()) { return; } // update all sources' watermarks into the new broadcast. - Map newBroadcast = new HashMap<>(); + Map newValues = new HashMap<>(); for (Map.Entry> en: sourceTimes.entrySet()) { if (en.getValue().isEmpty()) { @@ -99,8 +137,22 @@ public static void advance(JavaSparkContext jsc) { Instant currentLowWatermark = BoundedWindow.TIMESTAMP_MIN_VALUE; Instant currentHighWatermark = BoundedWindow.TIMESTAMP_MIN_VALUE; Instant currentSynchronizedProcessingTime = BoundedWindow.TIMESTAMP_MIN_VALUE; - if (broadcast != null && broadcast.getValue().containsKey(sourceId)) { - SparkWatermarks currentTimes = broadcast.getValue().get(sourceId); + + Option currentOption = blockManager.getRemote(WATERMARKS_BLOCK_ID); + Map current; + if (currentOption.isDefined()) { + current = (Map) currentOption.get().data().next(); + } else { + current = Maps.newHashMap(); + blockManager.putSingle( + WATERMARKS_BLOCK_ID, + current, + StorageLevel.MEMORY_ONLY(), + true); + } + + if (current.containsKey(sourceId)) { + SparkWatermarks currentTimes = current.get(sourceId); currentLowWatermark = currentTimes.getLowWatermark(); currentHighWatermark = currentTimes.getHighWatermark(); currentSynchronizedProcessingTime = currentTimes.getSynchronizedProcessingTime(); @@ -119,20 +171,21 @@ public static void advance(JavaSparkContext jsc) { nextLowWatermark, nextHighWatermark)); checkState(nextSynchronizedProcessingTime.isAfter(currentSynchronizedProcessingTime), "Synchronized processing time must advance."); - newBroadcast.put( + newValues.put( sourceId, new SparkWatermarks( nextLowWatermark, nextHighWatermark, nextSynchronizedProcessingTime)); } // update the watermarks broadcast only if something has changed. - if (!newBroadcast.isEmpty()) { - if (broadcast != null) { - // for now this is blocking, we could make this asynchronous - // but it could slow down WM propagation. - broadcast.destroy(); - } - broadcast = jsc.broadcast(newBroadcast); + if (!newValues.isEmpty()) { + driverWatermarks = newValues; + blockManager.removeBlock(WATERMARKS_BLOCK_ID, true); + blockManager.putSingle( + WATERMARKS_BLOCK_ID, + newValues, + StorageLevel.MEMORY_ONLY(), + true); } } } @@ -140,7 +193,12 @@ public static void advance(JavaSparkContext jsc) { @VisibleForTesting public static synchronized void clear() { sourceTimes.clear(); - broadcast = null; + driverWatermarks = null; + SparkEnv sparkEnv = SparkEnv.get(); + if (sparkEnv != null) { + BlockManager blockManager = sparkEnv.blockManager(); + blockManager.removeBlock(WATERMARKS_BLOCK_ID, true); + } } /** @@ -185,15 +243,24 @@ public String toString() { /** Advance the WMs onBatchCompleted event. */ public static class WatermarksListener extends JavaStreamingListener { - private final JavaStreamingContext jssc; - - public WatermarksListener(JavaStreamingContext jssc) { - this.jssc = jssc; + @Override + public void onBatchCompleted(JavaStreamingListenerBatchCompleted batchCompleted) { + GlobalWatermarkHolder.advance(); } + } + + private static class WatermarksLoader extends CacheLoader> { + @SuppressWarnings("unchecked") @Override - public void onBatchCompleted(JavaStreamingListenerBatchCompleted batchCompleted) { - GlobalWatermarkHolder.advance(jssc.sparkContext()); + public Map load(@Nonnull String key) throws Exception { + Option blockResultOption = + SparkEnv.get().blockManager().getRemote(WATERMARKS_BLOCK_ID); + if (blockResultOption.isDefined()) { + return (Map) blockResultOption.get().data().next(); + } else { + return Maps.newHashMap(); + } } } } diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/GlobalWatermarkHolderTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/GlobalWatermarkHolderTest.java index 47a6e3fe74999..17081236cf527 100644 --- a/runners/spark/src/test/java/org/apache/beam/runners/spark/GlobalWatermarkHolderTest.java +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/GlobalWatermarkHolderTest.java @@ -65,17 +65,17 @@ public void testLowHighWatermarksAdvance() { instant.plus(Duration.millis(5)), instant.plus(Duration.millis(5)), instant)); - GlobalWatermarkHolder.advance(jsc); + GlobalWatermarkHolder.advance(); // low < high. GlobalWatermarkHolder.add(1, new SparkWatermarks( instant.plus(Duration.millis(10)), instant.plus(Duration.millis(15)), instant.plus(Duration.millis(100)))); - GlobalWatermarkHolder.advance(jsc); + GlobalWatermarkHolder.advance(); // assert watermarks in Broadcast. - SparkWatermarks currentWatermarks = GlobalWatermarkHolder.get().getValue().get(1); + SparkWatermarks currentWatermarks = GlobalWatermarkHolder.get(0L).get(1); assertThat(currentWatermarks.getLowWatermark(), equalTo(instant.plus(Duration.millis(10)))); assertThat(currentWatermarks.getHighWatermark(), equalTo(instant.plus(Duration.millis(15)))); assertThat(currentWatermarks.getSynchronizedProcessingTime(), @@ -93,7 +93,7 @@ public void testLowHighWatermarksAdvance() { instant.plus(Duration.millis(25)), instant.plus(Duration.millis(20)), instant.plus(Duration.millis(200)))); - GlobalWatermarkHolder.advance(jsc); + GlobalWatermarkHolder.advance(); } @Test @@ -106,7 +106,7 @@ public void testSynchronizedTimeMonotonic() { instant.plus(Duration.millis(5)), instant.plus(Duration.millis(10)), instant)); - GlobalWatermarkHolder.advance(jsc); + GlobalWatermarkHolder.advance(); thrown.expect(IllegalStateException.class); thrown.expectMessage("Synchronized processing time must advance."); @@ -117,7 +117,7 @@ public void testSynchronizedTimeMonotonic() { instant.plus(Duration.millis(5)), instant.plus(Duration.millis(10)), instant)); - GlobalWatermarkHolder.advance(jsc); + GlobalWatermarkHolder.advance(); } @Test @@ -136,15 +136,15 @@ public void testMultiSource() { instant.plus(Duration.millis(6)), instant)); - GlobalWatermarkHolder.advance(jsc); + GlobalWatermarkHolder.advance(); // assert watermarks for source 1. - SparkWatermarks watermarksForSource1 = GlobalWatermarkHolder.get().getValue().get(1); + SparkWatermarks watermarksForSource1 = GlobalWatermarkHolder.get(0L).get(1); assertThat(watermarksForSource1.getLowWatermark(), equalTo(instant.plus(Duration.millis(5)))); assertThat(watermarksForSource1.getHighWatermark(), equalTo(instant.plus(Duration.millis(10)))); // assert watermarks for source 2. - SparkWatermarks watermarksForSource2 = GlobalWatermarkHolder.get().getValue().get(2); + SparkWatermarks watermarksForSource2 = GlobalWatermarkHolder.get(0L).get(2); assertThat(watermarksForSource2.getLowWatermark(), equalTo(instant.plus(Duration.millis(3)))); assertThat(watermarksForSource2.getHighWatermark(), equalTo(instant.plus(Duration.millis(6)))); } From 22dbb500289675fe95b6d149c8550e09dc26feac Mon Sep 17 00:00:00 2001 From: Aviem Zur Date: Wed, 21 Jun 2017 17:53:21 +0300 Subject: [PATCH 085/200] Move Spark runner streaming tests to post commit. --- runners/spark/pom.xml | 42 +++++++++++++++++++++--------------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/runners/spark/pom.xml b/runners/spark/pom.xml index 0f6b73091688c..ee72dd96fdf5c 100644 --- a/runners/spark/pom.xml +++ b/runners/spark/pom.xml @@ -103,6 +103,27 @@ 4 + + streaming-tests + test + + test + + + + org.apache.beam.runners.spark.StreamingTest + + + + [ + "--runner=TestSparkRunner", + "--forceStreaming=true", + "--enableSparkMetricSinks=true" + ] + + + + @@ -372,27 +393,6 @@ - - streaming-tests - test - - test - - - - org.apache.beam.runners.spark.StreamingTest - - - - [ - "--runner=TestSparkRunner", - "--forceStreaming=true", - "--enableSparkMetricSinks=true" - ] - - - - From 4c488152c45ac7a6c344a21b67c968b97bf5066c Mon Sep 17 00:00:00 2001 From: Sourabh Bajaj Date: Thu, 22 Jun 2017 08:33:31 -0700 Subject: [PATCH 086/200] [BEAM-1585] Fix the beam plugins installation --- sdks/python/apache_beam/options/pipeline_options.py | 8 ++++++-- .../apache_beam/runners/dataflow/dataflow_runner.py | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/sdks/python/apache_beam/options/pipeline_options.py b/sdks/python/apache_beam/options/pipeline_options.py index dab8ff204d3a4..ea996a3d9fb56 100644 --- a/sdks/python/apache_beam/options/pipeline_options.py +++ b/sdks/python/apache_beam/options/pipeline_options.py @@ -552,13 +552,17 @@ def _add_argparse_args(cls, parser): 'worker will install the resulting package before running any custom ' 'code.')) parser.add_argument( - '--beam_plugins', + '--beam_plugin', '--beam_plugin', + dest='beam_plugins', + action='append', default=None, help= ('Bootstrap the python process before executing any code by importing ' 'all the plugins used in the pipeline. Please pass a comma separated' 'list of import paths to be included. This is currently an ' - 'experimental flag and provides no stability.')) + 'experimental flag and provides no stability. Multiple ' + '--beam_plugin options can be specified if more than one plugin ' + 'is needed.')) parser.add_argument( '--save_main_session', default=False, diff --git a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py b/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py index 9395f1688056e..f213b3b9db33c 100644 --- a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py +++ b/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py @@ -247,7 +247,7 @@ def run(self, pipeline): setup_options = pipeline._options.view_as(SetupOptions) plugins = BeamPlugin.get_all_plugin_paths() if setup_options.beam_plugins is not None: - plugins = list(set(plugins + setup_options.beam_plugins.split(','))) + plugins = list(set(plugins + setup_options.beam_plugins)) setup_options.beam_plugins = plugins self.job = apiclient.Job(pipeline._options) From 8cab15338f811f880c6cfb820051cf355f92986b Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Wed, 21 Jun 2017 18:09:48 -0700 Subject: [PATCH 087/200] Java Dataflow runner harness compatibility. --- .../runners/portability/fn_api_runner.py | 6 ++++- .../apache_beam/runners/worker/sdk_worker.py | 26 ++++++++++++++----- 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner.py b/sdks/python/apache_beam/runners/portability/fn_api_runner.py index dabb7d687db51..a27e293696983 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner.py @@ -17,6 +17,7 @@ """A PipelineRunner using the SDK harness. """ +import base64 import collections import json import logging @@ -204,11 +205,14 @@ def get_outputs(op_ix): else: # Otherwise serialize the source and execute it there. # TODO: Use SDFs with an initial impulse. + # The Dataflow runner harness strips the base64 encoding. do the same + # here until we get the same thing back that we sent in. transform_spec = beam_runner_api_pb2.FunctionSpec( urn=sdk_worker.PYTHON_SOURCE_URN, parameter=proto_utils.pack_Any( wrappers_pb2.BytesValue( - value=pickler.dumps(operation.source.source)))) + value=base64.b64decode( + pickler.dumps(operation.source.source))))) elif isinstance(operation, operation_specs.WorkerDoFn): # Record the contents of each side input for access via the state api. diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker.py b/sdks/python/apache_beam/runners/worker/sdk_worker.py index fd7ecc4325a89..a2c9f424bbf35 100644 --- a/sdks/python/apache_beam/runners/worker/sdk_worker.py +++ b/sdks/python/apache_beam/runners/worker/sdk_worker.py @@ -21,6 +21,7 @@ from __future__ import division from __future__ import print_function +import base64 import collections import json import logging @@ -195,7 +196,7 @@ def pack_function_spec_data(value, urn, id=None): # pylint: enable=redefined-builtin -# TODO(vikasrk): move this method to ``coders.py`` in the SDK. +# TODO(vikasrk): Consistently use same format everywhere. def load_compressed(compressed_data): """Returns a decompressed and deserialized python object.""" # Note: SDK uses ``pickler.dumps`` to serialize certain python objects @@ -259,6 +260,10 @@ def process_requests(): try: response = self.worker.do_instruction(work_request) except Exception: # pylint: disable=broad-except + logging.error( + 'Error processing instruction %s', + work_request.instruction_id, + exc_info=True) response = beam_fn_api_pb2.InstructionResponse( instruction_id=work_request.instruction_id, error=traceback.format_exc()) @@ -319,10 +324,10 @@ def initial_source_split(self, request, unused_instruction_id=None): return response def create_execution_tree(self, descriptor): - if descriptor.primitive_transform: - return self.create_execution_tree_from_fn_api(descriptor) - else: + if descriptor.transforms: return self.create_execution_tree_from_runner_api(descriptor) + else: + return self.create_execution_tree_from_fn_api(descriptor) def create_execution_tree_from_runner_api(self, descriptor): # TODO(robertwb): Figure out the correct prefix to use for output counters @@ -551,7 +556,15 @@ def create_operation(self, transform_id, consumers): return creator(self, transform_id, transform_proto, parameter, consumers) def get_coder(self, coder_id): - return self.context.coders.get_by_id(coder_id) + coder_proto = self.descriptor.codersyyy[coder_id] + if coder_proto.spec.spec.urn: + return self.context.coders.get_by_id(coder_id) + else: + # No URN, assume cloud object encoding json bytes. + return operation_specs.get_coder_from_spec( + json.loads( + proto_utils.unpack_Any(coder_proto.spec.spec.parameter, + wrappers_pb2.BytesValue).value)) def get_output_coders(self, transform_proto): return { @@ -618,7 +631,8 @@ def create(factory, transform_id, transform_proto, grpc_port, consumers): @BeamTransformFactory.register_urn(PYTHON_SOURCE_URN, wrappers_pb2.BytesValue) def create(factory, transform_id, transform_proto, parameter, consumers): - source = pickler.loads(parameter.value) + # The Dataflow runner harness strips the base64 encoding. + source = pickler.loads(base64.b64encode(parameter.value)) spec = operation_specs.WorkerRead( iobase.SourceBundle(1.0, source, None, None), [WindowedValueCoder(source.default_output_coder())]) From 7471e2736cc22336500f6252ab8448889a2d04d3 Mon Sep 17 00:00:00 2001 From: Charles Chen Date: Thu, 22 Jun 2017 11:29:54 -0700 Subject: [PATCH 088/200] Clean up test_stream_test and remove stray print statement --- .../runners/direct/watermark_manager.py | 1 - .../apache_beam/testing/test_stream_test.py | 16 ++++++---------- 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/sdks/python/apache_beam/runners/direct/watermark_manager.py b/sdks/python/apache_beam/runners/direct/watermark_manager.py index 2146bb5d9b1c2..4aa2bb4342f18 100644 --- a/sdks/python/apache_beam/runners/direct/watermark_manager.py +++ b/sdks/python/apache_beam/runners/direct/watermark_manager.py @@ -175,7 +175,6 @@ def update_input_transform_watermarks(self, input_transform_watermarks): def update_timers(self, completed_timers): with self._lock: for timer_firing in completed_timers: - print 'REMOVE', timer_firing self._fired_timers.remove(timer_firing) @property diff --git a/sdks/python/apache_beam/testing/test_stream_test.py b/sdks/python/apache_beam/testing/test_stream_test.py index bf05ac16f7ef8..071c7cd3d6c00 100644 --- a/sdks/python/apache_beam/testing/test_stream_test.py +++ b/sdks/python/apache_beam/testing/test_stream_test.py @@ -25,6 +25,7 @@ from apache_beam.testing.test_stream import ProcessingTimeEvent from apache_beam.testing.test_stream import TestStream from apache_beam.testing.test_stream import WatermarkEvent +from apache_beam.testing.util import assert_that, equal_to from apache_beam.transforms.window import TimestampedValue from apache_beam.utils import timestamp from apache_beam.utils.windowed_value import WindowedValue @@ -92,28 +93,23 @@ def test_basic_execution(self): .add_elements([TimestampedValue('late', 12)]) .add_elements([TimestampedValue('last', 310)])) - global _seen_elements # pylint: disable=global-variable-undefined - _seen_elements = [] - class RecordFn(beam.DoFn): def process(self, element=beam.DoFn.ElementParam, timestamp=beam.DoFn.TimestampParam): - _seen_elements.append((element, timestamp)) + yield (element, timestamp) p = TestPipeline() my_record_fn = RecordFn() - p | test_stream | beam.ParDo(my_record_fn) # pylint: disable=expression-not-assigned - p.run() - - self.assertEqual([ + records = p | test_stream | beam.ParDo(my_record_fn) + assert_that(records, equal_to([ ('a', timestamp.Timestamp(10)), ('b', timestamp.Timestamp(10)), ('c', timestamp.Timestamp(10)), ('d', timestamp.Timestamp(20)), ('e', timestamp.Timestamp(20)), ('late', timestamp.Timestamp(12)), - ('last', timestamp.Timestamp(310)),], _seen_elements) - del _seen_elements + ('last', timestamp.Timestamp(310)),])) + p.run() if __name__ == '__main__': From 0292a24f9c88796542bff55031d84c11f0ab6b16 Mon Sep 17 00:00:00 2001 From: Colin Phipps Date: Mon, 15 May 2017 14:18:16 +0000 Subject: [PATCH 089/200] [BEAM-2439] Dynamic sizing of Datastore write RPCs This stops the Datastore connector from always sending 500 entities per RPC. Instead, it starts at a lower number which is more likely to complete within the deadline even in adverse conditions, and then increases or reduces the batch size in response to measured latency of past requests. --- .../sdk/io/gcp/datastore/DatastoreV1.java | 124 +++++++++++++++--- .../sdk/io/gcp/datastore/MovingAverage.java | 50 +++++++ .../sdk/io/gcp/datastore/DatastoreV1Test.java | 72 +++++++++- 3 files changed, 225 insertions(+), 21 deletions(-) create mode 100644 sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/datastore/MovingAverage.java diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/datastore/DatastoreV1.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/datastore/DatastoreV1.java index 06b9c8af9319f..e67f4b2fcd972 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/datastore/DatastoreV1.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/datastore/DatastoreV1.java @@ -201,11 +201,31 @@ public class DatastoreV1 { DatastoreV1() {} /** - * Cloud Datastore has a limit of 500 mutations per batch operation, so we flush - * changes to Datastore every 500 entities. + * The number of entity updates written per RPC, initially. We buffer updates in the connector and + * write a batch to Datastore once we have collected a certain number. This is the initial batch + * size; it is adjusted at runtime based on the performance of previous writes (see {@link + * DatastoreV1.WriteBatcher}). + * + *

    Testing has found that a batch of 200 entities will generally finish within the timeout even + * in adverse conditions. + */ + @VisibleForTesting + static final int DATASTORE_BATCH_UPDATE_ENTITIES_START = 200; + + /** + * When choosing the number of updates in a single RPC, never exceed the maximum allowed by the + * API. */ @VisibleForTesting - static final int DATASTORE_BATCH_UPDATE_LIMIT = 500; + static final int DATASTORE_BATCH_UPDATE_ENTITIES_LIMIT = 500; + + /** + * When choosing the number of updates in a single RPC, do not go below this value. The actual + * number of entities per request may be lower when we flush for the end of a bundle or if we hit + * {@link DatastoreV1.DATASTORE_BATCH_UPDATE_BYTES_LIMIT}. + */ + @VisibleForTesting + static final int DATASTORE_BATCH_UPDATE_ENTITIES_MIN = 10; /** * Cloud Datastore has a limit of 10MB per RPC, so we also flush if the total size of mutations @@ -1107,18 +1127,74 @@ public String getProjectId() { } } + /** Determines batch sizes for commit RPCs. */ + @VisibleForTesting + interface WriteBatcher { + /** Call before using this WriteBatcher. */ + void start(); + + /** + * Reports the latency of a previous commit RPC, and the number of mutations that it contained. + */ + void addRequestLatency(long timeSinceEpochMillis, long latencyMillis, int numMutations); + + /** Returns the number of entities to include in the next CommitRequest. */ + int nextBatchSize(long timeSinceEpochMillis); + } + + /** + * Determines batch sizes for commit RPCs based on past performance. + * + *

    It aims for a target response time per RPC: it uses the response times for previous RPCs + * and the number of entities contained in them, calculates a rolling average time-per-entity, and + * chooses the number of entities for future writes to hit the target time. + * + *

    This enables us to send large batches without sending over-large requests in the case of + * expensive entity writes that may timeout before the server can apply them all. + */ + @VisibleForTesting + static class WriteBatcherImpl implements WriteBatcher, Serializable { + /** Target time per RPC for writes. */ + static final int DATASTORE_BATCH_TARGET_LATENCY_MS = 5000; + + @Override + public void start() { + meanLatencyPerEntityMs = new MovingAverage( + 120000 /* sample period 2 minutes */, 10000 /* sample interval 10s */, + 1 /* numSignificantBuckets */, 1 /* numSignificantSamples */); + } + + @Override + public void addRequestLatency(long timeSinceEpochMillis, long latencyMillis, int numMutations) { + meanLatencyPerEntityMs.add(timeSinceEpochMillis, latencyMillis / numMutations); + } + + @Override + public int nextBatchSize(long timeSinceEpochMillis) { + if (!meanLatencyPerEntityMs.hasValue(timeSinceEpochMillis)) { + return DATASTORE_BATCH_UPDATE_ENTITIES_START; + } + long recentMeanLatency = Math.max(meanLatencyPerEntityMs.get(timeSinceEpochMillis), 1); + return (int) Math.max(DATASTORE_BATCH_UPDATE_ENTITIES_MIN, + Math.min(DATASTORE_BATCH_UPDATE_ENTITIES_LIMIT, + DATASTORE_BATCH_TARGET_LATENCY_MS / recentMeanLatency)); + } + + private transient MovingAverage meanLatencyPerEntityMs; + } + /** * {@link DoFn} that writes {@link Mutation}s to Cloud Datastore. Mutations are written in - * batches, where the maximum batch size is {@link DatastoreV1#DATASTORE_BATCH_UPDATE_LIMIT}. + * batches; see {@link DatastoreV1.WriteBatcherImpl}. * *

    See * Datastore: Entities, Properties, and Keys for information about entity keys and mutations. * *

    Commits are non-transactional. If a commit fails because of a conflict over an entity - * group, the commit will be retried (up to {@link DatastoreV1#DATASTORE_BATCH_UPDATE_LIMIT} + * group, the commit will be retried (up to {@link DatastoreV1.DatastoreWriterFn#MAX_RETRIES} * times). This means that the mutation operation should be idempotent. Thus, the writer should - * only be used for {code upsert} and {@code delete} mutation operations, as these are the only + * only be used for {@code upsert} and {@code delete} mutation operations, as these are the only * two Cloud Datastore mutations that are idempotent. */ @VisibleForTesting @@ -1132,6 +1208,7 @@ static class DatastoreWriterFn extends DoFn { // Current batch of mutations to be written. private final List mutations = new ArrayList<>(); private int mutationsSize = 0; // Accumulated size of protos in mutations. + private WriteBatcher writeBatcher; private static final int MAX_RETRIES = 5; private static final FluentBackoff BUNDLE_WRITE_BACKOFF = @@ -1139,24 +1216,27 @@ static class DatastoreWriterFn extends DoFn { .withMaxRetries(MAX_RETRIES).withInitialBackoff(Duration.standardSeconds(5)); DatastoreWriterFn(String projectId, @Nullable String localhost) { - this(StaticValueProvider.of(projectId), localhost, new V1DatastoreFactory()); + this(StaticValueProvider.of(projectId), localhost, new V1DatastoreFactory(), + new WriteBatcherImpl()); } DatastoreWriterFn(ValueProvider projectId, @Nullable String localhost) { - this(projectId, localhost, new V1DatastoreFactory()); + this(projectId, localhost, new V1DatastoreFactory(), new WriteBatcherImpl()); } @VisibleForTesting DatastoreWriterFn(ValueProvider projectId, @Nullable String localhost, - V1DatastoreFactory datastoreFactory) { + V1DatastoreFactory datastoreFactory, WriteBatcher writeBatcher) { this.projectId = checkNotNull(projectId, "projectId"); this.localhost = localhost; this.datastoreFactory = datastoreFactory; + this.writeBatcher = writeBatcher; } @StartBundle public void startBundle(StartBundleContext c) { datastore = datastoreFactory.getDatastore(c.getPipelineOptions(), projectId.get(), localhost); + writeBatcher.start(); } @ProcessElement @@ -1169,7 +1249,7 @@ public void processElement(ProcessContext c) throws Exception { } mutations.add(c.element()); mutationsSize += size; - if (mutations.size() >= DatastoreV1.DATASTORE_BATCH_UPDATE_LIMIT) { + if (mutations.size() >= writeBatcher.nextBatchSize(System.currentTimeMillis())) { flushBatch(); } } @@ -1199,18 +1279,32 @@ private void flushBatch() throws DatastoreException, IOException, InterruptedExc while (true) { // Batch upsert entities. + CommitRequest.Builder commitRequest = CommitRequest.newBuilder(); + commitRequest.addAllMutations(mutations); + commitRequest.setMode(CommitRequest.Mode.NON_TRANSACTIONAL); + long startTime = System.currentTimeMillis(), endTime; + try { - CommitRequest.Builder commitRequest = CommitRequest.newBuilder(); - commitRequest.addAllMutations(mutations); - commitRequest.setMode(CommitRequest.Mode.NON_TRANSACTIONAL); datastore.commit(commitRequest.build()); + endTime = System.currentTimeMillis(); + + writeBatcher.addRequestLatency(endTime, endTime - startTime, mutations.size()); + // Break if the commit threw no exception. break; } catch (DatastoreException exception) { + if (exception.getCode() == Code.DEADLINE_EXCEEDED) { + /* Most errors are not related to request size, and should not change our expectation of + * the latency of successful requests. DEADLINE_EXCEEDED can be taken into + * consideration, though. */ + endTime = System.currentTimeMillis(); + writeBatcher.addRequestLatency(endTime, endTime - startTime, mutations.size()); + } + // Only log the code and message for potentially-transient errors. The entire exception // will be propagated upon the last retry. - LOG.error("Error writing to the Datastore ({}): {}", exception.getCode(), - exception.getMessage()); + LOG.error("Error writing batch of {} mutations to Datastore ({}): {}", mutations.size(), + exception.getCode(), exception.getMessage()); if (!BackOffUtils.next(sleeper, backoff)) { LOG.error("Aborting after {} retries.", MAX_RETRIES); throw exception; diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/datastore/MovingAverage.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/datastore/MovingAverage.java new file mode 100644 index 0000000000000..0890e79473f1a --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/datastore/MovingAverage.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.sdk.io.gcp.datastore; + +import org.apache.beam.sdk.transforms.Sum; +import org.apache.beam.sdk.util.MovingFunction; + + +class MovingAverage { + private final MovingFunction sum; + private final MovingFunction count; + + public MovingAverage(long samplePeriodMs, long sampleUpdateMs, + int numSignificantBuckets, int numSignificantSamples) { + sum = new MovingFunction(samplePeriodMs, sampleUpdateMs, + numSignificantBuckets, numSignificantSamples, Sum.ofLongs()); + count = new MovingFunction(samplePeriodMs, sampleUpdateMs, + numSignificantBuckets, numSignificantSamples, Sum.ofLongs()); + } + + public void add(long nowMsSinceEpoch, long value) { + sum.add(nowMsSinceEpoch, value); + count.add(nowMsSinceEpoch, 1); + } + + public long get(long nowMsSinceEpoch) { + return sum.get(nowMsSinceEpoch) / count.get(nowMsSinceEpoch); + } + + public boolean hasValue(long nowMsSinceEpoch) { + return sum.isSignificant() && count.isSignificant() + && count.get(nowMsSinceEpoch) > 0; + } +} diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/datastore/DatastoreV1Test.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/datastore/DatastoreV1Test.java index 229b1fbb23672..946887c865e3c 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/datastore/DatastoreV1Test.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/datastore/DatastoreV1Test.java @@ -27,7 +27,7 @@ import static com.google.datastore.v1.client.DatastoreHelper.makeUpsert; import static com.google.datastore.v1.client.DatastoreHelper.makeValue; import static org.apache.beam.sdk.io.gcp.datastore.DatastoreV1.DATASTORE_BATCH_UPDATE_BYTES_LIMIT; -import static org.apache.beam.sdk.io.gcp.datastore.DatastoreV1.DATASTORE_BATCH_UPDATE_LIMIT; +import static org.apache.beam.sdk.io.gcp.datastore.DatastoreV1.DATASTORE_BATCH_UPDATE_ENTITIES_START; import static org.apache.beam.sdk.io.gcp.datastore.DatastoreV1.Read.DEFAULT_BUNDLE_SIZE_BYTES; import static org.apache.beam.sdk.io.gcp.datastore.DatastoreV1.Read.QUERY_BATCH_LIMIT; import static org.apache.beam.sdk.io.gcp.datastore.DatastoreV1.Read.getEstimatedSizeBytes; @@ -606,7 +606,7 @@ public void testDatatoreWriterFnWithOneBatch() throws Exception { /** Tests {@link DatastoreWriterFn} with entities of more than one batches, but not a multiple. */ @Test public void testDatatoreWriterFnWithMultipleBatches() throws Exception { - datastoreWriterFnTest(DATASTORE_BATCH_UPDATE_LIMIT * 3 + 100); + datastoreWriterFnTest(DATASTORE_BATCH_UPDATE_ENTITIES_START * 3 + 100); } /** @@ -615,7 +615,7 @@ public void testDatatoreWriterFnWithMultipleBatches() throws Exception { */ @Test public void testDatatoreWriterFnWithBatchesExactMultiple() throws Exception { - datastoreWriterFnTest(DATASTORE_BATCH_UPDATE_LIMIT * 2); + datastoreWriterFnTest(DATASTORE_BATCH_UPDATE_ENTITIES_START * 2); } // A helper method to test DatastoreWriterFn for various batch sizes. @@ -628,14 +628,14 @@ private void datastoreWriterFnTest(int numMutations) throws Exception { } DatastoreWriterFn datastoreWriter = new DatastoreWriterFn(StaticValueProvider.of(PROJECT_ID), - null, mockDatastoreFactory); + null, mockDatastoreFactory, new FakeWriteBatcher()); DoFnTester doFnTester = DoFnTester.of(datastoreWriter); doFnTester.setCloningBehavior(CloningBehavior.DO_NOT_CLONE); doFnTester.processBundle(mutations); int start = 0; while (start < numMutations) { - int end = Math.min(numMutations, start + DATASTORE_BATCH_UPDATE_LIMIT); + int end = Math.min(numMutations, start + DATASTORE_BATCH_UPDATE_ENTITIES_START); CommitRequest.Builder commitRequest = CommitRequest.newBuilder(); commitRequest.setMode(CommitRequest.Mode.NON_TRANSACTIONAL); commitRequest.addAllMutations(mutations.subList(start, end)); @@ -662,7 +662,7 @@ public void testDatatoreWriterFnWithLargeEntities() throws Exception { } DatastoreWriterFn datastoreWriter = new DatastoreWriterFn(StaticValueProvider.of(PROJECT_ID), - null, mockDatastoreFactory); + null, mockDatastoreFactory, new FakeWriteBatcher()); DoFnTester doFnTester = DoFnTester.of(datastoreWriter); doFnTester.setCloningBehavior(CloningBehavior.DO_NOT_CLONE); doFnTester.processBundle(mutations); @@ -896,6 +896,50 @@ public void testRuntimeOptionsNotCalledInApplyGqlQuery() { .apply(DatastoreIO.v1().write().withProjectId(options.getDatastoreProject())); } + @Test + public void testWriteBatcherWithoutData() { + DatastoreV1.WriteBatcher writeBatcher = new DatastoreV1.WriteBatcherImpl(); + writeBatcher.start(); + assertEquals(DatastoreV1.DATASTORE_BATCH_UPDATE_ENTITIES_START, writeBatcher.nextBatchSize(0)); + } + + @Test + public void testWriteBatcherFastQueries() { + DatastoreV1.WriteBatcher writeBatcher = new DatastoreV1.WriteBatcherImpl(); + writeBatcher.start(); + writeBatcher.addRequestLatency(0, 1000, 200); + writeBatcher.addRequestLatency(0, 1000, 200); + assertEquals(DatastoreV1.DATASTORE_BATCH_UPDATE_ENTITIES_LIMIT, writeBatcher.nextBatchSize(0)); + } + + @Test + public void testWriteBatcherSlowQueries() { + DatastoreV1.WriteBatcher writeBatcher = new DatastoreV1.WriteBatcherImpl(); + writeBatcher.start(); + writeBatcher.addRequestLatency(0, 10000, 200); + writeBatcher.addRequestLatency(0, 10000, 200); + assertEquals(100, writeBatcher.nextBatchSize(0)); + } + + @Test + public void testWriteBatcherSizeNotBelowMinimum() { + DatastoreV1.WriteBatcher writeBatcher = new DatastoreV1.WriteBatcherImpl(); + writeBatcher.start(); + writeBatcher.addRequestLatency(0, 30000, 50); + writeBatcher.addRequestLatency(0, 30000, 50); + assertEquals(DatastoreV1.DATASTORE_BATCH_UPDATE_ENTITIES_MIN, writeBatcher.nextBatchSize(0)); + } + + @Test + public void testWriteBatcherSlidingWindow() { + DatastoreV1.WriteBatcher writeBatcher = new DatastoreV1.WriteBatcherImpl(); + writeBatcher.start(); + writeBatcher.addRequestLatency(0, 30000, 50); + writeBatcher.addRequestLatency(50000, 5000, 200); + writeBatcher.addRequestLatency(100000, 5000, 200); + assertEquals(200, writeBatcher.nextBatchSize(150000)); + } + /** Helper Methods */ /** A helper function that verifies if all the queries have unique keys. */ @@ -1039,4 +1083,20 @@ private List splitQuery(Query query, int numSplits) { } return queries; } + + /** + * A WriteBatcher for unit tests, which does no timing-based adjustments (so unit tests have + * consistent results). + */ + static class FakeWriteBatcher implements DatastoreV1.WriteBatcher { + @Override + public void start() {} + @Override + public void addRequestLatency(long timeSinceEpochMillis, long latencyMillis, int numMutations) { + } + @Override + public int nextBatchSize(long timeSinceEpochMillis) { + return DatastoreV1.DATASTORE_BATCH_UPDATE_ENTITIES_START; + } + } } From c05764454e73ab93d0602f34b8b7622d46e1d892 Mon Sep 17 00:00:00 2001 From: Kenneth Knowles Date: Wed, 21 Jun 2017 20:58:35 -0700 Subject: [PATCH 090/200] DataflowRunner: Reject SetState and MapState --- .../dataflow/BatchStatefulParDoOverrides.java | 2 + .../dataflow/DataflowPipelineTranslator.java | 2 + .../beam/runners/dataflow/DataflowRunner.java | 30 +++++++ .../runners/dataflow/DataflowRunnerTest.java | 89 +++++++++++++++++-- 4 files changed, 114 insertions(+), 9 deletions(-) diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchStatefulParDoOverrides.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchStatefulParDoOverrides.java index 4d9a57fbf977d..41202db0e4690 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchStatefulParDoOverrides.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchStatefulParDoOverrides.java @@ -145,6 +145,7 @@ ParDo.SingleOutput, OutputT> getOriginalParDo() { public PCollection expand(PCollection> input) { DoFn, OutputT> fn = originalParDo.getFn(); verifyFnIsStateful(fn); + DataflowRunner.verifyStateSupported(fn); PTransform< PCollection>>>>>, @@ -169,6 +170,7 @@ static class StatefulMultiOutputParDo public PCollectionTuple expand(PCollection> input) { DoFn, OutputT> fn = originalParDo.getFn(); verifyFnIsStateful(fn); + DataflowRunner.verifyStateSupported(fn); PTransform< PCollection>>>>>, diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java index bfd9b649add4b..6d3054407b216 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java @@ -972,6 +972,8 @@ private static void translateFn( fn)); } + DataflowRunner.verifyStateSupported(fn); + stepContext.addInput(PropertyNames.USER_FN, fn.getClass().getName()); stepContext.addInput( PropertyNames.SERIALIZED_FN, diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java index 1741287d77036..4d7f6acfef3bb 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java @@ -107,6 +107,8 @@ import org.apache.beam.sdk.runners.PTransformOverrideFactory; import org.apache.beam.sdk.runners.TransformHierarchy; import org.apache.beam.sdk.runners.TransformHierarchy.Node; +import org.apache.beam.sdk.state.MapState; +import org.apache.beam.sdk.state.SetState; import org.apache.beam.sdk.transforms.Combine; import org.apache.beam.sdk.transforms.Combine.GroupedValues; import org.apache.beam.sdk.transforms.Create; @@ -119,6 +121,8 @@ import org.apache.beam.sdk.transforms.View; import org.apache.beam.sdk.transforms.WithKeys; import org.apache.beam.sdk.transforms.display.DisplayData; +import org.apache.beam.sdk.transforms.reflect.DoFnSignature; +import org.apache.beam.sdk.transforms.reflect.DoFnSignatures; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.util.CoderUtils; import org.apache.beam.sdk.util.InstanceBuilder; @@ -136,6 +140,7 @@ import org.apache.beam.sdk.values.PInput; import org.apache.beam.sdk.values.PValue; import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.sdk.values.ValueWithRecordId; import org.apache.beam.sdk.values.WindowingStrategy; import org.joda.time.DateTimeUtils; @@ -1512,4 +1517,29 @@ static String getContainerImageForJob(DataflowPipelineOptions options) { return workerHarnessContainerImage.replace("IMAGE", "beam-java-batch"); } } + + static void verifyStateSupported(DoFn fn) { + DoFnSignature signature = DoFnSignatures.getSignature(fn.getClass()); + + for (DoFnSignature.StateDeclaration stateDecl : signature.stateDeclarations().values()) { + + // https://issues.apache.org/jira/browse/BEAM-1474 + if (stateDecl.stateType().isSubtypeOf(TypeDescriptor.of(MapState.class))) { + throw new UnsupportedOperationException(String.format( + "%s does not currently support %s", + DataflowRunner.class.getSimpleName(), + MapState.class.getSimpleName() + )); + } + + // https://issues.apache.org/jira/browse/BEAM-1479 + if (stateDecl.stateType().isSubtypeOf(TypeDescriptor.of(SetState.class))) { + throw new UnsupportedOperationException(String.format( + "%s does not currently support %s", + DataflowRunner.class.getSimpleName(), + SetState.class.getSimpleName() + )); + } + } + } } diff --git a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java index aae21cffd4c00..f57c0ee5ad3ca 100644 --- a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java +++ b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java @@ -50,6 +50,7 @@ import java.io.File; import java.io.FileNotFoundException; import java.io.IOException; +import java.io.Serializable; import java.net.URL; import java.net.URLClassLoader; import java.nio.channels.FileChannel; @@ -82,18 +83,26 @@ import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.PipelineOptions.CheckEnabled; import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.options.StreamingOptions; import org.apache.beam.sdk.options.ValueProvider; import org.apache.beam.sdk.options.ValueProvider.StaticValueProvider; import org.apache.beam.sdk.runners.AppliedPTransform; import org.apache.beam.sdk.runners.TransformHierarchy; import org.apache.beam.sdk.runners.TransformHierarchy.Node; +import org.apache.beam.sdk.state.MapState; +import org.apache.beam.sdk.state.SetState; +import org.apache.beam.sdk.state.StateSpec; +import org.apache.beam.sdk.state.StateSpecs; import org.apache.beam.sdk.testing.ExpectedLogs; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.util.GcsUtil; import org.apache.beam.sdk.util.ReleaseInfo; import org.apache.beam.sdk.util.gcsfs.GcsPath; +import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PDone; import org.apache.beam.sdk.values.PValue; @@ -120,7 +129,7 @@ * Tests for the {@link DataflowRunner}. */ @RunWith(JUnit4.class) -public class DataflowRunnerTest { +public class DataflowRunnerTest implements Serializable { private static final String VALID_STAGING_BUCKET = "gs://valid-bucket/staging"; private static final String VALID_TEMP_BUCKET = "gs://valid-bucket/temp"; @@ -130,15 +139,12 @@ public class DataflowRunnerTest { private static final String PROJECT_ID = "some-project"; private static final String REGION_ID = "some-region-1"; - @Rule - public TemporaryFolder tmpFolder = new TemporaryFolder(); - @Rule - public ExpectedException thrown = ExpectedException.none(); - @Rule - public ExpectedLogs expectedLogs = ExpectedLogs.none(DataflowRunner.class); + @Rule public transient TemporaryFolder tmpFolder = new TemporaryFolder(); + @Rule public transient ExpectedException thrown = ExpectedException.none(); + @Rule public transient ExpectedLogs expectedLogs = ExpectedLogs.none(DataflowRunner.class); - private Dataflow.Projects.Locations.Jobs mockJobs; - private GcsUtil mockGcsUtil; + private transient Dataflow.Projects.Locations.Jobs mockJobs; + private transient GcsUtil mockGcsUtil; // Asserts that the given Job has all expected fields set. private static void assertValidJob(Job job) { @@ -1001,6 +1007,71 @@ public void translate( assertTrue(transform.translated); } + private void verifyMapStateUnsupported(PipelineOptions options) throws Exception { + Pipeline p = Pipeline.create(options); + p.apply(Create.of(KV.of(13, 42))) + .apply( + ParDo.of( + new DoFn, Void>() { + @StateId("fizzle") + private final StateSpec> voidState = StateSpecs.map(); + + @ProcessElement + public void process() {} + })); + + thrown.expectMessage("MapState"); + thrown.expect(UnsupportedOperationException.class); + p.run(); + } + + @Test + public void testMapStateUnsupportedInBatch() throws Exception { + PipelineOptions options = buildPipelineOptions(); + options.as(StreamingOptions.class).setStreaming(false); + verifyMapStateUnsupported(options); + } + + @Test + public void testMapStateUnsupportedInStreaming() throws Exception { + PipelineOptions options = buildPipelineOptions(); + options.as(StreamingOptions.class).setStreaming(true); + verifyMapStateUnsupported(options); + } + + private void verifySetStateUnsupported(PipelineOptions options) throws Exception { + Pipeline p = Pipeline.create(options); + p.apply(Create.of(KV.of(13, 42))) + .apply( + ParDo.of( + new DoFn, Void>() { + @StateId("fizzle") + private final StateSpec> voidState = StateSpecs.set(); + + @ProcessElement + public void process() {} + })); + + thrown.expectMessage("SetState"); + thrown.expect(UnsupportedOperationException.class); + p.run(); + } + + @Test + public void testSetStateUnsupportedInBatch() throws Exception { + PipelineOptions options = buildPipelineOptions(); + options.as(StreamingOptions.class).setStreaming(false); + Pipeline p = Pipeline.create(options); + verifySetStateUnsupported(options); + } + + @Test + public void testSetStateUnsupportedInStreaming() throws Exception { + PipelineOptions options = buildPipelineOptions(); + options.as(StreamingOptions.class).setStreaming(true); + verifySetStateUnsupported(options); + } + /** Records all the composite transforms visited within the Pipeline. */ private static class CompositeTransformRecorder extends PipelineVisitor.Defaults { private List> transforms = new ArrayList<>(); From 497cfabea7d6dcee0c5d327022678c571c3ec487 Mon Sep 17 00:00:00 2001 From: Kenneth Knowles Date: Thu, 22 Jun 2017 11:31:28 -0700 Subject: [PATCH 091/200] Add window matcher for pane info --- .../apache/beam/runners/core/WindowMatchers.java | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/WindowMatchers.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/WindowMatchers.java index 9769d10f85849..26cbfee92275c 100644 --- a/runners/core-java/src/test/java/org/apache/beam/runners/core/WindowMatchers.java +++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/WindowMatchers.java @@ -115,6 +115,21 @@ public static Matcher> isSingleWindowedValue( Matchers.anything()); } + public static Matcher> isSingleWindowedValue( + Matcher valueMatcher, + long timestamp, + long windowStart, + long windowEnd, + PaneInfo paneInfo) { + IntervalWindow intervalWindow = + new IntervalWindow(new Instant(windowStart), new Instant(windowEnd)); + return WindowMatchers.isSingleWindowedValue( + valueMatcher, + Matchers.describedAs("%0", Matchers.equalTo(new Instant(timestamp)), timestamp), + Matchers.equalTo(intervalWindow), + Matchers.equalTo(paneInfo)); + } + public static Matcher> isSingleWindowedValue( Matcher valueMatcher, Matcher timestampMatcher, From d4e5db51a025a831ddf4e3bc0e003caebabf647b Mon Sep 17 00:00:00 2001 From: Kenneth Knowles Date: Thu, 22 Jun 2017 11:56:53 -0700 Subject: [PATCH 092/200] Tidy LateDataDroppingDoFnRunner --- .../core/LateDataDroppingDoFnRunner.java | 33 ++++++++++--------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/LateDataDroppingDoFnRunner.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/LateDataDroppingDoFnRunner.java index 1cf150973f7ac..28938c1f8c1d3 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/LateDataDroppingDoFnRunner.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/LateDataDroppingDoFnRunner.java @@ -134,26 +134,27 @@ public WindowedValue apply(BoundedWindow window) { // The element is too late for this window. droppedDueToLateness.inc(); WindowTracing.debug( - "ReduceFnRunner.processElement: Dropping element at {} for key:{}; window:{} " - + "since too far behind inputWatermark:{}; outputWatermark:{}", - input.getTimestamp(), key, window, timerInternals.currentInputWatermarkTime(), + "{}: Dropping element at {} for key:{}; window:{} " + + "since too far behind inputWatermark:{}; outputWatermark:{}", + LateDataFilter.class.getSimpleName(), + input.getTimestamp(), + key, + window, + timerInternals.currentInputWatermarkTime(), timerInternals.currentOutputWatermarkTime()); } } - Iterable> nonLateElements = Iterables.filter( - concatElements, - new Predicate>() { - @Override - public boolean apply(WindowedValue input) { - BoundedWindow window = Iterables.getOnlyElement(input.getWindows()); - if (canDropDueToExpiredWindow(window)) { - return false; - } else { - return true; - } - } - }); + Iterable> nonLateElements = + Iterables.filter( + concatElements, + new Predicate>() { + @Override + public boolean apply(WindowedValue input) { + BoundedWindow window = Iterables.getOnlyElement(input.getWindows()); + return !canDropDueToExpiredWindow(window); + } + }); return nonLateElements; } From 1c1f239501349f5120b0d619c4eea9c435500b78 Mon Sep 17 00:00:00 2001 From: Kenneth Knowles Date: Thu, 22 Jun 2017 12:52:42 -0700 Subject: [PATCH 093/200] ReduceFnTester can advance clocks without firing timers --- .../beam/runners/core/ReduceFnTester.java | 24 ++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/ReduceFnTester.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/ReduceFnTester.java index 7f83eae787243..ab9fd6e9cb196 100644 --- a/runners/core-java/src/test/java/org/apache/beam/runners/core/ReduceFnTester.java +++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/ReduceFnTester.java @@ -420,6 +420,10 @@ public WindowedValue apply(WindowedValue> input) { return result; } + public void advanceInputWatermarkNoTimers(Instant newInputWatermark) throws Exception { + timerInternals.advanceInputWatermark(newInputWatermark); + } + /** * Advance the input watermark to the specified time, firing any timers that should * fire. Then advance the output watermark as far as possible. @@ -451,6 +455,10 @@ public void advanceInputWatermark(Instant newInputWatermark) throws Exception { runner.persist(); } + public void advanceProcessingTimeNoTimers(Instant newProcessingTime) throws Exception { + timerInternals.advanceProcessingTime(newProcessingTime); + } + /** * If {@link #autoAdvanceOutputWatermark} is {@literal false}, advance the output watermark * to the given value. Otherwise throw. @@ -535,13 +543,27 @@ public WindowedValue apply(TimestampedValue input) { public void fireTimer(W window, Instant timestamp, TimeDomain domain) throws Exception { ReduceFnRunner runner = createRunner(); - ArrayList timers = new ArrayList(1); + ArrayList timers = new ArrayList<>(1); timers.add( TimerData.of(StateNamespaces.window(windowFn.windowCoder(), window), timestamp, domain)); runner.onTimers(timers); runner.persist(); } + public void fireTimers(W window, TimestampedValue... timers) throws Exception { + ReduceFnRunner runner = createRunner(); + ArrayList timerData = new ArrayList<>(timers.length); + for (TimestampedValue timer : timers) { + timerData.add( + TimerData.of( + StateNamespaces.window(windowFn.windowCoder(), window), + timer.getTimestamp(), + timer.getValue())); + } + runner.onTimers(timerData); + runner.persist(); + } + /** * Convey the simulated state and implement {@link #outputWindowedValue} to capture all output * elements. From 795760d370bcbe28e1f0ca373ad4c8c841e6e6b5 Mon Sep 17 00:00:00 2001 From: Kenneth Knowles Date: Thu, 22 Jun 2017 12:53:15 -0700 Subject: [PATCH 094/200] ReduceFnTester assertion for windows that have data buffered --- .../apache/beam/runners/core/SystemReduceFn.java | 6 ++++++ .../apache/beam/runners/core/ReduceFnTester.java | 13 +++++++++++++ 2 files changed, 19 insertions(+) diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/SystemReduceFn.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/SystemReduceFn.java index c189b0d7ee3f1..3144bd6f22549 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/SystemReduceFn.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/SystemReduceFn.java @@ -18,6 +18,7 @@ package org.apache.beam.runners.core; +import com.google.common.annotations.VisibleForTesting; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.state.BagState; import org.apache.beam.sdk.state.CombiningState; @@ -103,6 +104,11 @@ public SystemReduceFn( this.bufferTag = bufferTag; } + @VisibleForTesting + StateTag> getBufferTag() { + return bufferTag; + } + @Override public void processValue(ProcessValueContext c) throws Exception { c.state().access(bufferTag).add(c.value()); diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/ReduceFnTester.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/ReduceFnTester.java index ab9fd6e9cb196..1fe8f73f4c28e 100644 --- a/runners/core-java/src/test/java/org/apache/beam/runners/core/ReduceFnTester.java +++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/ReduceFnTester.java @@ -317,6 +317,19 @@ public final void assertHasOnlyGlobalAndFinishedSetsFor(W... expectedWindows) { ImmutableSet.>of(TriggerStateMachineRunner.FINISHED_BITS_TAG)); } + @SafeVarargs + public final void assertHasOnlyGlobalAndStateFor(W... expectedWindows) { + assertHasOnlyGlobalAndAllowedTags( + ImmutableSet.copyOf(expectedWindows), + ImmutableSet.>of( + ((SystemReduceFn) reduceFn).getBufferTag(), + TriggerStateMachineRunner.FINISHED_BITS_TAG, + PaneInfoTracker.PANE_INFO_TAG, + WatermarkHold.watermarkHoldTagForTimestampCombiner( + objectStrategy.getTimestampCombiner()), + WatermarkHold.EXTRA_HOLD_TAG)); + } + @SafeVarargs public final void assertHasOnlyGlobalAndFinishedSetsAndPaneInfoFor(W... expectedWindows) { assertHasOnlyGlobalAndAllowedTags( From 412fd7eab9e58a4d412f4dff5ffec023610b4f22 Mon Sep 17 00:00:00 2001 From: Kenneth Knowles Date: Thu, 22 Jun 2017 12:56:14 -0700 Subject: [PATCH 095/200] Drop late data in ReduceFnTester --- .../org/apache/beam/runners/core/ReduceFnTester.java | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/ReduceFnTester.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/ReduceFnTester.java index 1fe8f73f4c28e..7ca96b9b549de 100644 --- a/runners/core-java/src/test/java/org/apache/beam/runners/core/ReduceFnTester.java +++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/ReduceFnTester.java @@ -529,8 +529,8 @@ public final void injectElements(TimestampedValue... values) throws Exce for (TimestampedValue value : values) { WindowTracing.trace("TriggerTester.injectElements: {}", value); } - ReduceFnRunner runner = createRunner(); - runner.processElements( + + Iterable> inputs = Iterables.transform( Arrays.asList(values), new Function, WindowedValue>() { @@ -548,7 +548,12 @@ public WindowedValue apply(TimestampedValue input) { throw new RuntimeException(e); } } - })); + }); + + ReduceFnRunner runner = createRunner(); + runner.processElements( + new LateDataDroppingDoFnRunner.LateDataFilter(objectStrategy, timerInternals) + .filter(KEY, inputs)); // Persist after each bundle. runner.persist(); From 50c43d96adb8c2523cf38c09f32e241eacc47823 Mon Sep 17 00:00:00 2001 From: Kenneth Knowles Date: Thu, 22 Jun 2017 12:56:34 -0700 Subject: [PATCH 096/200] Do not GC windows based on processing time timer! --- .../beam/runners/core/ReduceFnRunner.java | 3 +- .../beam/runners/core/ReduceFnRunnerTest.java | 35 ++++++++++++++++++- 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/ReduceFnRunner.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/ReduceFnRunner.java index b5c3e3ecc016a..75b6acda3312a 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/ReduceFnRunner.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/ReduceFnRunner.java @@ -663,7 +663,8 @@ private class EnrichedTimerData { this.isEndOfWindow = TimeDomain.EVENT_TIME == timer.getDomain() && timer.getTimestamp().equals(window.maxTimestamp()); Instant cleanupTime = LateDataUtils.garbageCollectionTime(window, windowingStrategy); - this.isGarbageCollection = !timer.getTimestamp().isBefore(cleanupTime); + this.isGarbageCollection = + TimeDomain.EVENT_TIME == timer.getDomain() && !timer.getTimestamp().isBefore(cleanupTime); } // Has this window had its trigger finish? diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/ReduceFnRunnerTest.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/ReduceFnRunnerTest.java index 9e71300fb335b..2b661626e871f 100644 --- a/runners/core-java/src/test/java/org/apache/beam/runners/core/ReduceFnRunnerTest.java +++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/ReduceFnRunnerTest.java @@ -140,7 +140,40 @@ public Void answer(InvocationOnMock invocation) throws Exception { } }) .when(mockTrigger).onFire(anyTriggerContext()); - } + } + + /** + * Tests that a processing time timer does not cause window GC. + */ + @Test + public void testProcessingTimeTimerDoesNotGc() throws Exception { + WindowingStrategy strategy = + WindowingStrategy.of((WindowFn) FixedWindows.of(Duration.millis(100))) + .withTimestampCombiner(TimestampCombiner.EARLIEST) + .withMode(AccumulationMode.ACCUMULATING_FIRED_PANES) + .withAllowedLateness(Duration.ZERO) + .withTrigger( + Repeatedly.forever( + AfterProcessingTime.pastFirstElementInPane().plusDelayOf(Duration.millis(10)))); + + ReduceFnTester tester = + ReduceFnTester.combining(strategy, Sum.ofIntegers(), VarIntCoder.of()); + + tester.advanceProcessingTime(new Instant(5000)); + injectElement(tester, 2); // processing timer @ 5000 + 10; EOW timer @ 100 + injectElement(tester, 5); + + tester.advanceProcessingTime(new Instant(10000)); + + tester.assertHasOnlyGlobalAndStateFor( + new IntervalWindow(new Instant(0), new Instant(100))); + + assertThat( + tester.extractOutput(), + contains( + isSingleWindowedValue( + equalTo(7), 2, 0, 100, PaneInfo.createPane(true, false, Timing.EARLY, 0, 0)))); + } @Test public void testOnElementBufferingDiscarding() throws Exception { From fda589c00c8920e76cfc9aaa87cecfa94077599d Mon Sep 17 00:00:00 2001 From: Kenneth Knowles Date: Thu, 22 Jun 2017 13:04:23 -0700 Subject: [PATCH 097/200] Add test reproducing BEAM-2505, ignored --- .../beam/runners/core/ReduceFnRunnerTest.java | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/ReduceFnRunnerTest.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/ReduceFnRunnerTest.java index 2b661626e871f..fa5ba8bda1da5 100644 --- a/runners/core-java/src/test/java/org/apache/beam/runners/core/ReduceFnRunnerTest.java +++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/ReduceFnRunnerTest.java @@ -78,6 +78,7 @@ import org.joda.time.Duration; import org.joda.time.Instant; import org.junit.Before; +import org.junit.Ignore; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -349,6 +350,36 @@ public BoundedWindow window() { assertThat(tester.extractOutput(), contains(isWindowedValue(equalTo(55)))); } + /** + * Tests that if end-of-window and GC timers come in together, that the pane is correctly + * marked as final. + */ + @Test + @Ignore("https://issues.apache.org/jira/browse/BEAM-2505") + public void testCombiningAccumulatingEventTime() throws Exception { + WindowingStrategy strategy = + WindowingStrategy.of((WindowFn) FixedWindows.of(Duration.millis(100))) + .withTimestampCombiner(TimestampCombiner.EARLIEST) + .withMode(AccumulationMode.ACCUMULATING_FIRED_PANES) + .withAllowedLateness(Duration.millis(1)) + .withTrigger(Repeatedly.forever(AfterWatermark.pastEndOfWindow())); + + ReduceFnTester tester = + ReduceFnTester.combining(strategy, Sum.ofIntegers(), VarIntCoder.of()); + + injectElement(tester, 2); // processing timer @ 5000 + 10; EOW timer @ 100 + injectElement(tester, 5); + + tester.advanceInputWatermark(new Instant(1000)); + + assertThat( + tester.extractOutput(), + contains( + isSingleWindowedValue( + equalTo(7), 2, 0, 100, PaneInfo.createPane(true, true, Timing.ON_TIME, 0, 0)))); + } + + @Test public void testOnElementCombiningAccumulating() throws Exception { // Test basic execution of a trigger using a non-combining window set and accumulating mode. From d2b384a20dbb0213d0f63e74713a06d63bad8d39 Mon Sep 17 00:00:00 2001 From: Kenneth Knowles Date: Thu, 22 Jun 2017 13:05:42 -0700 Subject: [PATCH 098/200] Add tests for corner cases of processing time timers --- .../beam/runners/core/ReduceFnRunnerTest.java | 70 +++++++++++++++++++ .../beam/sdk/transforms/GroupByKeyTest.java | 39 +++++++++++ 2 files changed, 109 insertions(+) diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/ReduceFnRunnerTest.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/ReduceFnRunnerTest.java index fa5ba8bda1da5..4f68038f38a7f 100644 --- a/runners/core-java/src/test/java/org/apache/beam/runners/core/ReduceFnRunnerTest.java +++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/ReduceFnRunnerTest.java @@ -283,6 +283,44 @@ public void testOnElementCombiningDiscarding() throws Exception { tester.assertHasOnlyGlobalAndFinishedSetsFor(firstWindow); } + /** + * Tests that when a processing time timer comes in after a window is expired + * but in the same bundle it does not cause a spurious output. + */ + @Test + public void testCombiningAccumulatingProcessingTime() throws Exception { + WindowingStrategy strategy = + WindowingStrategy.of((WindowFn) FixedWindows.of(Duration.millis(100))) + .withTimestampCombiner(TimestampCombiner.EARLIEST) + .withMode(AccumulationMode.ACCUMULATING_FIRED_PANES) + .withAllowedLateness(Duration.ZERO) + .withTrigger( + Repeatedly.forever( + AfterProcessingTime.pastFirstElementInPane().plusDelayOf(Duration.millis(10)))); + + ReduceFnTester tester = + ReduceFnTester.combining(strategy, Sum.ofIntegers(), VarIntCoder.of()); + + tester.advanceProcessingTime(new Instant(5000)); + injectElement(tester, 2); // processing timer @ 5000 + 10; EOW timer @ 100 + injectElement(tester, 5); + + tester.advanceInputWatermarkNoTimers(new Instant(100)); + tester.advanceProcessingTimeNoTimers(new Instant(5010)); + + // Fires the GC/EOW timer at the same time as the processing time timer. + tester.fireTimers( + new IntervalWindow(new Instant(0), new Instant(100)), + TimestampedValue.of(TimeDomain.EVENT_TIME, new Instant(100)), + TimestampedValue.of(TimeDomain.PROCESSING_TIME, new Instant(5010))); + + assertThat( + tester.extractOutput(), + contains( + isSingleWindowedValue( + equalTo(7), 2, 0, 100, PaneInfo.createPane(true, true, Timing.ON_TIME, 0, 0)))); + } + /** * Tests that the garbage collection time for a fixed window does not overflow the end of time. */ @@ -350,6 +388,38 @@ public BoundedWindow window() { assertThat(tester.extractOutput(), contains(isWindowedValue(equalTo(55)))); } + /** + * Tests that when a processing time timers comes in after a window is expired + * and GC'd it does not cause a spurious output. + */ + @Test + public void testCombiningAccumulatingProcessingTimeSeparateBundles() throws Exception { + WindowingStrategy strategy = + WindowingStrategy.of((WindowFn) FixedWindows.of(Duration.millis(100))) + .withTimestampCombiner(TimestampCombiner.EARLIEST) + .withMode(AccumulationMode.ACCUMULATING_FIRED_PANES) + .withAllowedLateness(Duration.ZERO) + .withTrigger( + Repeatedly.forever( + AfterProcessingTime.pastFirstElementInPane().plusDelayOf(Duration.millis(10)))); + + ReduceFnTester tester = + ReduceFnTester.combining(strategy, Sum.ofIntegers(), VarIntCoder.of()); + + tester.advanceProcessingTime(new Instant(5000)); + injectElement(tester, 2); // processing timer @ 5000 + 10; EOW timer @ 100 + injectElement(tester, 5); + + tester.advanceInputWatermark(new Instant(100)); + tester.advanceProcessingTime(new Instant(5011)); + + assertThat( + tester.extractOutput(), + contains( + isSingleWindowedValue( + equalTo(7), 2, 0, 100, PaneInfo.createPane(true, true, Timing.ON_TIME, 0, 0)))); + } + /** * Tests that if end-of-window and GC timers come in together, that the pane is correctly * marked as final. diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupByKeyTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupByKeyTest.java index 171171f33cd4d..4b5d5f5e5975e 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupByKeyTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupByKeyTest.java @@ -45,14 +45,19 @@ import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.coders.MapCoder; import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.coders.VarIntCoder; import org.apache.beam.sdk.testing.LargeKeys; import org.apache.beam.sdk.testing.NeedsRunner; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.testing.TestStream; +import org.apache.beam.sdk.testing.UsesTestStream; import org.apache.beam.sdk.testing.ValidatesRunner; import org.apache.beam.sdk.transforms.display.DisplayData; +import org.apache.beam.sdk.transforms.windowing.AfterProcessingTime; import org.apache.beam.sdk.transforms.windowing.FixedWindows; import org.apache.beam.sdk.transforms.windowing.InvalidWindows; +import org.apache.beam.sdk.transforms.windowing.Repeatedly; import org.apache.beam.sdk.transforms.windowing.Sessions; import org.apache.beam.sdk.transforms.windowing.TimestampCombiner; import org.apache.beam.sdk.transforms.windowing.Window; @@ -184,6 +189,40 @@ public void testGroupByKeyEmpty() { p.run(); } + /** + * Tests that when a processing time timers comes in after a window is expired it does not cause a + * spurious output. + */ + @Test + @Category({ValidatesRunner.class, UsesTestStream.class}) + public void testCombiningAccumulatingProcessingTime() throws Exception { + PCollection triggeredSums = + p.apply( + TestStream.create(VarIntCoder.of()) + .advanceWatermarkTo(new Instant(0)) + .addElements( + TimestampedValue.of(2, new Instant(2)), + TimestampedValue.of(5, new Instant(5))) + .advanceWatermarkTo(new Instant(100)) + .advanceProcessingTime(Duration.millis(10)) + .advanceWatermarkToInfinity()) + .apply( + Window.into(FixedWindows.of(Duration.millis(100))) + .withTimestampCombiner(TimestampCombiner.EARLIEST) + .accumulatingFiredPanes() + .withAllowedLateness(Duration.ZERO) + .triggering( + Repeatedly.forever( + AfterProcessingTime.pastFirstElementInPane() + .plusDelayOf(Duration.millis(10))))) + .apply(Sum.integersGlobally().withoutDefaults()); + + PAssert.that(triggeredSums) + .containsInAnyOrder(7); + + p.run(); + } + @Test public void testGroupByKeyNonDeterministic() throws Exception { From 5d6ad19958d0a2394f9e33720a04cc954279a7e7 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Thu, 22 Jun 2017 12:44:23 -0700 Subject: [PATCH 099/200] Remove fn api bundle descriptor translation. --- .../runners/portability/fn_api_runner.py | 191 +----------------- .../runners/portability/fn_api_runner_test.py | 18 +- .../apache_beam/runners/worker/sdk_worker.py | 150 -------------- 3 files changed, 4 insertions(+), 355 deletions(-) diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner.py b/sdks/python/apache_beam/runners/portability/fn_api_runner.py index a27e293696983..b45ff76f7e27c 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner.py @@ -115,13 +115,9 @@ def process(self, source): class FnApiRunner(maptask_executor_runner.MapTaskExecutorRunner): - def __init__(self, use_runner_protos=False): + def __init__(self): super(FnApiRunner, self).__init__() self._last_uid = -1 - if use_runner_protos: - self._map_task_to_protos = self._map_task_to_runner_protos - else: - self._map_task_to_protos = self._map_task_to_fn_protos def has_metrics_support(self): return False @@ -145,7 +141,7 @@ def _map_task_registration(self, map_task, state_handler, process_bundle_descriptor=[process_bundle_descriptor]) ), runner_sinks, input_data - def _map_task_to_runner_protos(self, map_task, data_operation_spec): + def _map_task_to_protos(self, map_task, data_operation_spec): input_data = {} side_input_data = {} runner_sinks = {} @@ -265,189 +261,6 @@ def get_outputs(op_ix): environments=dict(context_proto.environments.items())) return input_data, side_input_data, runner_sinks, process_bundle_descriptor - def _map_task_to_fn_protos(self, map_task, data_operation_spec): - - input_data = {} - side_input_data = {} - runner_sinks = {} - transforms = [] - transform_index_to_id = {} - - # Maps coders to new coder objects and references. - coders = {} - - def coder_id(coder): - if coder not in coders: - coders[coder] = beam_fn_api_pb2.Coder( - function_spec=sdk_worker.pack_function_spec_data( - json.dumps(coder.as_cloud_object()), - sdk_worker.PYTHON_CODER_URN, id=self._next_uid())) - - return coders[coder].function_spec.id - - def output_tags(op): - return getattr(op, 'output_tags', ['out']) - - def as_target(op_input): - input_op_index, input_output_index = op_input - input_op = map_task[input_op_index][1] - return { - 'ignored_input_tag': - beam_fn_api_pb2.Target.List(target=[ - beam_fn_api_pb2.Target( - primitive_transform_reference=transform_index_to_id[ - input_op_index], - name=output_tags(input_op)[input_output_index]) - ]) - } - - def outputs(op): - return { - tag: beam_fn_api_pb2.PCollection(coder_reference=coder_id(coder)) - for tag, coder in zip(output_tags(op), op.output_coders) - } - - for op_ix, (stage_name, operation) in enumerate(map_task): - transform_id = transform_index_to_id[op_ix] = self._next_uid() - if isinstance(operation, operation_specs.WorkerInMemoryWrite): - # Write this data back to the runner. - fn = beam_fn_api_pb2.FunctionSpec(urn=sdk_worker.DATA_OUTPUT_URN, - id=self._next_uid()) - if data_operation_spec: - fn.data.Pack(data_operation_spec) - inputs = as_target(operation.input) - side_inputs = {} - runner_sinks[(transform_id, 'out')] = operation - - elif isinstance(operation, operation_specs.WorkerRead): - # A Read is either translated to a direct injection of windowed values - # into the sdk worker, or an injection of the source object into the - # sdk worker as data followed by an SDF that reads that source. - if (isinstance(operation.source.source, - maptask_executor_runner.InMemorySource) - and isinstance(operation.source.source.default_output_coder(), - WindowedValueCoder)): - output_stream = create_OutputStream() - element_coder = ( - operation.source.source.default_output_coder().get_impl()) - # Re-encode the elements in the nested context and - # concatenate them together - for element in operation.source.source.read(None): - element_coder.encode_to_stream(element, output_stream, True) - target_name = self._next_uid() - input_data[(transform_id, target_name)] = output_stream.get() - fn = beam_fn_api_pb2.FunctionSpec(urn=sdk_worker.DATA_INPUT_URN, - id=self._next_uid()) - if data_operation_spec: - fn.data.Pack(data_operation_spec) - inputs = {target_name: beam_fn_api_pb2.Target.List()} - side_inputs = {} - else: - # Read the source object from the runner. - source_coder = beam.coders.DillCoder() - input_transform_id = self._next_uid() - output_stream = create_OutputStream() - source_coder.get_impl().encode_to_stream( - GlobalWindows.windowed_value(operation.source), - output_stream, - True) - target_name = self._next_uid() - input_data[(input_transform_id, target_name)] = output_stream.get() - input_ptransform = beam_fn_api_pb2.PrimitiveTransform( - id=input_transform_id, - function_spec=beam_fn_api_pb2.FunctionSpec( - urn=sdk_worker.DATA_INPUT_URN, - id=self._next_uid()), - # TODO(robertwb): Possible name collision. - step_name=stage_name + '/inject_source', - inputs={target_name: beam_fn_api_pb2.Target.List()}, - outputs={ - 'out': - beam_fn_api_pb2.PCollection( - coder_reference=coder_id(source_coder)) - }) - if data_operation_spec: - input_ptransform.function_spec.data.Pack(data_operation_spec) - transforms.append(input_ptransform) - - # Read the elements out of the source. - fn = sdk_worker.pack_function_spec_data( - OLDE_SOURCE_SPLITTABLE_DOFN_DATA, - sdk_worker.PYTHON_DOFN_URN, - id=self._next_uid()) - inputs = { - 'ignored_input_tag': - beam_fn_api_pb2.Target.List(target=[ - beam_fn_api_pb2.Target( - primitive_transform_reference=input_transform_id, - name='out') - ]) - } - side_inputs = {} - - elif isinstance(operation, operation_specs.WorkerDoFn): - fn = sdk_worker.pack_function_spec_data( - operation.serialized_fn, - sdk_worker.PYTHON_DOFN_URN, - id=self._next_uid()) - inputs = as_target(operation.input) - # Store the contents of each side input for state access. - for si in operation.side_inputs: - assert isinstance(si.source, iobase.BoundedSource) - element_coder = si.source.default_output_coder() - view_id = self._next_uid() - # TODO(robertwb): Actually flesh out the ViewFn API. - side_inputs[si.tag] = beam_fn_api_pb2.SideInput( - view_fn=sdk_worker.serialize_and_pack_py_fn( - element_coder, urn=sdk_worker.PYTHON_ITERABLE_VIEWFN_URN, - id=view_id)) - # Re-encode the elements in the nested context and - # concatenate them together - output_stream = create_OutputStream() - for element in si.source.read( - si.source.get_range_tracker(None, None)): - element_coder.get_impl().encode_to_stream( - element, output_stream, True) - elements_data = output_stream.get() - side_input_data[view_id] = elements_data - - elif isinstance(operation, operation_specs.WorkerFlatten): - fn = sdk_worker.pack_function_spec_data( - operation.serialized_fn, - sdk_worker.IDENTITY_DOFN_URN, - id=self._next_uid()) - inputs = { - 'ignored_input_tag': - beam_fn_api_pb2.Target.List(target=[ - beam_fn_api_pb2.Target( - primitive_transform_reference=transform_index_to_id[ - input_op_index], - name=output_tags(map_task[input_op_index][1])[ - input_output_index]) - for input_op_index, input_output_index in operation.inputs - ]) - } - side_inputs = {} - - else: - raise TypeError(operation) - - ptransform = beam_fn_api_pb2.PrimitiveTransform( - id=transform_id, - function_spec=fn, - step_name=stage_name, - inputs=inputs, - side_inputs=side_inputs, - outputs=outputs(operation)) - transforms.append(ptransform) - - process_bundle_descriptor = beam_fn_api_pb2.ProcessBundleDescriptor( - id=self._next_uid(), - coders=coders.values(), - primitive_transform=transforms) - - return input_data, side_input_data, runner_sinks, process_bundle_descriptor - def _run_map_task( self, map_task, control_handler, state_handler, data_plane_handler, data_operation_spec): diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py index e2eae26b2179f..91590351e99ef 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py @@ -23,26 +23,12 @@ from apache_beam.runners.portability import maptask_executor_runner_test -class FnApiRunnerTestWithRunnerProtos( +class FnApiRunnerTest( maptask_executor_runner_test.MapTaskExecutorRunnerTest): def create_pipeline(self): return beam.Pipeline( - runner=fn_api_runner.FnApiRunner(use_runner_protos=True)) - - def test_combine_per_key(self): - # TODO(robertwb): Implement PGBKCV operation. - pass - - # Inherits all tests from maptask_executor_runner.MapTaskExecutorRunner - - -class FnApiRunnerTestWithFnProtos( - maptask_executor_runner_test.MapTaskExecutorRunnerTest): - - def create_pipeline(self): - return beam.Pipeline( - runner=fn_api_runner.FnApiRunner(use_runner_protos=False)) + runner=fn_api_runner.FnApiRunner()) def test_combine_per_key(self): # TODO(robertwb): Implement PGBKCV operation. diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker.py b/sdks/python/apache_beam/runners/worker/sdk_worker.py index a2c9f424bbf35..d1359848d6417 100644 --- a/sdks/python/apache_beam/runners/worker/sdk_worker.py +++ b/sdks/python/apache_beam/runners/worker/sdk_worker.py @@ -196,25 +196,6 @@ def pack_function_spec_data(value, urn, id=None): # pylint: enable=redefined-builtin -# TODO(vikasrk): Consistently use same format everywhere. -def load_compressed(compressed_data): - """Returns a decompressed and deserialized python object.""" - # Note: SDK uses ``pickler.dumps`` to serialize certain python objects - # (like sources), which involves serialization, compression and base64 - # encoding. We cannot directly use ``pickler.loads`` for - # deserialization, as the runner would have already base64 decoded the - # data. So we only need to decompress and deserialize. - - data = zlib.decompress(compressed_data) - try: - return dill.loads(data) - except Exception: # pylint: disable=broad-except - dill.dill._trace(True) # pylint: disable=protected-access - return dill.loads(data) - finally: - dill.dill._trace(False) # pylint: disable=protected-access - - def memoize(func): cache = {} missing = object() @@ -324,12 +305,6 @@ def initial_source_split(self, request, unused_instruction_id=None): return response def create_execution_tree(self, descriptor): - if descriptor.transforms: - return self.create_execution_tree_from_runner_api(descriptor) - else: - return self.create_execution_tree_from_fn_api(descriptor) - - def create_execution_tree_from_runner_api(self, descriptor): # TODO(robertwb): Figure out the correct prefix to use for output counters # from StateSampler. counter_factory = counters.CounterFactory() @@ -368,131 +343,6 @@ def topological_height(transform_id): for transform_id in sorted( descriptor.transforms, key=topological_height, reverse=True)] - def create_execution_tree_from_fn_api(self, descriptor): - # TODO(vikasrk): Add an id field to Coder proto and use that instead. - coders = {coder.function_spec.id: operation_specs.get_coder_from_spec( - json.loads(unpack_function_spec_data(coder.function_spec))) - for coder in descriptor.coders} - - counter_factory = counters.CounterFactory() - # TODO(robertwb): Figure out the correct prefix to use for output counters - # from StateSampler. - state_sampler = statesampler.StateSampler( - 'fnapi-step%s-' % descriptor.id, counter_factory) - consumers = collections.defaultdict(lambda: collections.defaultdict(list)) - ops_by_id = {} - reversed_ops = [] - - for transform in reversed(descriptor.primitive_transform): - # TODO(robertwb): Figure out how to plumb through the operation name (e.g. - # "s3") from the service through the FnAPI so that msec counters can be - # reported and correctly plumbed through the service and the UI. - operation_name = 'fnapis%s' % transform.id - - def only_element(iterable): - element, = iterable - return element - - if transform.function_spec.urn == DATA_OUTPUT_URN: - target = beam_fn_api_pb2.Target( - primitive_transform_reference=transform.id, - name=only_element(transform.outputs.keys())) - - op = DataOutputOperation( - operation_name, - transform.step_name, - consumers[transform.id], - counter_factory, - state_sampler, - coders[only_element(transform.outputs.values()).coder_reference], - target, - self.data_channel_factory.create_data_channel( - transform.function_spec)) - - elif transform.function_spec.urn == DATA_INPUT_URN: - target = beam_fn_api_pb2.Target( - primitive_transform_reference=transform.id, - name=only_element(transform.inputs.keys())) - op = DataInputOperation( - operation_name, - transform.step_name, - consumers[transform.id], - counter_factory, - state_sampler, - coders[only_element(transform.outputs.values()).coder_reference], - target, - self.data_channel_factory.create_data_channel( - transform.function_spec)) - - elif transform.function_spec.urn == PYTHON_DOFN_URN: - def create_side_input(tag, si): - # TODO(robertwb): Extract windows (and keys) out of element data. - return operation_specs.WorkerSideInputSource( - tag=tag, - source=SideInputSource( - self.state_handler, - beam_fn_api_pb2.StateKey.MultimapSideInput( - key=si.view_fn.id.encode('utf-8')), - coder=unpack_and_deserialize_py_fn(si.view_fn))) - output_tags = list(transform.outputs.keys()) - spec = operation_specs.WorkerDoFn( - serialized_fn=unpack_function_spec_data(transform.function_spec), - output_tags=output_tags, - input=None, - side_inputs=[create_side_input(tag, si) - for tag, si in transform.side_inputs.items()], - output_coders=[coders[transform.outputs[out].coder_reference] - for out in output_tags]) - - op = operations.DoOperation(operation_name, spec, counter_factory, - state_sampler) - # TODO(robertwb): Move these to the constructor. - op.step_name = transform.step_name - for tag, op_consumers in consumers[transform.id].items(): - for consumer in op_consumers: - op.add_receiver( - consumer, output_tags.index(tag)) - - elif transform.function_spec.urn == IDENTITY_DOFN_URN: - op = operations.FlattenOperation(operation_name, None, counter_factory, - state_sampler) - # TODO(robertwb): Move these to the constructor. - op.step_name = transform.step_name - for tag, op_consumers in consumers[transform.id].items(): - for consumer in op_consumers: - op.add_receiver(consumer, 0) - - elif transform.function_spec.urn == PYTHON_SOURCE_URN: - source = load_compressed(unpack_function_spec_data( - transform.function_spec)) - # TODO(vikasrk): Remove this once custom source is implemented with - # splittable dofn via the data plane. - spec = operation_specs.WorkerRead( - iobase.SourceBundle(1.0, source, None, None), - [WindowedValueCoder(source.default_output_coder())]) - op = operations.ReadOperation(operation_name, spec, counter_factory, - state_sampler) - op.step_name = transform.step_name - output_tags = list(transform.outputs.keys()) - for tag, op_consumers in consumers[transform.id].items(): - for consumer in op_consumers: - op.add_receiver( - consumer, output_tags.index(tag)) - - else: - raise NotImplementedError - - # Record consumers. - for _, inputs in transform.inputs.items(): - for target in inputs.target: - consumers[target.primitive_transform_reference][target.name].append( - op) - - reversed_ops.append(op) - ops_by_id[transform.id] = op - - return list(reversed(reversed_ops)) - def process_bundle(self, request, instruction_id): ops = self.create_execution_tree( self.fns[request.process_bundle_descriptor_reference]) From a882e8f3a33c4a430f55d53b65285123c5a4f50d Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Thu, 22 Jun 2017 12:46:13 -0700 Subject: [PATCH 100/200] Remove unused (and untested) initial splittling logic. --- .../runners/portability/fn_api_runner.py | 1 - .../apache_beam/runners/worker/sdk_worker.py | 51 ------------ .../runners/worker/sdk_worker_test.py | 77 ------------------- 3 files changed, 129 deletions(-) diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner.py b/sdks/python/apache_beam/runners/portability/fn_api_runner.py index b45ff76f7e27c..a8e2eb4573a1f 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner.py @@ -19,7 +19,6 @@ """ import base64 import collections -import json import logging import Queue as queue import threading diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker.py b/sdks/python/apache_beam/runners/worker/sdk_worker.py index d1359848d6417..6a366ebcb2be1 100644 --- a/sdks/python/apache_beam/runners/worker/sdk_worker.py +++ b/sdks/python/apache_beam/runners/worker/sdk_worker.py @@ -28,9 +28,7 @@ import Queue as queue import threading import traceback -import zlib -import dill from google.protobuf import wrappers_pb2 from apache_beam.coders import coder_impl @@ -165,37 +163,6 @@ def __iter__(self): yield self._coder.get_impl().decode_from_stream(input_stream, True) -def unpack_and_deserialize_py_fn(function_spec): - """Returns unpacked and deserialized object from function spec proto.""" - return pickler.loads(unpack_function_spec_data(function_spec)) - - -def unpack_function_spec_data(function_spec): - """Returns unpacked data from function spec proto.""" - data = wrappers_pb2.BytesValue() - function_spec.data.Unpack(data) - return data.value - - -# pylint: disable=redefined-builtin -def serialize_and_pack_py_fn(fn, urn, id=None): - """Returns serialized and packed function in a function spec proto.""" - return pack_function_spec_data(pickler.dumps(fn), urn, id) -# pylint: enable=redefined-builtin - - -# pylint: disable=redefined-builtin -def pack_function_spec_data(value, urn, id=None): - """Returns packed data in a function spec proto.""" - data = wrappers_pb2.BytesValue(value=value) - fn_proto = beam_fn_api_pb2.FunctionSpec(urn=urn) - fn_proto.data.Pack(data) - if id: - fn_proto.id = id - return fn_proto -# pylint: enable=redefined-builtin - - def memoize(func): cache = {} missing = object() @@ -286,24 +253,6 @@ def register(self, request, unused_instruction_id=None): self.fns[p_transform.function_spec.id] = p_transform.function_spec return beam_fn_api_pb2.RegisterResponse() - def initial_source_split(self, request, unused_instruction_id=None): - source_spec = self.fns[request.source_reference] - assert source_spec.urn == PYTHON_SOURCE_URN - source_bundle = unpack_and_deserialize_py_fn( - self.fns[request.source_reference]) - splits = source_bundle.source.split(request.desired_bundle_size_bytes, - source_bundle.start_position, - source_bundle.stop_position) - response = beam_fn_api_pb2.InitialSourceSplitResponse() - response.splits.extend([ - beam_fn_api_pb2.SourceSplit( - source=serialize_and_pack_py_fn(split, PYTHON_SOURCE_URN), - relative_size=split.weight, - ) - for split in splits - ]) - return response - def create_execution_tree(self, descriptor): # TODO(robertwb): Figure out the correct prefix to use for output counters # from StateSampler. diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker_test.py b/sdks/python/apache_beam/runners/worker/sdk_worker_test.py index c431bcdf24576..553d5b86cbadc 100644 --- a/sdks/python/apache_beam/runners/worker/sdk_worker_test.py +++ b/sdks/python/apache_beam/runners/worker/sdk_worker_test.py @@ -27,10 +27,7 @@ from concurrent import futures import grpc -from apache_beam.io.concat_source_test import RangeSource -from apache_beam.io.iobase import SourceBundle from apache_beam.portability.api import beam_fn_api_pb2 -from apache_beam.runners.worker import data_plane from apache_beam.runners.worker import sdk_worker @@ -88,80 +85,6 @@ def test_fn_registration(self): harness.worker.fns, {item.id: item for item in fns + process_bundle_descriptors}) - @unittest.skip("initial splitting not in proto") - def test_source_split(self): - source = RangeSource(0, 100) - expected_splits = list(source.split(30)) - - worker = sdk_harness.SdkWorker( - None, data_plane.GrpcClientDataChannelFactory()) - worker.register( - beam_fn_api_pb2.RegisterRequest( - process_bundle_descriptor=[beam_fn_api_pb2.ProcessBundleDescriptor( - primitive_transform=[beam_fn_api_pb2.PrimitiveTransform( - function_spec=sdk_harness.serialize_and_pack_py_fn( - SourceBundle(1.0, source, None, None), - sdk_harness.PYTHON_SOURCE_URN, - id="src"))])])) - split_response = worker.initial_source_split( - beam_fn_api_pb2.InitialSourceSplitRequest( - desired_bundle_size_bytes=30, - source_reference="src")) - - self.assertEqual( - expected_splits, - [sdk_harness.unpack_and_deserialize_py_fn(s.source) - for s in split_response.splits]) - - self.assertEqual( - [s.weight for s in expected_splits], - [s.relative_size for s in split_response.splits]) - - @unittest.skip("initial splitting not in proto") - def test_source_split_via_instruction(self): - - source = RangeSource(0, 100) - expected_splits = list(source.split(30)) - - test_controller = BeamFnControlServicer([ - beam_fn_api_pb2.InstructionRequest( - instruction_id="register_request", - register=beam_fn_api_pb2.RegisterRequest( - process_bundle_descriptor=[ - beam_fn_api_pb2.ProcessBundleDescriptor( - primitive_transform=[beam_fn_api_pb2.PrimitiveTransform( - function_spec=sdk_harness.serialize_and_pack_py_fn( - SourceBundle(1.0, source, None, None), - sdk_harness.PYTHON_SOURCE_URN, - id="src"))])])), - beam_fn_api_pb2.InstructionRequest( - instruction_id="split_request", - initial_source_split=beam_fn_api_pb2.InitialSourceSplitRequest( - desired_bundle_size_bytes=30, - source_reference="src")) - ]) - - server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) - beam_fn_api_pb2.add_BeamFnControlServicer_to_server(test_controller, server) - test_port = server.add_insecure_port("[::]:0") - server.start() - - channel = grpc.insecure_channel("localhost:%s" % test_port) - harness = sdk_harness.SdkHarness(channel) - harness.run() - - split_response = test_controller.responses[ - "split_request"].initial_source_split - - self.assertEqual( - expected_splits, - [sdk_harness.unpack_and_deserialize_py_fn(s.source) - for s in split_response.splits]) - - self.assertEqual( - [s.weight for s in expected_splits], - [s.relative_size for s in split_response.splits]) - if __name__ == "__main__": logging.getLogger().setLevel(logging.INFO) From 6f12e7d3d6a9fbd0d7bc1a6136542bd503cb0f2b Mon Sep 17 00:00:00 2001 From: Eugene Kirpichov Date: Thu, 22 Jun 2017 14:59:36 -0700 Subject: [PATCH 101/200] Bump Dataflow worker to 0622 --- runners/google-cloud-dataflow-java/pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/runners/google-cloud-dataflow-java/pom.xml b/runners/google-cloud-dataflow-java/pom.xml index d1bce32b9c5b4..fbb0b87fc6204 100644 --- a/runners/google-cloud-dataflow-java/pom.xml +++ b/runners/google-cloud-dataflow-java/pom.xml @@ -33,7 +33,7 @@ jar - beam-master-20170619 + beam-master-20170622 1 6 From 799173fac4e07dab4547ba21971922336cd72c62 Mon Sep 17 00:00:00 2001 From: Eugene Kirpichov Date: Wed, 21 Jun 2017 15:21:32 -0700 Subject: [PATCH 102/200] Uses KV in SplittableParDo expansion instead of ElementAndRestriction This is a workaround for the following issue. ElementAndRestriction is in runners-core, which may be shaded by runners (and is shaded by Dataflow runner), hence it should be *both* produced and consumed by workers - but currently it's produced by (shaded) SplittableParDo and consumed by (differently shaded) ProcessFn in the runner's worker code. There are several ways out of this, e.g. moving EAR into the SDK (icky because it's an implementation detail of SplittableParDo), or using a type that's already in the SDK. There may be other more complicated ways too. --- .../construction/ElementAndRestriction.java | 42 ------ .../ElementAndRestrictionCoder.java | 88 ------------ .../core/construction/SplittableParDo.java | 39 +++--- .../ElementAndRestrictionCoderTest.java | 126 ------------------ .../beam/runners/core/ProcessFnRunner.java | 16 +-- .../SplittableParDoViaKeyedWorkItems.java | 49 +++---- .../core/SplittableParDoProcessFnTest.java | 16 +-- ...ttableProcessElementsEvaluatorFactory.java | 37 +++-- .../FlinkStreamingTransformTranslators.java | 19 +-- .../streaming/SplittableDoFnOperator.java | 16 +-- 10 files changed, 77 insertions(+), 371 deletions(-) delete mode 100644 runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ElementAndRestriction.java delete mode 100644 runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ElementAndRestrictionCoder.java delete mode 100644 runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/ElementAndRestrictionCoderTest.java diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ElementAndRestriction.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ElementAndRestriction.java deleted file mode 100644 index 53a86b1d4f993..0000000000000 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ElementAndRestriction.java +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.core.construction; - -import com.google.auto.value.AutoValue; -import org.apache.beam.sdk.annotations.Experimental; -import org.apache.beam.sdk.transforms.DoFn; - -/** - * A tuple of an element and a restriction applied to processing it with a - * splittable {@link DoFn}. - */ -@Experimental(Experimental.Kind.SPLITTABLE_DO_FN) -@AutoValue -public abstract class ElementAndRestriction { - /** The element to process. */ - public abstract ElementT element(); - - /** The restriction applied to processing the element. */ - public abstract RestrictionT restriction(); - - /** Constructs the {@link ElementAndRestriction}. */ - public static ElementAndRestriction of( - InputT element, RestrictionT restriction) { - return new AutoValue_ElementAndRestriction<>(element, restriction); - } -} diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ElementAndRestrictionCoder.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ElementAndRestrictionCoder.java deleted file mode 100644 index 5ff0aaead9eb4..0000000000000 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ElementAndRestrictionCoder.java +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.core.construction; - -import com.google.common.collect.ImmutableList; -import java.io.IOException; -import java.io.InputStream; -import java.io.OutputStream; -import java.util.List; -import org.apache.beam.sdk.annotations.Experimental; -import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.coders.CoderException; -import org.apache.beam.sdk.coders.StructuredCoder; - -/** A {@link Coder} for {@link ElementAndRestriction}. */ -@Experimental(Experimental.Kind.SPLITTABLE_DO_FN) -public class ElementAndRestrictionCoder - extends StructuredCoder> { - private final Coder elementCoder; - private final Coder restrictionCoder; - - /** - * Creates an {@link ElementAndRestrictionCoder} from an element coder and a restriction coder. - */ - public static ElementAndRestrictionCoder of( - Coder elementCoder, Coder restrictionCoder) { - return new ElementAndRestrictionCoder<>(elementCoder, restrictionCoder); - } - - private ElementAndRestrictionCoder( - Coder elementCoder, Coder restrictionCoder) { - this.elementCoder = elementCoder; - this.restrictionCoder = restrictionCoder; - } - - @Override - public void encode( - ElementAndRestriction value, OutputStream outStream) - throws IOException { - if (value == null) { - throw new CoderException("cannot encode a null ElementAndRestriction"); - } - elementCoder.encode(value.element(), outStream); - restrictionCoder.encode(value.restriction(), outStream); - } - - @Override - public ElementAndRestriction decode(InputStream inStream) - throws IOException { - ElementT key = elementCoder.decode(inStream); - RestrictionT value = restrictionCoder.decode(inStream); - return ElementAndRestriction.of(key, value); - } - - @Override - public List> getCoderArguments() { - return ImmutableList.of(elementCoder, restrictionCoder); - } - - @Override - public void verifyDeterministic() throws NonDeterministicException { - elementCoder.verifyDeterministic(); - restrictionCoder.verifyDeterministic(); - } - - public Coder getElementCoder() { - return elementCoder; - } - - public Coder getRestrictionCoder() { - return restrictionCoder; - } -} diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDo.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDo.java index 665e39d9c3905..5ccafcbc8ea17 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDo.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDo.java @@ -25,6 +25,7 @@ import org.apache.beam.runners.core.construction.PTransformTranslation.RawPTransform; import org.apache.beam.sdk.annotations.Experimental; import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; @@ -54,7 +55,7 @@ *

  • Explode windows, since splitting within each window has to happen independently *
  • Assign a unique key to each element/restriction pair *
  • Process the keyed element/restriction pairs in a runner-specific way with the splittable - * {@link DoFn}'s {@link DoFn.ProcessElement} method. + * {@link DoFn}'s {@link DoFn.ProcessElement} method. * * *

    This transform is intended as a helper for internal use by runners when implementing {@code @@ -93,10 +94,9 @@ public PCollectionTuple expand(PCollection input) { Coder restrictionCoder = DoFnInvokers.invokerFor(fn) .invokeGetRestrictionCoder(input.getPipeline().getCoderRegistry()); - Coder> splitCoder = - ElementAndRestrictionCoder.of(input.getCoder(), restrictionCoder); + Coder> splitCoder = KvCoder.of(input.getCoder(), restrictionCoder); - PCollection>> keyedRestrictions = + PCollection>> keyedRestrictions = input .apply( "Pair with initial restriction", @@ -107,12 +107,10 @@ public PCollectionTuple expand(PCollection input) { // ProcessFn requires all input elements to be in a single window and have a single // element per work item. This must precede the unique keying so each key has a single // associated element. - .apply( - "Explode windows", - ParDo.of(new ExplodeWindowsFn>())) + .apply("Explode windows", ParDo.of(new ExplodeWindowsFn>())) .apply( "Assign unique key", - WithKeys.of(new RandomUniqueKeyFn>())); + WithKeys.of(new RandomUniqueKeyFn>())); return keyedRestrictions.apply( "ProcessKeyedElements", @@ -140,12 +138,11 @@ public void process(ProcessContext c, BoundedWindow window) { /** * Runner-specific primitive {@link PTransform} that invokes the {@link DoFn.ProcessElement} - * method for a splittable {@link DoFn} on each {@link ElementAndRestriction} of the input {@link - * PCollection} of {@link KV KVs} keyed with arbitrary but globally unique keys. + * method for a splittable {@link DoFn} on each {@link KV} of the input {@link PCollection} of + * {@link KV KVs} keyed with arbitrary but globally unique keys. */ public static class ProcessKeyedElements - extends RawPTransform< - PCollection>>, PCollectionTuple> { + extends RawPTransform>>, PCollectionTuple> { private final DoFn fn; private final Coder elementCoder; private final Coder restrictionCoder; @@ -208,9 +205,7 @@ public TupleTagList getAdditionalOutputTags() { } @Override - public PCollectionTuple expand( - PCollection>> - input) { + public PCollectionTuple expand(PCollection>> input) { return createPrimitiveOutputFor( input, fn, mainOutputTag, additionalOutputTags, windowingStrategy); } @@ -257,7 +252,7 @@ public String apply(T input) { * Pairs each input element with its initial restriction using the given splittable {@link DoFn}. */ private static class PairWithRestrictionFn - extends DoFn> { + extends DoFn> { private DoFn fn; private transient DoFnInvoker invoker; @@ -273,7 +268,7 @@ public void setup() { @ProcessElement public void processElement(ProcessContext context) { context.output( - ElementAndRestriction.of( + KV.of( context.element(), invoker.invokeGetInitialRestriction(context.element()))); } @@ -281,9 +276,7 @@ public void processElement(ProcessContext context) { /** Splits the restriction using the given {@link SplitRestriction} method. */ private static class SplitRestrictionFn - extends DoFn< - ElementAndRestriction, - ElementAndRestriction> { + extends DoFn, KV> { private final DoFn splittableFn; private transient DoFnInvoker invoker; @@ -298,14 +291,14 @@ public void setup() { @ProcessElement public void processElement(final ProcessContext c) { - final InputT element = c.element().element(); + final InputT element = c.element().getKey(); invoker.invokeSplitRestriction( element, - c.element().restriction(), + c.element().getValue(), new OutputReceiver() { @Override public void output(RestrictionT part) { - c.output(ElementAndRestriction.of(element, part)); + c.output(KV.of(element, part)); } }); } diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/ElementAndRestrictionCoderTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/ElementAndRestrictionCoderTest.java deleted file mode 100644 index 051cbaa0008a8..0000000000000 --- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/ElementAndRestrictionCoderTest.java +++ /dev/null @@ -1,126 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.core.construction; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.Collections; -import java.util.List; -import org.apache.beam.sdk.coders.BigEndianLongCoder; -import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.coders.CoderException; -import org.apache.beam.sdk.coders.ListCoder; -import org.apache.beam.sdk.coders.StringUtf8Coder; -import org.apache.beam.sdk.coders.VarIntCoder; -import org.apache.beam.sdk.coders.VarLongCoder; -import org.apache.beam.sdk.testing.CoderProperties; -import org.apache.beam.sdk.util.CoderUtils; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.ExpectedException; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.junit.runners.Parameterized.Parameter; - -/** - * Tests for {@link ElementAndRestrictionCoder}. - */ -@RunWith(Parameterized.class) -public class ElementAndRestrictionCoderTest { - private static class CoderAndData { - Coder coder; - List data; - } - - private static class AnyCoderAndData { - private CoderAndData coderAndData; - } - - private static AnyCoderAndData coderAndData(Coder coder, List data) { - CoderAndData coderAndData = new CoderAndData<>(); - coderAndData.coder = coder; - coderAndData.data = data; - AnyCoderAndData res = new AnyCoderAndData(); - res.coderAndData = coderAndData; - return res; - } - - private static final List TEST_DATA = - Arrays.asList( - coderAndData( - VarIntCoder.of(), Arrays.asList(-1, 0, 1, 13, Integer.MAX_VALUE, Integer.MIN_VALUE)), - coderAndData( - BigEndianLongCoder.of(), - Arrays.asList(-1L, 0L, 1L, 13L, Long.MAX_VALUE, Long.MIN_VALUE)), - coderAndData(StringUtf8Coder.of(), Arrays.asList("", "hello", "goodbye", "1")), - coderAndData( - ElementAndRestrictionCoder.of(StringUtf8Coder.of(), VarIntCoder.of()), - Arrays.asList( - ElementAndRestriction.of("", -1), - ElementAndRestriction.of("hello", 0), - ElementAndRestriction.of("goodbye", Integer.MAX_VALUE))), - coderAndData( - ListCoder.of(VarLongCoder.of()), - Arrays.asList(Arrays.asList(1L, 2L, 3L), Collections.emptyList()))); - - @Parameterized.Parameters(name = "{index}: keyCoder={0} key={1} valueCoder={2} value={3}") - public static Collection data() { - List parameters = new ArrayList<>(); - for (AnyCoderAndData keyCoderAndData : TEST_DATA) { - Coder keyCoder = keyCoderAndData.coderAndData.coder; - for (Object key : keyCoderAndData.coderAndData.data) { - for (AnyCoderAndData valueCoderAndData : TEST_DATA) { - Coder valueCoder = valueCoderAndData.coderAndData.coder; - for (Object value : valueCoderAndData.coderAndData.data) { - parameters.add(new Object[] {keyCoder, key, valueCoder, value}); - } - } - } - } - return parameters; - } - - @Parameter(0) - public Coder keyCoder; - @Parameter(1) - public K key; - @Parameter(2) - public Coder valueCoder; - @Parameter(3) - public V value; - - @Test - @SuppressWarnings("rawtypes") - public void testDecodeEncodeEqual() throws Exception { - CoderProperties.coderDecodeEncodeEqual( - ElementAndRestrictionCoder.of(keyCoder, valueCoder), - ElementAndRestriction.of(key, value)); - } - - @Rule public ExpectedException thrown = ExpectedException.none(); - - @Test - public void encodeNullThrowsCoderException() throws Exception { - thrown.expect(CoderException.class); - thrown.expectMessage("cannot encode a null ElementAndRestriction"); - - CoderUtils.encodeToBase64( - ElementAndRestrictionCoder.of(StringUtf8Coder.of(), VarIntCoder.of()), null); - } -} diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/ProcessFnRunner.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/ProcessFnRunner.java index 31e86bdd0dc55..88275d69726a9 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/ProcessFnRunner.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/ProcessFnRunner.java @@ -24,11 +24,11 @@ import java.util.Collections; import org.apache.beam.runners.core.StateNamespaces.WindowNamespace; import org.apache.beam.runners.core.TimerInternals.TimerData; -import org.apache.beam.runners.core.construction.ElementAndRestriction; import org.apache.beam.sdk.state.TimeDomain; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.GlobalWindow; import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollectionView; import org.joda.time.Instant; @@ -38,16 +38,13 @@ */ public class ProcessFnRunner implements PushbackSideInputDoFnRunner< - KeyedWorkItem>, OutputT> { - private final DoFnRunner< - KeyedWorkItem>, OutputT> - underlying; + KeyedWorkItem>, OutputT> { + private final DoFnRunner>, OutputT> underlying; private final Collection> views; private final ReadyCheckingSideInputReader sideInputReader; ProcessFnRunner( - DoFnRunner>, OutputT> - underlying, + DoFnRunner>, OutputT> underlying, Collection> views, ReadyCheckingSideInputReader sideInputReader) { this.underlying = underlying; @@ -61,10 +58,9 @@ public void startBundle() { } @Override - public Iterable>>> + public Iterable>>> processElementInReadyWindows( - WindowedValue>> - windowedKWI) { + WindowedValue>> windowedKWI) { checkTrivialOuterWindows(windowedKWI); BoundedWindow window = getUnderlyingWindow(windowedKWI.getValue()); if (!isReady(window)) { diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/SplittableParDoViaKeyedWorkItems.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/SplittableParDoViaKeyedWorkItems.java index c4b086a410a7b..09f3b157f7be3 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/SplittableParDoViaKeyedWorkItems.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/SplittableParDoViaKeyedWorkItems.java @@ -21,7 +21,6 @@ import com.google.common.collect.Iterables; import java.util.List; import java.util.Map; -import org.apache.beam.runners.core.construction.ElementAndRestriction; import org.apache.beam.runners.core.construction.PTransformReplacements; import org.apache.beam.runners.core.construction.PTransformTranslation.RawPTransform; import org.apache.beam.runners.core.construction.ReplacementOutputs; @@ -86,15 +85,15 @@ public String getUrn() { /** Overrides a {@link ProcessKeyedElements} into {@link SplittableProcessViaKeyedWorkItems}. */ public static class OverrideFactory implements PTransformOverrideFactory< - PCollection>>, PCollectionTuple, - ProcessKeyedElements> { + PCollection>>, PCollectionTuple, + ProcessKeyedElements> { @Override public PTransformReplacement< - PCollection>>, PCollectionTuple> + PCollection>>, PCollectionTuple> getReplacementTransform( AppliedPTransform< - PCollection>>, - PCollectionTuple, ProcessKeyedElements> + PCollection>>, PCollectionTuple, + ProcessKeyedElements> transform) { return PTransformReplacement.of( PTransformReplacements.getSingletonMainInput(transform), @@ -113,8 +112,7 @@ public Map mapOutputs( * method for a splittable {@link DoFn}. */ public static class SplittableProcessViaKeyedWorkItems - extends PTransform< - PCollection>>, PCollectionTuple> { + extends PTransform>>, PCollectionTuple> { private final ProcessKeyedElements original; public SplittableProcessViaKeyedWorkItems( @@ -123,15 +121,13 @@ public SplittableProcessViaKeyedWorkItems( } @Override - public PCollectionTuple expand( - PCollection>> input) { + public PCollectionTuple expand(PCollection>> input) { return input - .apply(new GBKIntoKeyedWorkItems>()) + .apply(new GBKIntoKeyedWorkItems>()) .setCoder( KeyedWorkItemCoder.of( StringUtf8Coder.of(), - ((KvCoder>) input.getCoder()) - .getValueCoder(), + ((KvCoder>) input.getCoder()).getValueCoder(), input.getWindowingStrategy().getWindowFn().windowCoder())) .apply(new ProcessElements<>(original)); } @@ -141,8 +137,7 @@ public PCollectionTuple expand( public static class ProcessElements< InputT, OutputT, RestrictionT, TrackerT extends RestrictionTracker> extends PTransform< - PCollection>>, - PCollectionTuple> { + PCollection>>, PCollectionTuple> { private final ProcessKeyedElements original; public ProcessElements(ProcessKeyedElements original) { @@ -176,7 +171,7 @@ public TupleTagList getAdditionalOutputTags() { @Override public PCollectionTuple expand( - PCollection>> input) { + PCollection>> input) { return ProcessKeyedElements.createPrimitiveOutputFor( input, original.getFn(), @@ -201,7 +196,7 @@ public PCollectionTuple expand( @VisibleForTesting public static class ProcessFn< InputT, OutputT, RestrictionT, TrackerT extends RestrictionTracker> - extends DoFn>, OutputT> { + extends DoFn>, OutputT> { /** * The state cell containing a watermark hold for the output of this {@link DoFn}. The hold is * acquired during the first {@link DoFn.ProcessElement} call for each element and restriction, @@ -321,7 +316,7 @@ public void processElement(final ProcessContext c) { boolean isSeedCall = (timer == null); StateNamespace stateNamespace; if (isSeedCall) { - WindowedValue> windowedValue = + WindowedValue> windowedValue = Iterables.getOnlyElement(c.element().elementsIterable()); BoundedWindow window = Iterables.getOnlyElement(windowedValue.getWindows()); stateNamespace = @@ -337,27 +332,25 @@ public void processElement(final ProcessContext c) { stateInternals.state(stateNamespace, restrictionTag); WatermarkHoldState holdState = stateInternals.state(stateNamespace, watermarkHoldTag); - ElementAndRestriction, RestrictionT> elementAndRestriction; + KV, RestrictionT> elementAndRestriction; if (isSeedCall) { - WindowedValue> windowedValue = + WindowedValue> windowedValue = Iterables.getOnlyElement(c.element().elementsIterable()); - WindowedValue element = windowedValue.withValue(windowedValue.getValue().element()); + WindowedValue element = windowedValue.withValue(windowedValue.getValue().getKey()); elementState.write(element); - elementAndRestriction = - ElementAndRestriction.of(element, windowedValue.getValue().restriction()); + elementAndRestriction = KV.of(element, windowedValue.getValue().getValue()); } else { // This is not the first ProcessElement call for this element/restriction - rather, // this is a timer firing, so we need to fetch the element and restriction from state. elementState.readLater(); restrictionState.readLater(); - elementAndRestriction = - ElementAndRestriction.of(elementState.read(), restrictionState.read()); + elementAndRestriction = KV.of(elementState.read(), restrictionState.read()); } - final TrackerT tracker = invoker.invokeNewTracker(elementAndRestriction.restriction()); + final TrackerT tracker = invoker.invokeNewTracker(elementAndRestriction.getValue()); SplittableProcessElementInvoker.Result result = processElementInvoker.invokeProcessElement( - invoker, elementAndRestriction.element(), tracker); + invoker, elementAndRestriction.getKey(), tracker); // Save state for resuming. if (result.getResidualRestriction() == null) { @@ -370,7 +363,7 @@ public void processElement(final ProcessContext c) { restrictionState.write(result.getResidualRestriction()); Instant futureOutputWatermark = result.getFutureOutputWatermark(); if (futureOutputWatermark == null) { - futureOutputWatermark = elementAndRestriction.element().getTimestamp(); + futureOutputWatermark = elementAndRestriction.getKey().getTimestamp(); } holdState.add(futureOutputWatermark); // Set a timer to continue processing this element. diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/SplittableParDoProcessFnTest.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/SplittableParDoProcessFnTest.java index d2424318eea56..9543de8c61a00 100644 --- a/runners/core-java/src/test/java/org/apache/beam/runners/core/SplittableParDoProcessFnTest.java +++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/SplittableParDoProcessFnTest.java @@ -35,7 +35,6 @@ import java.util.NoSuchElementException; import java.util.concurrent.Executors; import org.apache.beam.runners.core.SplittableParDoViaKeyedWorkItems.ProcessFn; -import org.apache.beam.runners.core.construction.ElementAndRestriction; import org.apache.beam.sdk.coders.BigEndianIntegerCoder; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.InstantCoder; @@ -53,6 +52,7 @@ import org.apache.beam.sdk.transforms.windowing.IntervalWindow; import org.apache.beam.sdk.transforms.windowing.PaneInfo; import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.TimestampedValue; import org.apache.beam.sdk.values.TupleTag; @@ -111,9 +111,7 @@ public void checkDone() {} private static class ProcessFnTester< InputT, OutputT, RestrictionT, TrackerT extends RestrictionTracker> implements AutoCloseable { - private final DoFnTester< - KeyedWorkItem>, OutputT> - tester; + private final DoFnTester>, OutputT> tester; private Instant currentProcessingTime; private InMemoryTimerInternals timerInternals; @@ -194,14 +192,13 @@ public void close() throws Exception { void startElement(InputT element, RestrictionT restriction) throws Exception { startElement( WindowedValue.of( - ElementAndRestriction.of(element, restriction), + KV.of(element, restriction), currentProcessingTime, GlobalWindow.INSTANCE, PaneInfo.ON_TIME_AND_ONLY_FIRING)); } - void startElement(WindowedValue> windowedValue) - throws Exception { + void startElement(WindowedValue> windowedValue) throws Exception { tester.processElement( KeyedWorkItems.elementsWorkItem("key", Collections.singletonList(windowedValue))); } @@ -223,8 +220,7 @@ boolean advanceProcessingTimeBy(Duration duration) throws Exception { return false; } tester.processElement( - KeyedWorkItems.>timersWorkItem( - "key", timers)); + KeyedWorkItems.>timersWorkItem("key", timers)); return true; } @@ -309,7 +305,7 @@ public void testTrivialProcessFnPropagatesOutputWindowAndTimestamp() throws Exce MAX_BUNDLE_DURATION); tester.startElement( WindowedValue.of( - ElementAndRestriction.of(42, new SomeRestriction()), + KV.of(42, new SomeRestriction()), base, Collections.singletonList(w), PaneInfo.ON_TIME_AND_ONLY_FIRING)); diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/SplittableProcessElementsEvaluatorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/SplittableProcessElementsEvaluatorFactory.java index eccc83a031cb2..e6b51b79ce7f3 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/SplittableProcessElementsEvaluatorFactory.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/SplittableProcessElementsEvaluatorFactory.java @@ -35,7 +35,6 @@ import org.apache.beam.runners.core.StateInternalsFactory; import org.apache.beam.runners.core.TimerInternals; import org.apache.beam.runners.core.TimerInternalsFactory; -import org.apache.beam.runners.core.construction.ElementAndRestriction; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.runners.AppliedPTransform; import org.apache.beam.sdk.transforms.DoFn; @@ -43,6 +42,7 @@ import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.PaneInfo; import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionTuple; import org.apache.beam.sdk.values.PCollectionView; @@ -54,8 +54,7 @@ class SplittableProcessElementsEvaluatorFactory< InputT, OutputT, RestrictionT, TrackerT extends RestrictionTracker> implements TransformEvaluatorFactory { - private final ParDoEvaluatorFactory< - KeyedWorkItem>, OutputT> + private final ParDoEvaluatorFactory>, OutputT> delegateFactory; private final EvaluationContext evaluationContext; @@ -84,14 +83,13 @@ public void cleanup() throws Exception { } @SuppressWarnings({"unchecked", "rawtypes"}) - private TransformEvaluator>> - createEvaluator( - AppliedPTransform< - PCollection>>, - PCollectionTuple, ProcessElements> - application, - CommittedBundle inputBundle) - throws Exception { + private TransformEvaluator>> createEvaluator( + AppliedPTransform< + PCollection>>, PCollectionTuple, + ProcessElements> + application, + CommittedBundle inputBundle) + throws Exception { final ProcessElements transform = application.getTransform(); @@ -101,9 +99,7 @@ public void cleanup() throws Exception { DoFnLifecycleManager fnManager = DoFnLifecycleManager.of(processFn); processFn = ((ProcessFn) - fnManager - .>, OutputT> - get()); + fnManager.>, OutputT>get()); String stepName = evaluationContext.getStepName(application); final DirectExecutionContext.DirectStepContext stepContext = @@ -111,12 +107,12 @@ public void cleanup() throws Exception { .getExecutionContext(application, inputBundle.getKey()) .getStepContext(stepName); - final ParDoEvaluator>> + final ParDoEvaluator>> parDoEvaluator = delegateFactory.createParDoEvaluator( application, inputBundle.getKey(), - (PCollection>>) + (PCollection>>) inputBundle.getPCollection(), transform.getSideInputs(), transform.getMainOutputTag(), @@ -189,17 +185,16 @@ public void outputWindowedValue( } private static - ParDoEvaluator.DoFnRunnerFactory< - KeyedWorkItem>, OutputT> + ParDoEvaluator.DoFnRunnerFactory>, OutputT> processFnRunnerFactory() { return new ParDoEvaluator.DoFnRunnerFactory< - KeyedWorkItem>, OutputT>() { + KeyedWorkItem>, OutputT>() { @Override public PushbackSideInputDoFnRunner< - KeyedWorkItem>, OutputT> + KeyedWorkItem>, OutputT> createRunner( PipelineOptions options, - DoFn>, OutputT> fn, + DoFn>, OutputT> fn, List> sideInputs, ReadyCheckingSideInputReader sideInputReader, OutputManager outputManager, diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java index fef32de77edfa..3d7e81f0584f9 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java @@ -29,7 +29,6 @@ import org.apache.beam.runners.core.KeyedWorkItem; import org.apache.beam.runners.core.SplittableParDoViaKeyedWorkItems; import org.apache.beam.runners.core.SystemReduceFn; -import org.apache.beam.runners.core.construction.ElementAndRestriction; import org.apache.beam.runners.flink.translation.functions.FlinkAssignWindows; import org.apache.beam.runners.flink.translation.types.CoderTypeInformation; import org.apache.beam.runners.flink.translation.wrappers.streaming.DoFnOperator; @@ -548,14 +547,11 @@ public void translateNode( transform.getAdditionalOutputTags().getAll(), context, new ParDoTranslationHelper.DoFnOperatorFactory< - KeyedWorkItem>, OutputT>() { + KeyedWorkItem>, OutputT>() { @Override - public DoFnOperator< - KeyedWorkItem>, - OutputT> createDoFnOperator( - DoFn< - KeyedWorkItem>, - OutputT> doFn, + public DoFnOperator>, OutputT> + createDoFnOperator( + DoFn>, OutputT> doFn, String stepName, List> sideInputs, TupleTag mainOutputTag, @@ -563,11 +559,8 @@ OutputT> createDoFnOperator( FlinkStreamingTranslationContext context, WindowingStrategy windowingStrategy, Map, OutputTag>> tagsToOutputTags, - Coder< - WindowedValue< - KeyedWorkItem< - String, - ElementAndRestriction>>> inputCoder, + Coder>>> + inputCoder, Coder keyCoder, Map> transformedSideInputs) { return new SplittableDoFnOperator<>( diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/SplittableDoFnOperator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/SplittableDoFnOperator.java index 5d08eba96a28e..2f095d481694e 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/SplittableDoFnOperator.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/SplittableDoFnOperator.java @@ -35,7 +35,6 @@ import org.apache.beam.runners.core.StateInternalsFactory; import org.apache.beam.runners.core.TimerInternals; import org.apache.beam.runners.core.TimerInternalsFactory; -import org.apache.beam.runners.core.construction.ElementAndRestriction; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.transforms.DoFn; @@ -43,6 +42,7 @@ import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.PaneInfo; import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.WindowingStrategy; @@ -55,18 +55,15 @@ * the {@code @ProcessElement} method of a splittable {@link DoFn}. */ public class SplittableDoFnOperator< - InputT, OutputT, RestrictionT, TrackerT extends RestrictionTracker> - extends DoFnOperator< - KeyedWorkItem>, OutputT> { + InputT, OutputT, RestrictionT, TrackerT extends RestrictionTracker> + extends DoFnOperator>, OutputT> { private transient ScheduledExecutorService executorService; public SplittableDoFnOperator( - DoFn>, OutputT> doFn, + DoFn>, OutputT> doFn, String stepName, - Coder< - WindowedValue< - KeyedWorkItem>>> inputCoder, + Coder>>> inputCoder, TupleTag mainOutputTag, List> additionalOutputTags, OutputManagerFactory outputManagerFactory, @@ -87,7 +84,6 @@ public SplittableDoFnOperator( sideInputs, options, keyCoder); - } @Override @@ -151,7 +147,7 @@ public void outputWindowedValue( @Override public void fireTimer(InternalTimer timer) { doFnRunner.processElement(WindowedValue.valueInGlobalWindow( - KeyedWorkItems.>timersWorkItem( + KeyedWorkItems.>timersWorkItem( (String) stateInternals.getKey(), Collections.singletonList(timer.getNamespace())))); } From cab4d8969e7f95b0ece59838ad2d578e75d38823 Mon Sep 17 00:00:00 2001 From: Kenneth Knowles Date: Wed, 21 Jun 2017 20:25:31 -0700 Subject: [PATCH 103/200] DataflowRunner: Reject merging windowing for stateful ParDo --- .../dataflow/BatchStatefulParDoOverrides.java | 2 + .../dataflow/DataflowPipelineTranslator.java | 5 ++- .../beam/runners/dataflow/DataflowRunner.java | 10 +++++ .../runners/dataflow/DataflowRunnerTest.java | 38 +++++++++++++++++++ 4 files changed, 54 insertions(+), 1 deletion(-) diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchStatefulParDoOverrides.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchStatefulParDoOverrides.java index 41202db0e4690..7309f6171e6f0 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchStatefulParDoOverrides.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchStatefulParDoOverrides.java @@ -146,6 +146,7 @@ public PCollection expand(PCollection> input) { DoFn, OutputT> fn = originalParDo.getFn(); verifyFnIsStateful(fn); DataflowRunner.verifyStateSupported(fn); + DataflowRunner.verifyStateSupportForWindowingStrategy(input.getWindowingStrategy()); PTransform< PCollection>>>>>, @@ -171,6 +172,7 @@ public PCollectionTuple expand(PCollection> input) { DoFn, OutputT> fn = originalParDo.getFn(); verifyFnIsStateful(fn); DataflowRunner.verifyStateSupported(fn); + DataflowRunner.verifyStateSupportForWindowingStrategy(input.getWindowingStrategy()); PTransform< PCollection>>>>>, diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java index 6d3054407b216..28fd1bb1af02c 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java @@ -972,7 +972,10 @@ private static void translateFn( fn)); } - DataflowRunner.verifyStateSupported(fn); + if (signature.usesState() || signature.usesTimers()) { + DataflowRunner.verifyStateSupported(fn); + DataflowRunner.verifyStateSupportForWindowingStrategy(windowingStrategy); + } stepContext.addInput(PropertyNames.USER_FN, fn.getClass().getName()); stepContext.addInput( diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java index 4d7f6acfef3bb..5d9f0f32aca4a 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java @@ -1542,4 +1542,14 @@ static void verifyStateSupported(DoFn fn) { } } } + + static void verifyStateSupportForWindowingStrategy(WindowingStrategy strategy) { + // https://issues.apache.org/jira/browse/BEAM-2507 + if (!strategy.getWindowFn().isNonMerging()) { + throw new UnsupportedOperationException( + String.format( + "%s does not currently support state or timers with merging windows", + DataflowRunner.class.getSimpleName())); + } + } } diff --git a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java index f57c0ee5ad3ca..bc1a04247c223 100644 --- a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java +++ b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java @@ -93,12 +93,15 @@ import org.apache.beam.sdk.state.SetState; import org.apache.beam.sdk.state.StateSpec; import org.apache.beam.sdk.state.StateSpecs; +import org.apache.beam.sdk.state.ValueState; import org.apache.beam.sdk.testing.ExpectedLogs; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.windowing.Sessions; +import org.apache.beam.sdk.transforms.windowing.Window; import org.apache.beam.sdk.util.GcsUtil; import org.apache.beam.sdk.util.ReleaseInfo; import org.apache.beam.sdk.util.gcsfs.GcsPath; @@ -112,6 +115,7 @@ import org.hamcrest.Description; import org.hamcrest.Matchers; import org.hamcrest.TypeSafeMatcher; +import org.joda.time.Duration; import org.joda.time.Instant; import org.junit.Before; import org.junit.Rule; @@ -127,6 +131,8 @@ /** * Tests for the {@link DataflowRunner}. + * + *

    Implements {@link Serializable} because it is caught in closures. */ @RunWith(JUnit4.class) public class DataflowRunnerTest implements Serializable { @@ -1222,6 +1228,38 @@ public void testStreamingWriteWithNoShardingReturnsNewTransformMaxWorkersUnset() testStreamingWriteOverride(options, StreamingShardedWriteFactory.DEFAULT_NUM_SHARDS); } + private void verifyMergingStatefulParDoRejected(PipelineOptions options) throws Exception { + Pipeline p = Pipeline.create(options); + + p.apply(Create.of(KV.of(13, 42))) + .apply(Window.>into(Sessions.withGapDuration(Duration.millis(1)))) + .apply(ParDo.of(new DoFn, Void>() { + @StateId("fizzle") + private final StateSpec> voidState = StateSpecs.value(); + + @ProcessElement + public void process() {} + })); + + thrown.expectMessage("merging"); + thrown.expect(UnsupportedOperationException.class); + p.run(); + } + + @Test + public void testMergingStatefulRejectedInStreaming() throws Exception { + PipelineOptions options = buildPipelineOptions(); + options.as(StreamingOptions.class).setStreaming(true); + verifyMergingStatefulParDoRejected(options); + } + + @Test + public void testMergingStatefulRejectedInBatch() throws Exception { + PipelineOptions options = buildPipelineOptions(); + options.as(StreamingOptions.class).setStreaming(false); + verifyMergingStatefulParDoRejected(options); + } + private void testStreamingWriteOverride(PipelineOptions options, int expectedNumShards) { TestPipeline p = TestPipeline.fromOptions(options); From 649994b353afe28c917969609c7a1a47a4f39aaf Mon Sep 17 00:00:00 2001 From: Rune Fevang Date: Thu, 15 Jun 2017 13:51:12 +0200 Subject: [PATCH 104/200] Allow output from FinishBundle in DoFnTester --- .../beam/sdk/transforms/DoFnTester.java | 16 ++-------- .../beam/sdk/transforms/DoFnTesterTest.java | 32 +++++++++++++++++++ 2 files changed, 35 insertions(+), 13 deletions(-) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java index 8a03f3c0064a2..4da9a8096f44f 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java @@ -546,11 +546,6 @@ private TestFinishBundleContext() { fn.super(); } - private void throwUnsupportedOutputFromBundleMethods() { - throw new UnsupportedOperationException( - "DoFnTester doesn't support output from bundle methods"); - } - @Override public PipelineOptions getPipelineOptions() { return options; @@ -559,12 +554,13 @@ public PipelineOptions getPipelineOptions() { @Override public void output( OutputT output, Instant timestamp, BoundedWindow window) { - throwUnsupportedOutputFromBundleMethods(); + output(mainOutputTag, output, timestamp, window); } @Override public void output(TupleTag tag, T output, Instant timestamp, BoundedWindow window) { - throwUnsupportedOutputFromBundleMethods(); + getMutableOutput(tag) + .add(ValueInSingleWindow.of(output, timestamp, window, PaneInfo.NO_FIRING)); } } @@ -642,12 +638,6 @@ public void outputWithTimestamp(TupleTag tag, T output, Instant timestamp getMutableOutput(tag) .add(ValueInSingleWindow.of(output, timestamp, element.getWindow(), element.getPane())); } - - private void throwUnsupportedOutputFromBundleMethods() { - throw new UnsupportedOperationException( - "DoFnTester doesn't support output from bundle methods"); - } - } @Override diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/DoFnTesterTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/DoFnTesterTest.java index 1bb71bbf1a180..5cb9e18ca9b13 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/DoFnTesterTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/DoFnTesterTest.java @@ -360,6 +360,38 @@ public void processElement(ProcessContext c, BoundedWindow window) { } } + @Test + public void testSupportsFinishBundleOutput() throws Exception { + for (DoFnTester.CloningBehavior cloning : DoFnTester.CloningBehavior.values()) { + try (DoFnTester tester = DoFnTester.of(new BundleCounterDoFn())) { + tester.setCloningBehavior(cloning); + + assertThat(tester.processBundle(1, 2, 3, 4), contains(4)); + assertThat(tester.processBundle(5, 6, 7), contains(3)); + assertThat(tester.processBundle(8, 9), contains(2)); + } + } + } + + private static class BundleCounterDoFn extends DoFn { + private int elements; + + @StartBundle + public void startBundle() { + elements = 0; + } + + @ProcessElement + public void processElement(ProcessContext c) { + elements++; + } + + @FinishBundle + public void finishBundle(FinishBundleContext c) { + c.output(elements, Instant.now(), GlobalWindow.INSTANCE); + } + } + private static class SideInputDoFn extends DoFn { private final PCollectionView value; From 8dcda6e40355af13f4d92fcd44aae4539a225a4a Mon Sep 17 00:00:00 2001 From: Sourabh Bajaj Date: Thu, 22 Jun 2017 17:08:20 -0700 Subject: [PATCH 105/200] [BEAM-2497] Fix the reading of concat gzip files --- .../examples/snippets/snippets_test.py | 16 ++++++++++++++++ sdks/python/apache_beam/io/filesystem.py | 8 ++++++++ 2 files changed, 24 insertions(+) diff --git a/sdks/python/apache_beam/examples/snippets/snippets_test.py b/sdks/python/apache_beam/examples/snippets/snippets_test.py index 9183d0dfea190..31f71b3bbb0da 100644 --- a/sdks/python/apache_beam/examples/snippets/snippets_test.py +++ b/sdks/python/apache_beam/examples/snippets/snippets_test.py @@ -589,6 +589,22 @@ def test_model_textio_compressed(self): snippets.model_textio_compressed( {'read': gzip_file_name}, ['aa', 'bb', 'cc']) + def test_model_textio_gzip_concatenated(self): + temp_path_1 = self.create_temp_file('a\nb\nc\n') + temp_path_2 = self.create_temp_file('p\nq\nr\n') + temp_path_3 = self.create_temp_file('x\ny\nz') + gzip_file_name = temp_path_1 + '.gz' + with open(temp_path_1) as src, gzip.open(gzip_file_name, 'wb') as dst: + dst.writelines(src) + with open(temp_path_2) as src, gzip.open(gzip_file_name, 'ab') as dst: + dst.writelines(src) + with open(temp_path_3) as src, gzip.open(gzip_file_name, 'ab') as dst: + dst.writelines(src) + # Add the temporary gzip file to be cleaned up as well. + self.temp_files.append(gzip_file_name) + snippets.model_textio_compressed( + {'read': gzip_file_name}, ['a', 'b', 'c', 'p', 'q', 'r', 'x', 'y', 'z']) + @unittest.skipIf(datastore_pb2 is None, 'GCP dependencies are not installed') def test_model_datastoreio(self): # We cannot test datastoreio functionality in unit tests therefore we limit diff --git a/sdks/python/apache_beam/io/filesystem.py b/sdks/python/apache_beam/io/filesystem.py index f5530262b4ca7..1f65d0a3a9edc 100644 --- a/sdks/python/apache_beam/io/filesystem.py +++ b/sdks/python/apache_beam/io/filesystem.py @@ -201,6 +201,14 @@ def _fetch_to_internal_buffer(self, num_bytes): assert False, 'Possible file corruption.' except EOFError: pass # All is as expected! + elif self._compression_type == CompressionTypes.GZIP: + # If Gzip file check if there is unused data generated by gzip concat + if self._decompressor.unused_data != '': + buf = self._decompressor.unused_data + self._decompressor = zlib.decompressobj(self._gzip_mask) + decompressed = self._decompressor.decompress(buf) + self._read_buffer.write(decompressed) + continue else: self._read_buffer.write(self._decompressor.flush()) From f291713b28e3ba0246a8c0a710c71506cd0a0f91 Mon Sep 17 00:00:00 2001 From: Etienne Chauchot Date: Wed, 21 Jun 2017 10:39:39 +0200 Subject: [PATCH 106/200] [BEAM-2489] Use dynamic ES port in HIFIOWithElasticTest --- .../io/hadoop/inputformat/HIFIOWithElasticTest.java | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/sdks/java/io/hadoop/jdk1.8-tests/src/test/java/org/apache/beam/sdk/io/hadoop/inputformat/HIFIOWithElasticTest.java b/sdks/java/io/hadoop/jdk1.8-tests/src/test/java/org/apache/beam/sdk/io/hadoop/inputformat/HIFIOWithElasticTest.java index 8745521a30643..3f866a4845312 100644 --- a/sdks/java/io/hadoop/jdk1.8-tests/src/test/java/org/apache/beam/sdk/io/hadoop/inputformat/HIFIOWithElasticTest.java +++ b/sdks/java/io/hadoop/jdk1.8-tests/src/test/java/org/apache/beam/sdk/io/hadoop/inputformat/HIFIOWithElasticTest.java @@ -20,6 +20,7 @@ import java.io.File; import java.io.IOException; import java.io.Serializable; +import java.net.ServerSocket; import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; @@ -76,7 +77,7 @@ public class HIFIOWithElasticTest implements Serializable { private static final long serialVersionUID = 1L; private static final Logger LOG = LoggerFactory.getLogger(HIFIOWithElasticTest.class); private static final String ELASTIC_IN_MEM_HOSTNAME = "127.0.0.1"; - private static final String ELASTIC_IN_MEM_PORT = "9200"; + private static String elasticInMemPort = "9200"; private static final String ELASTIC_INTERNAL_VERSION = "5.x"; private static final String TRUE = "true"; private static final String ELASTIC_INDEX_NAME = "beamdb"; @@ -94,6 +95,10 @@ public class HIFIOWithElasticTest implements Serializable { @BeforeClass public static void startServer() throws NodeValidationException, InterruptedException, IOException { + ServerSocket serverSocket = new ServerSocket(0); + int port = serverSocket.getLocalPort(); + serverSocket.close(); + elasticInMemPort = String.valueOf(port); ElasticEmbeddedServer.startElasticEmbeddedServer(); } @@ -173,7 +178,7 @@ public void testHifIOWithElasticQuery() { public Configuration getConfiguration() { Configuration conf = new Configuration(); conf.set(ConfigurationOptions.ES_NODES, ELASTIC_IN_MEM_HOSTNAME); - conf.set(ConfigurationOptions.ES_PORT, String.format("%s", ELASTIC_IN_MEM_PORT)); + conf.set(ConfigurationOptions.ES_PORT, String.format("%s", elasticInMemPort)); conf.set(ConfigurationOptions.ES_RESOURCE, ELASTIC_RESOURCE); conf.set("es.internal.es.version", ELASTIC_INTERNAL_VERSION); conf.set(ConfigurationOptions.ES_NODES_DISCOVERY, TRUE); @@ -209,7 +214,7 @@ public static void startElasticEmbeddedServer() Settings settings = Settings.builder() .put("node.data", TRUE) .put("network.host", ELASTIC_IN_MEM_HOSTNAME) - .put("http.port", ELASTIC_IN_MEM_PORT) + .put("http.port", elasticInMemPort) .put("path.data", elasticTempFolder.getRoot().getPath()) .put("path.home", elasticTempFolder.getRoot().getPath()) .put("transport.type", "local") From fd8f15f1ac761425dc791a455b042a8846081f48 Mon Sep 17 00:00:00 2001 From: Mark Liu Date: Thu, 22 Jun 2017 14:04:00 -0700 Subject: [PATCH 107/200] [BEAM-2745] Add Jenkins Suite for Python Performance Test --- .../jenkins/common_job_properties.groovy | 4 +- .../job_beam_PerformanceTests_Python.groovy | 58 +++++++++++++++++++ 2 files changed, 61 insertions(+), 1 deletion(-) create mode 100644 .test-infra/jenkins/job_beam_PerformanceTests_Python.groovy diff --git a/.test-infra/jenkins/common_job_properties.groovy b/.test-infra/jenkins/common_job_properties.groovy index 6d4d68b7a3402..0e047eac70d56 100644 --- a/.test-infra/jenkins/common_job_properties.groovy +++ b/.test-infra/jenkins/common_job_properties.groovy @@ -264,8 +264,10 @@ class common_job_properties { shell('rm -rf PerfKitBenchmarker') // Clone appropriate perfkit branch shell('git clone https://github.com/GoogleCloudPlatform/PerfKitBenchmarker.git') - // Install job requirements. + // Install Perfkit benchmark requirements. shell('pip install --user -r PerfKitBenchmarker/requirements.txt') + // Install job requirements for Python SDK. + shell('pip install --user -e sdks/python/[gcp,test]') // Launch performance test. shell("python PerfKitBenchmarker/pkb.py $pkbArgs") } diff --git a/.test-infra/jenkins/job_beam_PerformanceTests_Python.groovy b/.test-infra/jenkins/job_beam_PerformanceTests_Python.groovy new file mode 100644 index 0000000000000..6a71bdaa51dc2 --- /dev/null +++ b/.test-infra/jenkins/job_beam_PerformanceTests_Python.groovy @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import common_job_properties + +// This job runs the Beam Python performance tests on PerfKit Benchmarker. +job('beam_PerformanceTests_Python'){ + // Set default Beam job properties. + common_job_properties.setTopLevelMainJobProperties(delegate) + + // Run job in postcommit every 6 hours, don't trigger every push. + common_job_properties.setPostCommit( + delegate, + '0 */6 * * *', + false, + 'commits@beam.apache.org') + + // Allows triggering this build against pull requests. + common_job_properties.enablePhraseTriggeringFromPullRequest( + delegate, + 'Python SDK Performance Test', + 'Run Python Performance Test') + + def pipelineArgs = [ + project: 'apache-beam-testing', + staging_location: 'gs://temp-storage-for-end-to-end-tests/staging-it', + temp_location: 'gs://temp-storage-for-end-to-end-tests/temp-it', + output: 'gs://temp-storage-for-end-to-end-tests/py-it-cloud/output' + ] + def pipelineArgList = [] + pipelineArgs.each({ + key, value -> pipelineArgList.add("--$key=$value") + }) + def pipelineArgsJoined = pipelineArgList.join(',') + + def argMap = [ + beam_sdk : 'python', + benchmarks: 'beam_integration_benchmark', + beam_it_args: pipelineArgsJoined + ] + + common_job_properties.buildPerformanceTest(delegate, argMap) +} From 32095487e56b63b5c1aa690bb6e098375cb108d5 Mon Sep 17 00:00:00 2001 From: Vikas Kedigehalli Date: Fri, 23 Jun 2017 11:50:12 -0700 Subject: [PATCH 108/200] Fix python fn API data plane remote grpc port access --- sdks/python/apache_beam/runners/worker/data_plane.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/sdks/python/apache_beam/runners/worker/data_plane.py b/sdks/python/apache_beam/runners/worker/data_plane.py index bc981a8d30edd..26f65ee7d0348 100644 --- a/sdks/python/apache_beam/runners/worker/data_plane.py +++ b/sdks/python/apache_beam/runners/worker/data_plane.py @@ -246,8 +246,8 @@ class DataChannelFactory(object): __metaclass__ = abc.ABCMeta @abc.abstractmethod - def create_data_channel(self, function_spec): - """Returns a ``DataChannel`` from the given function_spec.""" + def create_data_channel(self, remote_grpc_port): + """Returns a ``DataChannel`` from the given RemoteGrpcPort.""" raise NotImplementedError(type(self)) @abc.abstractmethod @@ -265,9 +265,7 @@ class GrpcClientDataChannelFactory(DataChannelFactory): def __init__(self): self._data_channel_cache = {} - def create_data_channel(self, function_spec): - remote_grpc_port = beam_fn_api_pb2.RemoteGrpcPort() - function_spec.data.Unpack(remote_grpc_port) + def create_data_channel(self, remote_grpc_port): url = remote_grpc_port.api_service_descriptor.url if url not in self._data_channel_cache: logging.info('Creating channel for %s', url) @@ -289,7 +287,7 @@ class InMemoryDataChannelFactory(DataChannelFactory): def __init__(self, in_memory_data_channel): self._in_memory_data_channel = in_memory_data_channel - def create_data_channel(self, unused_function_spec): + def create_data_channel(self, unused_remote_grpc_port): return self._in_memory_data_channel def close(self): From 903da41ac5395e76c44ef8ae1c8a695569e23abb Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Fri, 23 Jun 2017 15:01:42 -0700 Subject: [PATCH 109/200] Avoid pickling the entire pipeline per-transform. --- sdks/python/apache_beam/pipeline.py | 7 +++++++ sdks/python/apache_beam/pipeline_test.py | 18 ++++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/sdks/python/apache_beam/pipeline.py b/sdks/python/apache_beam/pipeline.py index d84a2b7b59cce..724c87d023f14 100644 --- a/sdks/python/apache_beam/pipeline.py +++ b/sdks/python/apache_beam/pipeline.py @@ -466,6 +466,13 @@ def apply(self, transform, pvalueish=None, label=None): self.transforms_stack.pop() return pvalueish_result + def __reduce__(self): + # Some transforms contain a reference to their enclosing pipeline, + # which in turn reference all other transforms (resulting in quadratic + # time/space to pickle each transform individually). As we don't + # require pickled pipelines to be executable, break the chain here. + return str, ('Pickled pipeline stub.',) + def _verify_runner_api_compatible(self): class Visitor(PipelineVisitor): # pylint: disable=used-before-assignment ok = True # Really a nonlocal. diff --git a/sdks/python/apache_beam/pipeline_test.py b/sdks/python/apache_beam/pipeline_test.py index f9b894f72eb72..aad01435fd9ea 100644 --- a/sdks/python/apache_beam/pipeline_test.py +++ b/sdks/python/apache_beam/pipeline_test.py @@ -480,6 +480,24 @@ def test_simple(self): p2 = Pipeline.from_runner_api(proto, p.runner, p._options) p2.run() + def test_pickling(self): + class MyPTransform(beam.PTransform): + pickle_count = [0] + + def expand(self, p): + self.p = p + return p | beam.Create([None]) + + def __reduce__(self): + self.pickle_count[0] += 1 + return str, () + + p = beam.Pipeline() + for k in range(20): + p | 'Iter%s' % k >> MyPTransform() # pylint: disable=expression-not-assigned + p.to_runner_api() + self.assertEqual(MyPTransform.pickle_count[0], 20) + if __name__ == '__main__': logging.getLogger().setLevel(logging.DEBUG) From e45f522d6e945899c20259ebf8faca105c2e552e Mon Sep 17 00:00:00 2001 From: Valentyn Tymofieiev Date: Fri, 23 Jun 2017 16:44:49 -0700 Subject: [PATCH 110/200] Fix a typo in function args --- sdks/python/apache_beam/examples/streaming_wordcount.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/examples/streaming_wordcount.py b/sdks/python/apache_beam/examples/streaming_wordcount.py index f2b179aa2438d..4c29f2b46b302 100644 --- a/sdks/python/apache_beam/examples/streaming_wordcount.py +++ b/sdks/python/apache_beam/examples/streaming_wordcount.py @@ -33,7 +33,7 @@ def split_fn(lines): import re - return re.findall(r'[A-Za-z\']+', x) + return re.findall(r'[A-Za-z\']+', lines) def run(argv=None): From 926f949580c3a21df72a8836feda1f6b947850ec Mon Sep 17 00:00:00 2001 From: Charles Chen Date: Mon, 26 Jun 2017 13:00:14 -0700 Subject: [PATCH 111/200] Remove old deprecated PubSub code --- sdks/python/apache_beam/io/gcp/pubsub.py | 71 +------------------ .../runners/dataflow/internal/dependency.py | 2 +- 2 files changed, 2 insertions(+), 71 deletions(-) diff --git a/sdks/python/apache_beam/io/gcp/pubsub.py b/sdks/python/apache_beam/io/gcp/pubsub.py index 6dc15288276da..fabe29612a8b6 100644 --- a/sdks/python/apache_beam/io/gcp/pubsub.py +++ b/sdks/python/apache_beam/io/gcp/pubsub.py @@ -33,8 +33,7 @@ from apache_beam.transforms.display import DisplayDataItem -__all__ = ['ReadStringsFromPubSub', 'WriteStringsToPubSub', - 'PubSubSource', 'PubSubSink'] +__all__ = ['ReadStringsFromPubSub', 'WriteStringsToPubSub'] class ReadStringsFromPubSub(PTransform): @@ -160,71 +159,3 @@ def display_data(self): def writer(self): raise NotImplementedError( 'PubSubPayloadSink is not supported in local execution.') - - -class PubSubSource(dataflow_io.NativeSource): - """Deprecated: do not use. - - Source for reading from a given Cloud Pub/Sub topic. - - Attributes: - topic: Cloud Pub/Sub topic in the form "/topics//". - subscription: Optional existing Cloud Pub/Sub subscription to use in the - form "projects//subscriptions/". - id_label: The attribute on incoming Pub/Sub messages to use as a unique - record identifier. When specified, the value of this attribute (which can - be any string that uniquely identifies the record) will be used for - deduplication of messages. If not provided, Dataflow cannot guarantee - that no duplicate data will be delivered on the Pub/Sub stream. In this - case, deduplication of the stream will be strictly best effort. - coder: The Coder to use for decoding incoming Pub/Sub messages. - """ - - def __init__(self, topic, subscription=None, id_label=None, - coder=coders.StrUtf8Coder()): - self.topic = topic - self.subscription = subscription - self.id_label = id_label - self.coder = coder - - @property - def format(self): - """Source format name required for remote execution.""" - return 'pubsub' - - def display_data(self): - return {'id_label': - DisplayDataItem(self.id_label, - label='ID Label Attribute').drop_if_none(), - 'topic': - DisplayDataItem(self.topic, - label='Pubsub Topic'), - 'subscription': - DisplayDataItem(self.subscription, - label='Pubsub Subscription').drop_if_none()} - - def reader(self): - raise NotImplementedError( - 'PubSubSource is not supported in local execution.') - - -class PubSubSink(dataflow_io.NativeSink): - """Deprecated: do not use. - - Sink for writing to a given Cloud Pub/Sub topic.""" - - def __init__(self, topic, coder=coders.StrUtf8Coder()): - self.topic = topic - self.coder = coder - - @property - def format(self): - """Sink format name required for remote execution.""" - return 'pubsub' - - def display_data(self): - return {'topic': DisplayDataItem(self.topic, label='Pubsub Topic')} - - def writer(self): - raise NotImplementedError( - 'PubSubSink is not supported in local execution.') diff --git a/sdks/python/apache_beam/runners/dataflow/internal/dependency.py b/sdks/python/apache_beam/runners/dataflow/internal/dependency.py index e65660060db26..6d4a703bba1b6 100644 --- a/sdks/python/apache_beam/runners/dataflow/internal/dependency.py +++ b/sdks/python/apache_beam/runners/dataflow/internal/dependency.py @@ -73,7 +73,7 @@ # Update this version to the next version whenever there is a change that will # require changes to the execution environment. # This should be in the beam-[version]-[date] format, date is optional. -BEAM_CONTAINER_VERSION = 'beam-2.1.0-20170601' +BEAM_CONTAINER_VERSION = 'beam-2.1.0-20170626' # Standard file names used for staging files. WORKFLOW_TARBALL_FILE = 'workflow.tar.gz' From bec32fe93c6b5c16563d7ea4b877a2dee3352fee Mon Sep 17 00:00:00 2001 From: Eugene Kirpichov Date: Fri, 16 Jun 2017 14:56:07 -0700 Subject: [PATCH 112/200] Reintroduces DoFn.ProcessContinuation (Dataflow worker compatibility part) --- .../src/main/java/org/apache/beam/sdk/transforms/DoFn.java | 3 +++ .../sdk/transforms/reflect/ByteBuddyDoFnInvokerFactory.java | 6 ++++++ .../org/apache/beam/sdk/transforms/reflect/DoFnInvoker.java | 4 +++- 3 files changed, 12 insertions(+), 1 deletion(-) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java index e711ac2f297aa..fb6d0ee4ffe58 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java @@ -677,6 +677,9 @@ public interface OutputReceiver { @Experimental(Kind.SPLITTABLE_DO_FN) public @interface UnboundedPerElement {} + /** Temporary, do not use. See https://issues.apache.org/jira/browse/BEAM-1904 */ + public class ProcessContinuation {} + /** * Finalize the {@link DoFn} construction to prepare for processing. * This method should be called by runners before any processing methods. diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/ByteBuddyDoFnInvokerFactory.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/ByteBuddyDoFnInvokerFactory.java index 5d5887a3c59b3..4f67db4b33ce3 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/ByteBuddyDoFnInvokerFactory.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/ByteBuddyDoFnInvokerFactory.java @@ -49,6 +49,7 @@ import net.bytebuddy.implementation.bytecode.assign.Assigner; import net.bytebuddy.implementation.bytecode.assign.Assigner.Typing; import net.bytebuddy.implementation.bytecode.assign.TypeCasting; +import net.bytebuddy.implementation.bytecode.constant.NullConstant; import net.bytebuddy.implementation.bytecode.constant.TextConstant; import net.bytebuddy.implementation.bytecode.member.FieldAccess; import net.bytebuddy.implementation.bytecode.member.MethodInvocation; @@ -667,6 +668,11 @@ protected StackManipulation beforeDelegation(MethodDescription instrumentedMetho } return new StackManipulation.Compound(pushParameters); } + + @Override + protected StackManipulation afterDelegation(MethodDescription instrumentedMethod) { + return new StackManipulation.Compound(NullConstant.INSTANCE, MethodReturn.REFERENCE); + } } private static class UserCodeMethodInvocation implements StackManipulation { diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvoker.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvoker.java index 6fd40523c28c6..ed81f42870bbe 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvoker.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvoker.java @@ -53,8 +53,10 @@ public interface DoFnInvoker { * Invoke the {@link DoFn.ProcessElement} method on the bound {@link DoFn}. * * @param extra Factory for producing extra parameter objects (such as window), if necessary. + * @return {@code null} - see JIRA + * tracking the complete removal of {@link DoFn.ProcessContinuation}. */ - void invokeProcessElement(ArgumentProvider extra); + DoFn.ProcessContinuation invokeProcessElement(ArgumentProvider extra); /** Invoke the appropriate {@link DoFn.OnTimer} method on the bound {@link DoFn}. */ void invokeOnTimer(String timerId, ArgumentProvider arguments); From 2052cc7689b4ed53f817f56dd71a32235fb083ca Mon Sep 17 00:00:00 2001 From: Eugene Kirpichov Date: Fri, 23 Jun 2017 10:16:30 -0700 Subject: [PATCH 113/200] Bump Dataflow worker to 0623 --- runners/google-cloud-dataflow-java/pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/runners/google-cloud-dataflow-java/pom.xml b/runners/google-cloud-dataflow-java/pom.xml index fbb0b87fc6204..2ba163bdf84e0 100644 --- a/runners/google-cloud-dataflow-java/pom.xml +++ b/runners/google-cloud-dataflow-java/pom.xml @@ -33,7 +33,7 @@ jar - beam-master-20170622 + beam-master-20170623 1 6 From eb379e76adaa9c4b4e24a4b3c5757be8523d95c4 Mon Sep 17 00:00:00 2001 From: Charles Chen Date: Mon, 26 Jun 2017 16:54:00 -0700 Subject: [PATCH 114/200] Implement streaming GroupByKey in Python DirectRunner --- .../runners/direct/direct_runner.py | 29 +++- .../runners/direct/evaluation_context.py | 2 +- .../runners/direct/transform_evaluator.py | 138 +++++++++++++++++- .../python/apache_beam/runners/direct/util.py | 25 ++-- .../runners/direct/watermark_manager.py | 26 ++-- .../apache_beam/testing/test_stream_test.py | 37 ++++- sdks/python/apache_beam/transforms/trigger.py | 16 ++ 7 files changed, 239 insertions(+), 34 deletions(-) diff --git a/sdks/python/apache_beam/runners/direct/direct_runner.py b/sdks/python/apache_beam/runners/direct/direct_runner.py index d80ef102e0363..2a75977576128 100644 --- a/sdks/python/apache_beam/runners/direct/direct_runner.py +++ b/sdks/python/apache_beam/runners/direct/direct_runner.py @@ -34,6 +34,7 @@ from apache_beam.runners.runner import PipelineState from apache_beam.runners.runner import PValueCache from apache_beam.transforms.core import _GroupAlsoByWindow +from apache_beam.transforms.core import _GroupByKeyOnly from apache_beam.options.pipeline_options import DirectOptions from apache_beam.options.pipeline_options import StandardOptions from apache_beam.options.value_provider import RuntimeValueProvider @@ -47,6 +48,13 @@ V = typehints.TypeVariable('V') +@typehints.with_input_types(typehints.KV[K, V]) +@typehints.with_output_types(typehints.KV[K, typehints.Iterable[V]]) +class _StreamingGroupByKeyOnly(_GroupByKeyOnly): + """Streaming GroupByKeyOnly placeholder for overriding in DirectRunner.""" + pass + + @typehints.with_input_types(typehints.KV[K, typehints.Iterable[V]]) @typehints.with_output_types(typehints.KV[K, typehints.Iterable[V]]) class _StreamingGroupAlsoByWindow(_GroupAlsoByWindow): @@ -79,17 +87,24 @@ def apply_CombinePerKey(self, transform, pcoll): except NotImplementedError: return transform.expand(pcoll) + def apply__GroupByKeyOnly(self, transform, pcoll): + if (transform.__class__ == _GroupByKeyOnly and + pcoll.pipeline._options.view_as(StandardOptions).streaming): + # Use specialized streaming implementation, if requested. + type_hints = transform.get_type_hints() + return pcoll | (_StreamingGroupByKeyOnly() + .with_input_types(*type_hints.input_types[0]) + .with_output_types(*type_hints.output_types[0])) + return transform.expand(pcoll) + def apply__GroupAlsoByWindow(self, transform, pcoll): if (transform.__class__ == _GroupAlsoByWindow and pcoll.pipeline._options.view_as(StandardOptions).streaming): # Use specialized streaming implementation, if requested. - raise NotImplementedError( - 'Streaming support is not yet available on the DirectRunner.') - # TODO(ccy): enable when streaming implementation is plumbed through. - # type_hints = transform.get_type_hints() - # return pcoll | (_StreamingGroupAlsoByWindow(transform.windowing) - # .with_input_types(*type_hints.input_types[0]) - # .with_output_types(*type_hints.output_types[0])) + type_hints = transform.get_type_hints() + return pcoll | (_StreamingGroupAlsoByWindow(transform.windowing) + .with_input_types(*type_hints.input_types[0]) + .with_output_types(*type_hints.output_types[0])) return transform.expand(pcoll) def run(self, pipeline): diff --git a/sdks/python/apache_beam/runners/direct/evaluation_context.py b/sdks/python/apache_beam/runners/direct/evaluation_context.py index 669a68a13c7da..54c407c1c866f 100644 --- a/sdks/python/apache_beam/runners/direct/evaluation_context.py +++ b/sdks/python/apache_beam/runners/direct/evaluation_context.py @@ -213,7 +213,7 @@ def handle_result( result.unprocessed_bundles) self._watermark_manager.update_watermarks( completed_bundle, result.transform, completed_timers, - committed_bundles, unprocessed_bundles, result.watermark_hold) + committed_bundles, unprocessed_bundles, result.keyed_watermark_holds) self._metrics.commit_logical(completed_bundle, result.logical_metric_updates) diff --git a/sdks/python/apache_beam/runners/direct/transform_evaluator.py b/sdks/python/apache_beam/runners/direct/transform_evaluator.py index 3aefbb8d5a1b2..67b24927e2e38 100644 --- a/sdks/python/apache_beam/runners/direct/transform_evaluator.py +++ b/sdks/python/apache_beam/runners/direct/transform_evaluator.py @@ -27,6 +27,8 @@ import apache_beam.io as io from apache_beam.runners.common import DoFnRunner from apache_beam.runners.common import DoFnState +from apache_beam.runners.direct.direct_runner import _StreamingGroupByKeyOnly +from apache_beam.runners.direct.direct_runner import _StreamingGroupAlsoByWindow from apache_beam.runners.direct.watermark_manager import WatermarkManager from apache_beam.runners.direct.util import KeyedWorkItem from apache_beam.runners.direct.util import TransformResult @@ -38,6 +40,7 @@ from apache_beam.transforms import core from apache_beam.transforms.window import GlobalWindows from apache_beam.transforms.window import WindowedValue +from apache_beam.transforms.trigger import create_trigger_driver from apache_beam.transforms.trigger import _CombiningValueStateTag from apache_beam.transforms.trigger import _ListStateTag from apache_beam.transforms.trigger import TimeDomain @@ -63,6 +66,8 @@ def __init__(self, evaluation_context): core.Flatten: _FlattenEvaluator, core.ParDo: _ParDoEvaluator, core._GroupByKeyOnly: _GroupByKeyOnlyEvaluator, + _StreamingGroupByKeyOnly: _StreamingGroupByKeyOnlyEvaluator, + _StreamingGroupAlsoByWindow: _StreamingGroupAlsoByWindowEvaluator, _NativeWrite: _NativeWriteEvaluator, TestStream: _TestStreamEvaluator, } @@ -125,7 +130,10 @@ def should_execute_serially(self, applied_ptransform): True if executor should execute applied_ptransform serially. """ return isinstance(applied_ptransform.transform, - (core._GroupByKeyOnly, _NativeWrite)) + (core._GroupByKeyOnly, + _StreamingGroupByKeyOnly, + _StreamingGroupAlsoByWindow, + _NativeWrite,)) class RootBundleProvider(object): @@ -234,7 +242,7 @@ def process_timer_wrapper(self, timer_firing): timer and passes it to process_element(). Evaluator subclasses which desire different timer delivery semantics can override process_timer(). """ - state = self.step_context.get_keyed_state(timer_firing.key) + state = self.step_context.get_keyed_state(timer_firing.encoded_key) state.clear_timer( timer_firing.window, timer_firing.name, timer_firing.time_domain) self.process_timer(timer_firing) @@ -242,7 +250,9 @@ def process_timer_wrapper(self, timer_firing): def process_timer(self, timer_firing): """Default process_timer() impl. generating KeyedWorkItem element.""" self.process_element( - KeyedWorkItem(timer_firing.key, timer_firing=timer_firing)) + GlobalWindows.windowed_value( + KeyedWorkItem(timer_firing.encoded_key, + timer_firings=[timer_firing]))) def process_element(self, element): """Processes a new element as part of the current bundle.""" @@ -343,7 +353,8 @@ def finish_bundle(self): unprocessed_bundles.append(unprocessed_bundle) hold = self.watermark return TransformResult( - self._applied_ptransform, self.bundles, unprocessed_bundles, None, hold) + self._applied_ptransform, self.bundles, unprocessed_bundles, None, + {None: hold}) class _FlattenEvaluator(_TransformEvaluator): @@ -547,7 +558,122 @@ def len_element_fn(element): None, '', TimeDomain.WATERMARK, WatermarkManager.WATERMARK_POS_INF) return TransformResult( - self._applied_ptransform, bundles, [], None, hold) + self._applied_ptransform, bundles, [], None, {None: hold}) + + +class _StreamingGroupByKeyOnlyEvaluator(_TransformEvaluator): + """TransformEvaluator for _StreamingGroupByKeyOnly transform. + + The _GroupByKeyOnlyEvaluator buffers elements until its input watermark goes + to infinity, which is suitable for batch mode execution. During streaming + mode execution, we emit each bundle as it comes to the next transform. + """ + + MAX_ELEMENT_PER_BUNDLE = None + + def __init__(self, evaluation_context, applied_ptransform, + input_committed_bundle, side_inputs, scoped_metrics_container): + assert not side_inputs + super(_StreamingGroupByKeyOnlyEvaluator, self).__init__( + evaluation_context, applied_ptransform, input_committed_bundle, + side_inputs, scoped_metrics_container) + + def start_bundle(self): + self.gbk_items = collections.defaultdict(list) + + assert len(self._outputs) == 1 + self.output_pcollection = list(self._outputs)[0] + + # The input type of a GroupByKey will be KV[Any, Any] or more specific. + kv_type_hint = ( + self._applied_ptransform.transform.get_type_hints().input_types[0]) + self.key_coder = coders.registry.get_coder(kv_type_hint[0].tuple_types[0]) + + def process_element(self, element): + if (isinstance(element, WindowedValue) + and isinstance(element.value, collections.Iterable) + and len(element.value) == 2): + k, v = element.value + self.gbk_items[self.key_coder.encode(k)].append(v) + else: + raise TypeCheckError('Input to _GroupByKeyOnly must be a PCollection of ' + 'windowed key-value pairs. Instead received: %r.' + % element) + + def finish_bundle(self): + bundles = [] + bundle = None + for encoded_k, vs in self.gbk_items.iteritems(): + if not bundle: + bundle = self._evaluation_context.create_bundle( + self.output_pcollection) + bundles.append(bundle) + kwi = KeyedWorkItem(encoded_k, elements=vs) + bundle.add(GlobalWindows.windowed_value(kwi)) + + return TransformResult( + self._applied_ptransform, bundles, [], None, None) + + +class _StreamingGroupAlsoByWindowEvaluator(_TransformEvaluator): + """TransformEvaluator for the _StreamingGroupAlsoByWindow transform. + + This evaluator is only used in streaming mode. In batch mode, the + GroupAlsoByWindow operation is evaluated as a normal DoFn, as defined + in transforms/core.py. + """ + + def __init__(self, evaluation_context, applied_ptransform, + input_committed_bundle, side_inputs, scoped_metrics_container): + assert not side_inputs + super(_StreamingGroupAlsoByWindowEvaluator, self).__init__( + evaluation_context, applied_ptransform, input_committed_bundle, + side_inputs, scoped_metrics_container) + + def start_bundle(self): + assert len(self._outputs) == 1 + self.output_pcollection = list(self._outputs)[0] + self.step_context = self._execution_context.get_step_context() + self.driver = create_trigger_driver( + self._applied_ptransform.transform.windowing) + self.gabw_items = [] + self.keyed_holds = {} + + # The input type of a GroupAlsoByWindow will be KV[Any, Iter[Any]] or more + # specific. + kv_type_hint = ( + self._applied_ptransform.transform.get_type_hints().input_types[0]) + self.key_coder = coders.registry.get_coder(kv_type_hint[0].tuple_types[0]) + + def process_element(self, element): + kwi = element.value + assert isinstance(kwi, KeyedWorkItem), kwi + encoded_k, timer_firings, vs = ( + kwi.encoded_key, kwi.timer_firings, kwi.elements) + k = self.key_coder.decode(encoded_k) + state = self.step_context.get_keyed_state(encoded_k) + + for timer_firing in timer_firings: + for wvalue in self.driver.process_timer( + timer_firing.window, timer_firing.name, timer_firing.time_domain, + timer_firing.timestamp, state): + self.gabw_items.append(wvalue.with_value((k, wvalue.value))) + if vs: + for wvalue in self.driver.process_elements(state, vs, MIN_TIMESTAMP): + self.gabw_items.append(wvalue.with_value((k, wvalue.value))) + + self.keyed_holds[encoded_k] = state.get_earliest_hold() + + def finish_bundle(self): + bundles = [] + if self.gabw_items: + bundle = self._evaluation_context.create_bundle(self.output_pcollection) + for item in self.gabw_items: + bundle.add(item) + bundles.append(bundle) + + return TransformResult( + self._applied_ptransform, bundles, [], None, self.keyed_holds) class _NativeWriteEvaluator(_TransformEvaluator): @@ -612,4 +738,4 @@ def finish_bundle(self): None, '', TimeDomain.WATERMARK, WatermarkManager.WATERMARK_POS_INF) return TransformResult( - self._applied_ptransform, [], [], None, hold) + self._applied_ptransform, [], [], None, {None: hold}) diff --git a/sdks/python/apache_beam/runners/direct/util.py b/sdks/python/apache_beam/runners/direct/util.py index 8c846fc55eb4d..10f7b294c1306 100644 --- a/sdks/python/apache_beam/runners/direct/util.py +++ b/sdks/python/apache_beam/runners/direct/util.py @@ -27,13 +27,21 @@ class TransformResult(object): """Result of evaluating an AppliedPTransform with a TransformEvaluator.""" def __init__(self, applied_ptransform, uncommitted_output_bundles, - unprocessed_bundles, counters, watermark_hold, + unprocessed_bundles, counters, keyed_watermark_holds, undeclared_tag_values=None): self.transform = applied_ptransform self.uncommitted_output_bundles = uncommitted_output_bundles self.unprocessed_bundles = unprocessed_bundles self.counters = counters - self.watermark_hold = watermark_hold + # Mapping of key -> earliest hold timestamp or None. Keys should be + # strings or None. + # + # For each key, we receive as its corresponding value the earliest + # watermark hold for that key (the key can be None for global state), past + # which the output watermark for the currently-executing step will not + # advance. If the value is None or utils.timestamp.MAX_TIMESTAMP, the + # watermark hold will be removed. + self.keyed_watermark_holds = keyed_watermark_holds or {} # Only used when caching (materializing) all values is requested. self.undeclared_tag_values = undeclared_tag_values # Populated by the TransformExecutor. @@ -43,8 +51,8 @@ def __init__(self, applied_ptransform, uncommitted_output_bundles, class TimerFiring(object): """A single instance of a fired timer.""" - def __init__(self, key, window, name, time_domain, timestamp): - self.key = key + def __init__(self, encoded_key, window, name, time_domain, timestamp): + self.encoded_key = encoded_key self.window = window self.name = name self.time_domain = time_domain @@ -53,8 +61,7 @@ def __init__(self, key, window, name, time_domain, timestamp): class KeyedWorkItem(object): """A keyed item that can either be a timer firing or a list of elements.""" - def __init__(self, key, timer_firing=None, elements=None): - self.key = key - assert not timer_firing and elements - self.timer_firing = timer_firing - self.elements = elements + def __init__(self, encoded_key, timer_firings=None, elements=None): + self.encoded_key = encoded_key + self.timer_firings = timer_firings or [] + self.elements = elements or [] diff --git a/sdks/python/apache_beam/runners/direct/watermark_manager.py b/sdks/python/apache_beam/runners/direct/watermark_manager.py index 4aa2bb4342f18..935998d27de0f 100644 --- a/sdks/python/apache_beam/runners/direct/watermark_manager.py +++ b/sdks/python/apache_beam/runners/direct/watermark_manager.py @@ -94,13 +94,13 @@ def get_watermarks(self, applied_ptransform): def update_watermarks(self, completed_committed_bundle, applied_ptransform, completed_timers, outputs, unprocessed_bundles, - earliest_hold): + keyed_earliest_holds): assert isinstance(applied_ptransform, pipeline.AppliedPTransform) self._update_pending( completed_committed_bundle, applied_ptransform, completed_timers, outputs, unprocessed_bundles) tw = self.get_watermarks(applied_ptransform) - tw.hold(earliest_hold) + tw.hold(keyed_earliest_holds) self._refresh_watermarks(applied_ptransform) def _update_pending(self, input_committed_bundle, applied_ptransform, @@ -161,7 +161,7 @@ def __init__(self, clock, keyed_states, transform): self._input_transform_watermarks = [] self._input_watermark = WatermarkManager.WATERMARK_NEG_INF self._output_watermark = WatermarkManager.WATERMARK_NEG_INF - self._earliest_hold = WatermarkManager.WATERMARK_POS_INF + self._keyed_earliest_holds = {} self._pending = set() # Scheduled bundles targeted for this transform. self._fired_timers = set() self._lock = threading.Lock() @@ -187,11 +187,13 @@ def output_watermark(self): with self._lock: return self._output_watermark - def hold(self, value): + def hold(self, keyed_earliest_holds): with self._lock: - if value is None: - value = WatermarkManager.WATERMARK_POS_INF - self._earliest_hold = value + for key, hold_value in keyed_earliest_holds.iteritems(): + self._keyed_earliest_holds[key] = hold_value + if (hold_value is None or + hold_value == WatermarkManager.WATERMARK_POS_INF): + del self._keyed_earliest_holds[key] def add_pending(self, pending): with self._lock: @@ -230,7 +232,11 @@ def refresh(self): self._input_watermark = max(self._input_watermark, min(pending_holder, producer_watermark)) - new_output_watermark = min(self._input_watermark, self._earliest_hold) + earliest_hold = WatermarkManager.WATERMARK_POS_INF + for hold in self._keyed_earliest_holds.values(): + if hold < earliest_hold: + earliest_hold = hold + new_output_watermark = min(self._input_watermark, earliest_hold) advanced = new_output_watermark > self._output_watermark self._output_watermark = new_output_watermark @@ -246,11 +252,11 @@ def extract_fired_timers(self): return False fired_timers = [] - for key, state in self._keyed_states.iteritems(): + for encoded_key, state in self._keyed_states.iteritems(): timers = state.get_timers(watermark=self._input_watermark) for expired in timers: window, (name, time_domain, timestamp) = expired fired_timers.append( - TimerFiring(key, window, name, time_domain, timestamp)) + TimerFiring(encoded_key, window, name, time_domain, timestamp)) self._fired_timers.update(fired_timers) return fired_timers diff --git a/sdks/python/apache_beam/testing/test_stream_test.py b/sdks/python/apache_beam/testing/test_stream_test.py index 071c7cd3d6c00..b7ca141f0598e 100644 --- a/sdks/python/apache_beam/testing/test_stream_test.py +++ b/sdks/python/apache_beam/testing/test_stream_test.py @@ -20,12 +20,15 @@ import unittest import apache_beam as beam +from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.options.pipeline_options import StandardOptions from apache_beam.testing.test_pipeline import TestPipeline from apache_beam.testing.test_stream import ElementEvent from apache_beam.testing.test_stream import ProcessingTimeEvent from apache_beam.testing.test_stream import TestStream from apache_beam.testing.test_stream import WatermarkEvent from apache_beam.testing.util import assert_that, equal_to +from apache_beam.transforms.window import FixedWindows from apache_beam.transforms.window import TimestampedValue from apache_beam.utils import timestamp from apache_beam.utils.windowed_value import WindowedValue @@ -98,7 +101,9 @@ def process(self, element=beam.DoFn.ElementParam, timestamp=beam.DoFn.TimestampParam): yield (element, timestamp) - p = TestPipeline() + options = PipelineOptions() + options.view_as(StandardOptions).streaming = True + p = TestPipeline(options=options) my_record_fn = RecordFn() records = p | test_stream | beam.ParDo(my_record_fn) assert_that(records, equal_to([ @@ -111,6 +116,36 @@ def process(self, element=beam.DoFn.ElementParam, ('last', timestamp.Timestamp(310)),])) p.run() + def test_gbk_execution(self): + test_stream = (TestStream() + .advance_watermark_to(10) + .add_elements(['a', 'b', 'c']) + .advance_watermark_to(20) + .add_elements(['d']) + .add_elements(['e']) + .advance_processing_time(10) + .advance_watermark_to(300) + .add_elements([TimestampedValue('late', 12)]) + .add_elements([TimestampedValue('last', 310)])) + + options = PipelineOptions() + options.view_as(StandardOptions).streaming = True + p = TestPipeline(options=options) + records = (p + | test_stream + | beam.WindowInto(FixedWindows(15)) + | beam.Map(lambda x: ('k', x)) + | beam.GroupByKey()) + # TODO(BEAM-2519): timestamp assignment for elements from a GBK should + # respect the TimestampCombiner. The test below should also verify the + # timestamps of the outputted elements once this is implemented. + assert_that(records, equal_to([ + ('k', ['a', 'b', 'c']), + ('k', ['d', 'e']), + ('k', ['late']), + ('k', ['last'])])) + p.run() + if __name__ == '__main__': unittest.main() diff --git a/sdks/python/apache_beam/transforms/trigger.py b/sdks/python/apache_beam/transforms/trigger.py index 7ff44fa8fde39..f77fa1a996662 100644 --- a/sdks/python/apache_beam/transforms/trigger.py +++ b/sdks/python/apache_beam/transforms/trigger.py @@ -36,6 +36,7 @@ from apache_beam.portability.api import beam_runner_api_pb2 from apache_beam.utils.timestamp import MAX_TIMESTAMP from apache_beam.utils.timestamp import MIN_TIMESTAMP +from apache_beam.utils.timestamp import TIME_GRANULARITY # AfterCount is experimental. No backwards compatibility guarantees. @@ -1066,6 +1067,8 @@ def set_timer(self, window, name, time_domain, timestamp): def clear_timer(self, window, name, time_domain): self.timers[window].pop((name, time_domain), None) + if not self.timers[window]: + del self.timers[window] def get_window(self, window_id): return window_id @@ -1117,6 +1120,19 @@ def get_timers(self, clear=False, watermark=MAX_TIMESTAMP): def get_and_clear_timers(self, watermark=MAX_TIMESTAMP): return self.get_timers(clear=True, watermark=watermark) + def get_earliest_hold(self): + earliest_hold = MAX_TIMESTAMP + for unused_window, tagged_states in self.state.iteritems(): + # TODO(BEAM-2519): currently, this assumes that the watermark hold tag is + # named "watermark". This is currently only true because the only place + # watermark holds are set is in the GeneralTriggerDriver, where we use + # this name. We should fix this by allowing enumeration of the tag types + # used in adding state. + if 'watermark' in tagged_states and tagged_states['watermark']: + hold = min(tagged_states['watermark']) - TIME_GRANULARITY + earliest_hold = min(earliest_hold, hold) + return earliest_hold + def __repr__(self): state_str = '\n'.join('%s: %s' % (key, dict(state)) for key, state in self.state.items()) From cbfcad823972270171552e556b7fa8d8d4882f14 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jean-Baptiste=20Onofr=C3=A9?= Date: Wed, 28 Dec 2016 15:06:09 +0100 Subject: [PATCH 115/200] [BEAM-1237] Create AmqpIO --- sdks/java/io/amqp/pom.xml | 100 +++++ .../org/apache/beam/sdk/io/amqp/AmqpIO.java | 397 ++++++++++++++++++ .../beam/sdk/io/amqp/AmqpMessageCoder.java | 79 ++++ .../AmqpMessageCoderProviderRegistrar.java | 44 ++ .../apache/beam/sdk/io/amqp/package-info.java | 22 + .../apache/beam/sdk/io/amqp/AmqpIOTest.java | 148 +++++++ .../sdk/io/amqp/AmqpMessageCoderTest.java | 89 ++++ sdks/java/io/pom.xml | 1 + 8 files changed, 880 insertions(+) create mode 100644 sdks/java/io/amqp/pom.xml create mode 100644 sdks/java/io/amqp/src/main/java/org/apache/beam/sdk/io/amqp/AmqpIO.java create mode 100644 sdks/java/io/amqp/src/main/java/org/apache/beam/sdk/io/amqp/AmqpMessageCoder.java create mode 100644 sdks/java/io/amqp/src/main/java/org/apache/beam/sdk/io/amqp/AmqpMessageCoderProviderRegistrar.java create mode 100644 sdks/java/io/amqp/src/main/java/org/apache/beam/sdk/io/amqp/package-info.java create mode 100644 sdks/java/io/amqp/src/test/java/org/apache/beam/sdk/io/amqp/AmqpIOTest.java create mode 100644 sdks/java/io/amqp/src/test/java/org/apache/beam/sdk/io/amqp/AmqpMessageCoderTest.java diff --git a/sdks/java/io/amqp/pom.xml b/sdks/java/io/amqp/pom.xml new file mode 100644 index 0000000000000..45b295dfce244 --- /dev/null +++ b/sdks/java/io/amqp/pom.xml @@ -0,0 +1,100 @@ + + + + + 4.0.0 + + + org.apache.beam + beam-sdks-java-io-parent + 2.1.0-SNAPSHOT + ../pom.xml + + + beam-sdks-java-io-amqp + Apache Beam :: SDKs :: Java :: IO :: AMQP + IO to read and write using AMQP 1.0 protocol (http://www.amqp.org). + + + + org.apache.beam + beam-sdks-java-core + + + + org.slf4j + slf4j-api + + + + joda-time + joda-time + + + + com.google.guava + guava + + + + com.google.code.findbugs + jsr305 + + + + org.apache.qpid + proton-j + 0.13.1 + + + + + com.google.auto.value + auto-value + provided + + + com.google.auto.service + auto-service + true + + + + + org.slf4j + slf4j-jdk14 + test + + + junit + junit + test + + + org.hamcrest + hamcrest-all + test + + + org.apache.beam + beam-runners-direct-java + test + + + + \ No newline at end of file diff --git a/sdks/java/io/amqp/src/main/java/org/apache/beam/sdk/io/amqp/AmqpIO.java b/sdks/java/io/amqp/src/main/java/org/apache/beam/sdk/io/amqp/AmqpIO.java new file mode 100644 index 0000000000000..b9a0be9a078fe --- /dev/null +++ b/sdks/java/io/amqp/src/main/java/org/apache/beam/sdk/io/amqp/AmqpIO.java @@ -0,0 +1,397 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.amqp; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; + +import com.google.auto.value.AutoValue; +import com.google.common.base.Joiner; + +import java.io.IOException; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; +import java.util.NoSuchElementException; + +import javax.annotation.Nullable; + +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.SerializableCoder; +import org.apache.beam.sdk.io.UnboundedSource; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.display.DisplayData; +import org.apache.beam.sdk.values.PBegin; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PDone; +import org.apache.qpid.proton.message.Message; +import org.apache.qpid.proton.messenger.Messenger; +import org.apache.qpid.proton.messenger.Tracker; +import org.joda.time.Duration; +import org.joda.time.Instant; + +/** + * AmqpIO supports AMQP 1.0 protocol using the Apache QPid Proton-J library. + * + *

    It's also possible to use AMQP 1.0 protocol via Apache Qpid JMS connection factory and the + * Apache Beam JmsIO. + * + *

    Binding AMQP and receive messages

    + * + *

    The {@link AmqpIO} {@link Read} can bind a AMQP listener endpoint and receive messages. It can + * also connect to a AMPQ broker (such as Apache Qpid or Apache ActiveMQ). + * + *

    {@link AmqpIO} {@link Read} returns an unbounded {@link PCollection} of {@link Message} + * containing the received messages. + * + *

    To configure a AMQP source, you have to provide a list of addresses where it will receive + * messages. An address has the following form: {@code + * [amqp[s]://][user[:password]@]domain[/[name]]} where {@code domain} can be one of {@code + * host | host:port | ip | ip:port | name}. NB: the {@code ~} character allows to bind a AMQP + * listener instead of connecting to a remote broker. For instance {@code amqp://~0.0.0.0:1234} + * will bind a AMQP listener on any network interface on the 1234 port number. + * + *

    The following example illustrates how to configure a AMQP source: + * + *

    {@code
    + *
    + *  pipeline.apply(AmqpIO.read()
    + *    .withAddresses(Collections.singletonList("amqp://host:1234")))
    + *
    + * }
    + * + *

    Sending messages to a AMQP endpoint

    + * + *

    {@link AmqpIO} provides a sink to send {@link PCollection} elements as messages. + * + *

    As for the {@link Read}, {@link AmqpIO} {@link Write} requires a list of addresses where to + * send messages. The following example illustrates how to configure the {@link AmqpIO} + * {@link Write}: + * + *

    {@code
    + *
    + *  pipeline
    + *    .apply(...) // provide PCollection
    + *    .apply(AmqpIO.write());
    + *
    + * }
    + */ +public class AmqpIO { + + public static Read read() { + return new AutoValue_AmqpIO_Read.Builder().setMaxNumRecords(Long.MAX_VALUE).build(); + } + + public static Write write() { + return new AutoValue_AmqpIO_Write(); + } + + private AmqpIO() { + } + + /** + * A {@link PTransform} to read/receive messages using AMQP 1.0 protocol. + */ + @AutoValue + public abstract static class Read extends PTransform> { + + @Nullable abstract List addresses(); + abstract long maxNumRecords(); + @Nullable abstract Duration maxReadTime(); + + abstract Builder builder(); + + @AutoValue.Builder + abstract static class Builder { + abstract Builder setAddresses(List addresses); + abstract Builder setMaxNumRecords(long maxNumRecords); + abstract Builder setMaxReadTime(Duration maxReadTime); + abstract Read build(); + } + + /** + * Define the AMQP addresses where to receive messages. + */ + public Read withAddresses(List addresses) { + checkArgument(addresses != null, "AmqpIO.read().withAddresses(addresses) called with null" + + " addresses"); + checkArgument(!addresses.isEmpty(), "AmqpIO.read().withAddresses(addresses) called with " + + "empty addresses list"); + return builder().setAddresses(addresses).build(); + } + + /** + * Define the max number of records received by the {@link Read}. + * When the max number of records is lower than {@code Long.MAX_VALUE}, the {@link Read} will + * provide a bounded {@link PCollection}. + */ + public Read withMaxNumRecords(long maxNumRecords) { + checkArgument(maxReadTime() == null, + "maxNumRecord and maxReadTime are exclusive"); + return builder().setMaxNumRecords(maxNumRecords).build(); + } + + /** + * Define the max read time (duration) while the {@link Read} will receive messages. + * When this max read time is not null, the {@link Read} will provide a bounded + * {@link PCollection}. + */ + public Read withMaxReadTime(Duration maxReadTime) { + checkArgument(maxNumRecords() == Long.MAX_VALUE, + "maxNumRecord and maxReadTime are exclusive"); + return builder().setMaxReadTime(maxReadTime).build(); + } + + @Override + public void validate(PipelineOptions pipelineOptions) { + checkState(addresses() != null, "AmqIO.read() requires addresses list to be set via " + + "withAddresses(addresses)"); + checkState(!addresses().isEmpty(), "AmqIO.read() requires a non-empty addresses list to be" + + " set via withAddresses(addresses)"); + } + + @Override + public void populateDisplayData(DisplayData.Builder builder) { + builder.add(DisplayData.item("addresses", Joiner.on(" ").join(addresses()))); + } + + @Override + public PCollection expand(PBegin input) { + org.apache.beam.sdk.io.Read.Unbounded unbounded = + org.apache.beam.sdk.io.Read.from(new UnboundedAmqpSource(this)); + + PTransform> transform = unbounded; + + if (maxNumRecords() != Long.MAX_VALUE) { + transform = unbounded.withMaxNumRecords(maxNumRecords()); + } else if (maxReadTime() != null) { + transform = unbounded.withMaxReadTime(maxReadTime()); + } + + return input.getPipeline().apply(transform); + } + + } + + private static class AmqpCheckpointMark implements UnboundedSource.CheckpointMark, Serializable { + + private transient Messenger messenger; + private transient List trackers = new ArrayList<>(); + + public AmqpCheckpointMark() { + } + + @Override + public void finalizeCheckpoint() { + for (Tracker tracker : trackers) { + // flag as not cumulative + messenger.accept(tracker, 0); + } + trackers.clear(); + } + + // set an empty list to messages when deserialize + private void readObject(java.io.ObjectInputStream stream) + throws java.io.IOException, ClassNotFoundException { + trackers = new ArrayList<>(); + } + + } + + private static class UnboundedAmqpSource + extends UnboundedSource { + + private final Read spec; + + public UnboundedAmqpSource(Read spec) { + this.spec = spec; + } + + @Override + public List split(int desiredNumSplits, + PipelineOptions pipelineOptions) { + // amqp is a queue system, so, it's possible to have multiple concurrent sources, even if + // they bind the listener + List sources = new ArrayList<>(); + for (int i = 0; i < Math.max(1, desiredNumSplits); ++i) { + sources.add(new UnboundedAmqpSource(spec)); + } + return sources; + } + + @Override + public UnboundedReader createReader(PipelineOptions pipelineOptions, + AmqpCheckpointMark checkpointMark) { + return new UnboundedAmqpReader(this, checkpointMark); + } + + @Override + public Coder getDefaultOutputCoder() { + return new AmqpMessageCoder(); + } + + @Override + public Coder getCheckpointMarkCoder() { + return SerializableCoder.of(AmqpCheckpointMark.class); + } + + @Override + public void validate() { + spec.validate(null); + } + + } + + private static class UnboundedAmqpReader extends UnboundedSource.UnboundedReader { + + private final UnboundedAmqpSource source; + + private Messenger messenger; + private Message current; + private Instant currentTimestamp; + private Instant watermark = new Instant(Long.MIN_VALUE); + private AmqpCheckpointMark checkpointMark; + + public UnboundedAmqpReader(UnboundedAmqpSource source, AmqpCheckpointMark checkpointMark) { + this.source = source; + this.current = null; + if (checkpointMark != null) { + this.checkpointMark = checkpointMark; + } else { + this.checkpointMark = new AmqpCheckpointMark(); + } + } + + @Override + public Instant getWatermark() { + return watermark; + } + + @Override + public Instant getCurrentTimestamp() { + if (current == null) { + throw new NoSuchElementException(); + } + return currentTimestamp; + } + + @Override + public Message getCurrent() { + if (current == null) { + throw new NoSuchElementException(); + } + return current; + } + + @Override + public UnboundedSource.CheckpointMark getCheckpointMark() { + return checkpointMark; + } + + @Override + public UnboundedAmqpSource getCurrentSource() { + return source; + } + + @Override + public boolean start() throws IOException { + Read spec = source.spec; + messenger = Messenger.Factory.create(); + messenger.start(); + for (String address : spec.addresses()) { + messenger.subscribe(address); + } + checkpointMark.messenger = messenger; + return advance(); + } + + @Override + public boolean advance() { + messenger.recv(); + if (messenger.incoming() <= 0) { + current = null; + return false; + } + Message message = messenger.get(); + Tracker tracker = messenger.incomingTracker(); + checkpointMark.trackers.add(tracker); + currentTimestamp = new Instant(message.getCreationTime()); + watermark = currentTimestamp; + current = message; + return true; + } + + @Override + public void close() { + if (messenger != null) { + messenger.stop(); + } + } + + } + + /** + * A {@link PTransform} to send messages using AMQP 1.0 protocol. + */ + @AutoValue + public abstract static class Write extends PTransform, PDone> { + + @Override + public PDone expand(PCollection input) { + input.apply(ParDo.of(new WriteFn(this))); + return PDone.in(input.getPipeline()); + } + + private static class WriteFn extends DoFn { + + private final Write spec; + + private transient Messenger messenger; + + public WriteFn(Write spec) { + this.spec = spec; + } + + @Setup + public void setup() throws Exception { + messenger = Messenger.Factory.create(); + messenger.start(); + } + + @ProcessElement + public void processElement(ProcessContext processContext) throws Exception { + Message message = processContext.element(); + messenger.put(message); + messenger.send(); + } + + @Teardown + public void teardown() throws Exception { + if (messenger != null) { + messenger.stop(); + } + } + + } + + } + +} diff --git a/sdks/java/io/amqp/src/main/java/org/apache/beam/sdk/io/amqp/AmqpMessageCoder.java b/sdks/java/io/amqp/src/main/java/org/apache/beam/sdk/io/amqp/AmqpMessageCoder.java new file mode 100644 index 0000000000000..5a552600168f7 --- /dev/null +++ b/sdks/java/io/amqp/src/main/java/org/apache/beam/sdk/io/amqp/AmqpMessageCoder.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.amqp; + +import com.google.common.io.ByteStreams; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.BufferOverflowException; + +import org.apache.beam.sdk.coders.CoderException; +import org.apache.beam.sdk.coders.CustomCoder; +import org.apache.beam.sdk.util.VarInt; +import org.apache.qpid.proton.message.Message; + +/** + * A coder for AMQP message. + */ +public class AmqpMessageCoder extends CustomCoder { + + private static final int[] MESSAGE_SIZES = new int[]{ + 8 * 1024, + 64 * 1024, + 1 * 1024 * 1024, + 64 * 1024 * 1024 + }; + + static AmqpMessageCoder of() { + return new AmqpMessageCoder(); + } + + @Override + public void encode(Message value, OutputStream outStream) throws CoderException, IOException { + for (int maxMessageSize : MESSAGE_SIZES) { + try { + encode(value, outStream, maxMessageSize); + return; + } catch (Exception e) { + continue; + } + } + throw new CoderException("Message is larger than the max size supported by the coder"); + } + + private void encode(Message value, OutputStream outStream, int messageSize) throws + IOException, BufferOverflowException { + byte[] data = new byte[messageSize]; + int bytesWritten = value.encode(data, 0, data.length); + VarInt.encode(bytesWritten, outStream); + outStream.write(data, 0, bytesWritten); + } + + @Override + public Message decode(InputStream inStream) throws CoderException, IOException { + Message message = Message.Factory.create(); + int bytesToRead = VarInt.decodeInt(inStream); + byte[] encodedMessage = new byte[bytesToRead]; + ByteStreams.readFully(inStream, encodedMessage); + message.decode(encodedMessage, 0, encodedMessage.length); + return message; + } + +} diff --git a/sdks/java/io/amqp/src/main/java/org/apache/beam/sdk/io/amqp/AmqpMessageCoderProviderRegistrar.java b/sdks/java/io/amqp/src/main/java/org/apache/beam/sdk/io/amqp/AmqpMessageCoderProviderRegistrar.java new file mode 100644 index 0000000000000..bc3445cf97814 --- /dev/null +++ b/sdks/java/io/amqp/src/main/java/org/apache/beam/sdk/io/amqp/AmqpMessageCoderProviderRegistrar.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.amqp; + +import com.google.auto.service.AutoService; +import com.google.common.collect.ImmutableList; + +import java.util.List; + +import org.apache.beam.sdk.coders.CoderProvider; +import org.apache.beam.sdk.coders.CoderProviderRegistrar; +import org.apache.beam.sdk.coders.CoderProviders; +import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.qpid.proton.message.Message; + +/** + * A {@link CoderProviderRegistrar} for standard types used with {@link AmqpIO}. + */ +@AutoService(CoderProviderRegistrar.class) +public class AmqpMessageCoderProviderRegistrar implements CoderProviderRegistrar { + + @Override + public List getCoderProviders() { + return ImmutableList.of( + CoderProviders.forCoder(TypeDescriptor.of(Message.class), + AmqpMessageCoder.of())); + } + +} diff --git a/sdks/java/io/amqp/src/main/java/org/apache/beam/sdk/io/amqp/package-info.java b/sdks/java/io/amqp/src/main/java/org/apache/beam/sdk/io/amqp/package-info.java new file mode 100644 index 0000000000000..091f23424a3af --- /dev/null +++ b/sdks/java/io/amqp/src/main/java/org/apache/beam/sdk/io/amqp/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Transforms for reading and writing using AMQP 1.0 protocol. + */ +package org.apache.beam.sdk.io.amqp; diff --git a/sdks/java/io/amqp/src/test/java/org/apache/beam/sdk/io/amqp/AmqpIOTest.java b/sdks/java/io/amqp/src/test/java/org/apache/beam/sdk/io/amqp/AmqpIOTest.java new file mode 100644 index 0000000000000..c8fe4e80f834e --- /dev/null +++ b/sdks/java/io/amqp/src/test/java/org/apache/beam/sdk/io/amqp/AmqpIOTest.java @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.amqp; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import java.net.ServerSocket; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Count; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.values.PCollection; +import org.apache.qpid.proton.amqp.messaging.AmqpValue; +import org.apache.qpid.proton.message.Message; +import org.apache.qpid.proton.messenger.Messenger; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Tests on {@link AmqpIO}. + */ +@RunWith(JUnit4.class) +public class AmqpIOTest { + + private static final Logger LOG = LoggerFactory.getLogger(AmqpIOTest.class); + + private int port; + + @Rule public TestPipeline pipeline = TestPipeline.create(); + + @Before + public void findFreeNetworkPort() throws Exception { + LOG.info("Finding free network port"); + ServerSocket socket = new ServerSocket(0); + port = socket.getLocalPort(); + socket.close(); + } + + @Test + public void testRead() throws Exception { + PCollection output = pipeline.apply(AmqpIO.read() + .withMaxNumRecords(100) + .withAddresses(Collections.singletonList("amqp://~localhost:" + port))); + PAssert.thatSingleton(output.apply(Count.globally())).isEqualTo(100L); + + Thread sender = new Thread() { + public void run() { + try { + Thread.sleep(500); + Messenger sender = Messenger.Factory.create(); + sender.start(); + for (int i = 0; i < 100; i++) { + Message message = Message.Factory.create(); + message.setAddress("amqp://localhost:" + port); + message.setBody(new AmqpValue("Test " + i)); + sender.put(message); + sender.send(); + } + sender.stop(); + } catch (Exception e) { + LOG.error("Sender error", e); + } + } + }; + try { + sender.start(); + pipeline.run(); + } finally { + sender.join(); + } + } + + @Test + public void testWrite() throws Exception { + final List received = new ArrayList<>(); + Thread receiver = new Thread() { + @Override + public void run() { + try { + Messenger messenger = Messenger.Factory.create(); + messenger.start(); + messenger.subscribe("amqp://~localhost:" + port); + while (received.size() < 100) { + messenger.recv(); + while (messenger.incoming() > 0) { + Message message = messenger.get(); + LOG.info("Received: " + message.getBody().toString()); + received.add(message.getBody().toString()); + } + } + messenger.stop(); + } catch (Exception e) { + LOG.error("Receiver error", e); + } + } + }; + LOG.info("Starting AMQP receiver"); + receiver.start(); + + List data = new ArrayList<>(); + for (int i = 0; i < 100; i++) { + Message message = Message.Factory.create(); + message.setBody(new AmqpValue("Test " + i)); + message.setAddress("amqp://localhost:" + port); + message.setSubject("test"); + data.add(message); + } + pipeline.apply(Create.of(data).withCoder(AmqpMessageCoder.of())).apply(AmqpIO.write()); + LOG.info("Starting pipeline"); + try { + pipeline.run(); + } finally { + LOG.info("Join receiver thread"); + receiver.join(); + } + + assertEquals(100, received.size()); + for (int i = 0; i < 100; i++) { + assertTrue(received.contains("AmqpValue{Test " + i + "}")); + } + } + +} diff --git a/sdks/java/io/amqp/src/test/java/org/apache/beam/sdk/io/amqp/AmqpMessageCoderTest.java b/sdks/java/io/amqp/src/test/java/org/apache/beam/sdk/io/amqp/AmqpMessageCoderTest.java new file mode 100644 index 0000000000000..7a8efeb61c1ce --- /dev/null +++ b/sdks/java/io/amqp/src/test/java/org/apache/beam/sdk/io/amqp/AmqpMessageCoderTest.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.amqp; + +import static org.junit.Assert.assertEquals; + +import com.google.common.base.Joiner; + +import java.util.Collections; + +import org.apache.beam.sdk.coders.CoderException; +import org.apache.beam.sdk.util.CoderUtils; +import org.apache.qpid.proton.amqp.messaging.AmqpValue; +import org.apache.qpid.proton.message.Message; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Test on {@link AmqpMessageCoder}. + */ +@RunWith(JUnit4.class) +public class AmqpMessageCoderTest { + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + @Test + public void encodeDecode() throws Exception { + Message message = Message.Factory.create(); + message.setBody(new AmqpValue("body")); + message.setAddress("address"); + message.setSubject("test"); + AmqpMessageCoder coder = AmqpMessageCoder.of(); + + Message clone = CoderUtils.clone(coder, message); + + assertEquals("AmqpValue{body}", clone.getBody().toString()); + assertEquals("address", clone.getAddress()); + assertEquals("test", clone.getSubject()); + } + + @Test + public void encodeDecodeTooMuchLargerMessage() throws Exception { + thrown.expect(CoderException.class); + Message message = Message.Factory.create(); + message.setAddress("address"); + message.setSubject("subject"); + String body = Joiner.on("").join(Collections.nCopies(64 * 1024 * 1024, " ")); + message.setBody(new AmqpValue(body)); + + AmqpMessageCoder coder = AmqpMessageCoder.of(); + + byte[] encoded = CoderUtils.encodeToByteArray(coder, message); + } + + @Test + public void encodeDecodeLargeMessage() throws Exception { + Message message = Message.Factory.create(); + message.setAddress("address"); + message.setSubject("subject"); + String body = Joiner.on("").join(Collections.nCopies(32 * 1024 * 1024, " ")); + message.setBody(new AmqpValue(body)); + + AmqpMessageCoder coder = AmqpMessageCoder.of(); + + Message clone = CoderUtils.clone(coder, message); + + clone.getBody().toString().equals(message.getBody().toString()); + } + +} diff --git a/sdks/java/io/pom.xml b/sdks/java/io/pom.xml index 13cd418355daa..e5db41b726292 100644 --- a/sdks/java/io/pom.xml +++ b/sdks/java/io/pom.xml @@ -64,6 +64,7 @@ + amqp cassandra common elasticsearch From 01b3febc62debf8b4fd1582225d59768d231601a Mon Sep 17 00:00:00 2001 From: manuzhang Date: Mon, 26 Jun 2017 20:55:13 +0800 Subject: [PATCH 116/200] [BEAM-2514] Improve error message on missing required value --- .../sdk/options/PipelineOptionsFactory.java | 18 +++++--- .../sdk/options/PipelineOptionsValidator.java | 34 ++++++++++++-- .../sdk/options/ProxyInvocationHandler.java | 4 ++ .../options/PipelineOptionsValidatorTest.java | 44 +++++++++++++++++++ .../options/ProxyInvocationHandlerTest.java | 7 +++ 5 files changed, 97 insertions(+), 10 deletions(-) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/options/PipelineOptionsFactory.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/options/PipelineOptionsFactory.java index c0990cb108a86..d7e6cc84ee7ae 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/options/PipelineOptionsFactory.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/options/PipelineOptionsFactory.java @@ -184,18 +184,20 @@ public static class Builder { private final String[] args; private final boolean validation; private final boolean strictParsing; + private final boolean isCli; // Do not allow direct instantiation private Builder() { - this(null, false, true); + this(null, false, true, false); } private Builder(String[] args, boolean validation, - boolean strictParsing) { + boolean strictParsing, boolean isCli) { this.defaultAppName = findCallersClassName(); this.args = args; this.validation = validation; this.strictParsing = strictParsing; + this.isCli = isCli; } /** @@ -237,7 +239,7 @@ private Builder(String[] args, boolean validation, */ public Builder fromArgs(String... args) { checkNotNull(args, "Arguments should not be null."); - return new Builder(args, validation, strictParsing); + return new Builder(args, validation, strictParsing, true); } /** @@ -247,7 +249,7 @@ public Builder fromArgs(String... args) { * validation. */ public Builder withValidation() { - return new Builder(args, true, strictParsing); + return new Builder(args, true, strictParsing, isCli); } /** @@ -255,7 +257,7 @@ public Builder withValidation() { * arguments. */ public Builder withoutStrictParsing() { - return new Builder(args, validation, false); + return new Builder(args, validation, false, isCli); } /** @@ -300,7 +302,11 @@ public T as(Class klass) { } if (validation) { - PipelineOptionsValidator.validate(klass, t); + if (isCli) { + PipelineOptionsValidator.validateCli(klass, t); + } else { + PipelineOptionsValidator.validate(klass, t); + } } return t; } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/options/PipelineOptionsValidator.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/options/PipelineOptionsValidator.java index bd54ec39bd74b..fcffd74c7a7eb 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/options/PipelineOptionsValidator.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/options/PipelineOptionsValidator.java @@ -43,9 +43,29 @@ public class PipelineOptionsValidator { * * @param klass The interface to fetch validation criteria from. * @param options The {@link PipelineOptions} to validate. - * @return The type + * @return Validated options. */ public static T validate(Class klass, PipelineOptions options) { + return validate(klass, options, false); + } + + /** + * Validates that the passed {@link PipelineOptions} from command line interface (CLI) + * conforms to all the validation criteria from the passed in interface. + * + *

    Note that the interface requested must conform to the validation criteria specified on + * {@link PipelineOptions#as(Class)}. + * + * @param klass The interface to fetch validation criteria from. + * @param options The {@link PipelineOptions} to validate. + * @return Validated options. + */ + public static T validateCli(Class klass, PipelineOptions options) { + return validate(klass, options, true); + } + + private static T validate(Class klass, PipelineOptions options, + boolean isCli) { checkNotNull(klass); checkNotNull(options); checkArgument(Proxy.isProxyClass(options.getClass())); @@ -67,9 +87,15 @@ public static T validate(Class klass, PipelineOpt requiredGroups.put(requiredGroup, method); } } else { - checkArgument(handler.invoke(asClassOptions, method, null) != null, - "Missing required value for [%s, \"%s\"]. ", - method, getDescription(method)); + if (isCli) { + checkArgument(handler.invoke(asClassOptions, method, null) != null, + "Missing required value for [--%s, \"%s\"]. ", + handler.getOptionName(method), getDescription(method)); + } else { + checkArgument(handler.invoke(asClassOptions, method, null) != null, + "Missing required value for [%s, \"%s\"]. ", + method, getDescription(method)); + } } } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/options/ProxyInvocationHandler.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/options/ProxyInvocationHandler.java index 3842388e8c0b9..926a7b957c300 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/options/ProxyInvocationHandler.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/options/ProxyInvocationHandler.java @@ -166,6 +166,10 @@ public Object invoke(Object proxy, Method method, Object[] args) { + Arrays.toString(args) + "]."); } + public String getOptionName(Method method) { + return gettersToPropertyNames.get(method.getName()); + } + private void writeObject(java.io.ObjectOutputStream stream) throws IOException { throw new NotSerializableException( diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/options/PipelineOptionsValidatorTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/options/PipelineOptionsValidatorTest.java index 120d5eddf628c..f8cd00f3de696 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/options/PipelineOptionsValidatorTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/options/PipelineOptionsValidatorTest.java @@ -59,6 +59,18 @@ public void testWhenRequiredOptionIsSetAndCleared() { PipelineOptionsValidator.validate(Required.class, required); } + @Test + public void testWhenRequiredOptionIsSetAndClearedCli() { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("Missing required value for " + + "[--object, \"Fake Description\"]."); + + Required required = PipelineOptionsFactory.fromArgs(new String[]{"--object=blah"}) + .as(Required.class); + required.setObject(null); + PipelineOptionsValidator.validateCli(Required.class, required); + } + @Test public void testWhenRequiredOptionIsNeverSet() { expectedException.expect(IllegalArgumentException.class); @@ -70,6 +82,17 @@ public void testWhenRequiredOptionIsNeverSet() { PipelineOptionsValidator.validate(Required.class, required); } + + @Test + public void testWhenRequiredOptionIsNeverSetCli() { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("Missing required value for " + + "[--object, \"Fake Description\"]."); + + Required required = PipelineOptionsFactory.fromArgs(new String[]{}).as(Required.class); + PipelineOptionsValidator.validateCli(Required.class, required); + } + @Test public void testWhenRequiredOptionIsNeverSetOnSuperInterface() { expectedException.expect(IllegalArgumentException.class); @@ -81,6 +104,16 @@ public void testWhenRequiredOptionIsNeverSetOnSuperInterface() { PipelineOptionsValidator.validate(Required.class, options); } + @Test + public void testWhenRequiredOptionIsNeverSetOnSuperInterfaceCli() { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("Missing required value for " + + "[--object, \"Fake Description\"]."); + + PipelineOptions options = PipelineOptionsFactory.fromArgs(new String[]{}).create(); + PipelineOptionsValidator.validateCli(Required.class, options); + } + /** A test interface that overrides the parent's method. */ public interface SubClassValidation extends Required { @Override @@ -100,6 +133,17 @@ public void testValidationOnOverriddenMethods() throws Exception { PipelineOptionsValidator.validate(Required.class, required); } + @Test + public void testValidationOnOverriddenMethodsCli() throws Exception { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("Missing required value for " + + "[--object, \"Fake Description\"]."); + + SubClassValidation required = PipelineOptionsFactory.fromArgs(new String[]{}) + .as(SubClassValidation.class); + PipelineOptionsValidator.validateCli(Required.class, required); + } + /** A test interface with a required group. */ public interface GroupRequired extends PipelineOptions { @Validation.Required(groups = {"ham"}) diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/options/ProxyInvocationHandlerTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/options/ProxyInvocationHandlerTest.java index d90cb4210139c..fb0a0d7e2e715 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/options/ProxyInvocationHandlerTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/options/ProxyInvocationHandlerTest.java @@ -1031,4 +1031,11 @@ public void testOptionsAreNotSerializable() { expectedException.expectCause(Matchers.instanceOf(NotSerializableException.class)); SerializableUtils.clone(new CapturesOptions()); } + + @Test + public void testGetOptionNameFromMethod() throws NoSuchMethodException { + ProxyInvocationHandler handler = new ProxyInvocationHandler(Maps.newHashMap()); + handler.as(BaseOptions.class); + assertEquals("foo", handler.getOptionName(BaseOptions.class.getMethod("getFoo"))); + } } From d33d64644c8ddc74a3dc07bf0c642a9ed818a624 Mon Sep 17 00:00:00 2001 From: Nigel Kilmer Date: Wed, 21 Jun 2017 11:26:10 -0700 Subject: [PATCH 117/200] Removed uses of proto builder clone method --- .../beam/sdk/io/gcp/bigtable/BigtableServiceImpl.java | 9 ++++----- .../beam/sdk/io/gcp/datastore/DatastoreV1Test.java | 4 ++-- .../org/apache/beam/sdk/io/gcp/datastore/V1TestUtil.java | 2 +- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableServiceImpl.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableServiceImpl.java index d1a17feab6e53..07476e2f3e642 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableServiceImpl.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableServiceImpl.java @@ -168,14 +168,13 @@ private static class BigtableWriterImpl implements Writer { private BigtableSession session; private AsyncExecutor executor; private BulkMutation bulkMutation; - private final MutateRowRequest.Builder partialBuilder; + private final String tableName; public BigtableWriterImpl(BigtableSession session, BigtableTableName tableName) { this.session = session; executor = session.createAsyncExecutor(); bulkMutation = session.createBulkMutation(tableName, executor); - - partialBuilder = MutateRowRequest.newBuilder().setTableName(tableName.toString()); + this.tableName = tableName.toString(); } @Override @@ -208,8 +207,8 @@ public ListenableFuture writeRecord( KV> record) throws IOException { MutateRowRequest r = - partialBuilder - .clone() + MutateRowRequest.newBuilder() + .setTableName(tableName) .setRowKey(record.getKey()) .addAllMutations(record.getValue()) .build(); diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/datastore/DatastoreV1Test.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/datastore/DatastoreV1Test.java index 946887c865e3c..a3f5d38ae886b 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/datastore/DatastoreV1Test.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/datastore/DatastoreV1Test.java @@ -783,7 +783,7 @@ public void testSplitQueryFnWithoutNumSplits() throws Exception { */ @Test public void testSplitQueryFnWithQueryLimit() throws Exception { - Query queryWithLimit = QUERY.toBuilder().clone() + Query queryWithLimit = QUERY.toBuilder() .setLimit(Int32Value.newBuilder().setValue(1)) .build(); @@ -1079,7 +1079,7 @@ private static Query makeLatestTimestampQuery(String namespace) { private List splitQuery(Query query, int numSplits) { List queries = new LinkedList<>(); for (int i = 0; i < numSplits; i++) { - queries.add(query.toBuilder().clone().build()); + queries.add(query.toBuilder().build()); } return queries; } diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/datastore/V1TestUtil.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/datastore/V1TestUtil.java index 5e618dfaabd2e..cd6122956a169 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/datastore/V1TestUtil.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/datastore/V1TestUtil.java @@ -374,7 +374,7 @@ boolean advance() throws IOException { // Read the next batch of query results. private Iterator getIteratorAndMoveCursor() throws DatastoreException { - Query.Builder query = this.query.toBuilder().clone(); + Query.Builder query = this.query.toBuilder(); query.setLimit(Int32Value.newBuilder().setValue(QUERY_BATCH_LIMIT)); if (currentBatch != null && !currentBatch.getEndCursor().isEmpty()) { query.setStartCursor(currentBatch.getEndCursor()); From 8dd0077d2a58e278b11c7e7eb4b5f182e1400992 Mon Sep 17 00:00:00 2001 From: Vikas Kedigehalli Date: Mon, 26 Jun 2017 18:47:39 -0700 Subject: [PATCH 118/200] Enable grpc controller in fn_api_runner --- .../runners/portability/fn_api_runner.py | 12 +++++++--- .../runners/portability/fn_api_runner_test.py | 23 ++++++++++++++++++- .../apache_beam/runners/worker/sdk_worker.py | 2 +- 3 files changed, 32 insertions(+), 5 deletions(-) diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner.py b/sdks/python/apache_beam/runners/portability/fn_api_runner.py index a8e2eb4573a1f..c5438adbdcf0b 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner.py @@ -174,12 +174,17 @@ def get_outputs(op_ix): return {tag: pcollection_id(op_ix, out_ix) for out_ix, tag in enumerate(getattr(op, 'output_tags', ['out']))} + def only_element(iterable): + element, = iterable + return element + for op_ix, (stage_name, operation) in enumerate(map_task): transform_id = uniquify(stage_name) if isinstance(operation, operation_specs.WorkerInMemoryWrite): # Write this data back to the runner. - runner_sinks[(transform_id, 'out')] = operation + target_name = only_element(get_inputs(operation).keys()) + runner_sinks[(transform_id, target_name)] = operation transform_spec = beam_runner_api_pb2.FunctionSpec( urn=sdk_worker.DATA_OUTPUT_URN, parameter=proto_utils.pack_Any(data_operation_spec)) @@ -190,7 +195,8 @@ def get_outputs(op_ix): maptask_executor_runner.InMemorySource) and isinstance(operation.source.source.default_output_coder(), WindowedValueCoder)): - input_data[(transform_id, 'input')] = self._reencode_elements( + target_name = only_element(get_outputs(op_ix).keys()) + input_data[(transform_id, target_name)] = self._reencode_elements( operation.source.source.read(None), operation.source.source.default_output_coder()) transform_spec = beam_runner_api_pb2.FunctionSpec( @@ -309,7 +315,7 @@ def _run_map_task( sink_op.output_buffer.append(e) return - def execute_map_tasks(self, ordered_map_tasks, direct=True): + def execute_map_tasks(self, ordered_map_tasks, direct=False): if direct: controller = FnApiRunner.DirectController() else: diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py index 91590351e99ef..163e98029467c 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py @@ -21,6 +21,8 @@ import apache_beam as beam from apache_beam.runners.portability import fn_api_runner from apache_beam.runners.portability import maptask_executor_runner_test +from apache_beam.testing.util import assert_that +from apache_beam.testing.util import equal_to class FnApiRunnerTest( @@ -31,9 +33,28 @@ def create_pipeline(self): runner=fn_api_runner.FnApiRunner()) def test_combine_per_key(self): - # TODO(robertwb): Implement PGBKCV operation. + # TODO(BEAM-1348): Enable once Partial GBK is supported in fn API. pass + def test_combine_per_key(self): + # TODO(BEAM-1348): Enable once Partial GBK is supported in fn API. + pass + + def test_pardo_side_inputs(self): + # TODO(BEAM-1348): Enable once side inputs are supported in fn API. + pass + + def test_pardo_unfusable_side_inputs(self): + # TODO(BEAM-1348): Enable once side inputs are supported in fn API. + pass + + def test_assert_that(self): + # TODO: figure out a way for fn_api_runner to parse and raise the + # underlying exception. + with self.assertRaisesRegexp(RuntimeError, 'BeamAssertException'): + with self.create_pipeline() as p: + assert_that(p | beam.Create(['a', 'b']), equal_to(['a'])) + # Inherits all tests from maptask_executor_runner.MapTaskExecutorRunner diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker.py b/sdks/python/apache_beam/runners/worker/sdk_worker.py index 6a366ebcb2be1..e1ddfb7807aa8 100644 --- a/sdks/python/apache_beam/runners/worker/sdk_worker.py +++ b/sdks/python/apache_beam/runners/worker/sdk_worker.py @@ -415,7 +415,7 @@ def create(factory, transform_id, transform_proto, grpc_port, consumers): def create(factory, transform_id, transform_proto, grpc_port, consumers): target = beam_fn_api_pb2.Target( primitive_transform_reference=transform_id, - name='out') + name=only_element(transform_proto.inputs.keys())) return DataOutputOperation( transform_proto.unique_name, transform_proto.unique_name, From 0c34db9b680c5cddf9874a7ccc86008c109068e0 Mon Sep 17 00:00:00 2001 From: Stepan Kadlec Date: Tue, 27 Jun 2017 09:12:47 -0700 Subject: [PATCH 119/200] [BEAM-2522] upgrading jackson to 2.8.9 (mitigating #1599) --- pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index 98cace95808ab..29bb4ebed18e9 100644 --- a/pom.xml +++ b/pom.xml @@ -128,7 +128,7 @@ 1.2.0 0.1.9 1.3 - 2.8.8 + 2.8.9 3.0.1 2.4 4.12 From d6855ac6797b8f83bf57b6ccdaf20bbf3db316c6 Mon Sep 17 00:00:00 2001 From: Ahmet Altay Date: Mon, 26 Jun 2017 23:22:36 -0700 Subject: [PATCH 120/200] Use installed distribution name for sdk name --- .../apache_beam/runners/dataflow/internal/dependency.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/sdks/python/apache_beam/runners/dataflow/internal/dependency.py b/sdks/python/apache_beam/runners/dataflow/internal/dependency.py index 6d4a703bba1b6..03e17940e5943 100644 --- a/sdks/python/apache_beam/runners/dataflow/internal/dependency.py +++ b/sdks/python/apache_beam/runners/dataflow/internal/dependency.py @@ -500,11 +500,13 @@ def get_sdk_name_and_version(): """For internal use only; no backwards-compatibility guarantees. Returns name and version of SDK reported to Google Cloud Dataflow.""" - # TODO(ccy): Make this check cleaner. + import pkg_resources as pkg container_version = get_required_container_version() - if container_version == BEAM_CONTAINER_VERSION: + try: + pkg.get_distribution(GOOGLE_PACKAGE_NAME) + return ('Google Cloud Dataflow SDK for Python', container_version) + except pkg.DistributionNotFound: return ('Apache Beam SDK for Python', beam_version.__version__) - return ('Google Cloud Dataflow SDK for Python', container_version) def get_sdk_package_name(): From 7fee4b93d5b548d390ab2511a91880b4c5e57a26 Mon Sep 17 00:00:00 2001 From: Thomas Groh Date: Tue, 27 Jun 2017 14:23:22 -0700 Subject: [PATCH 121/200] Add WindowFn#assignsToOneWindow --- .../beam/sdk/testing/StaticWindows.java | 5 ++++ .../transforms/windowing/GlobalWindows.java | 5 ++++ .../windowing/PartitioningWindowFn.java | 5 ++++ .../transforms/windowing/SlidingWindows.java | 5 ++++ .../sdk/transforms/windowing/WindowFn.java | 11 +++++++ .../beam/sdk/util/IdentityWindowFn.java | 5 ++++ .../windowing/SlidingWindowsTest.java | 30 +++++++++++++++---- 7 files changed, 61 insertions(+), 5 deletions(-) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/StaticWindows.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/StaticWindows.java index c11057a5c6236..eba6978744c18 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/StaticWindows.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/StaticWindows.java @@ -126,4 +126,9 @@ public BoundedWindow getSideInputWindow(BoundedWindow mainWindow) { } }; } + + @Override + public boolean assignsToOneWindow() { + return true; + } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/GlobalWindows.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/GlobalWindows.java index d48d26b1807c0..c68c497deb09a 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/GlobalWindows.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/GlobalWindows.java @@ -78,6 +78,11 @@ public Instant getOutputTime(Instant inputTimestamp, GlobalWindow window) { return inputTimestamp; } + @Override + public boolean assignsToOneWindow() { + return true; + } + @Override public boolean equals(Object other) { return other instanceof GlobalWindows; diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/PartitioningWindowFn.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/PartitioningWindowFn.java index 40ee68aae7f4f..341ba27ec8cda 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/PartitioningWindowFn.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/PartitioningWindowFn.java @@ -58,4 +58,9 @@ public W getSideInputWindow(BoundedWindow mainWindow) { public Instant getOutputTime(Instant inputTimestamp, W window) { return inputTimestamp; } + + @Override + public final boolean assignsToOneWindow() { + return true; + } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/SlidingWindows.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/SlidingWindows.java index f65788438021d..150b95633efa9 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/SlidingWindows.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/SlidingWindows.java @@ -147,6 +147,11 @@ public boolean isCompatible(WindowFn other) { return equals(other); } + @Override + public boolean assignsToOneWindow() { + return !this.period.isShorterThan(this.size); + } + @Override public void verifyCompatibility(WindowFn other) throws IncompatibleWindowException { if (!this.isCompatible(other)) { diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/WindowFn.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/WindowFn.java index 001d63014ec9e..ffe85f3cf6c9b 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/WindowFn.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/WindowFn.java @@ -179,6 +179,17 @@ public boolean isNonMerging() { return false; } + /** + * Returns true if this {@link WindowFn} always assigns an element to exactly one window. + * + *

    If this varies per-element, or cannot be determined, conservatively return false. + * + *

    By default, returns false. + */ + public boolean assignsToOneWindow() { + return false; + } + /** * Returns a {@link TypeDescriptor} capturing what is known statically about the window type of * this {@link WindowFn} instance's most-derived class. diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/IdentityWindowFn.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/IdentityWindowFn.java index a4bfdda4f8134..ef6d83397f45a 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/IdentityWindowFn.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/IdentityWindowFn.java @@ -116,4 +116,9 @@ public WindowMappingFn getDefaultWindowMappingFn() { public Instant getOutputTime(Instant inputTimestamp, BoundedWindow window) { return inputTimestamp; } + + @Override + public boolean assignsToOneWindow() { + return true; + } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/windowing/SlidingWindowsTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/windowing/SlidingWindowsTest.java index b14e2215173da..bfd01f02ed26d 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/windowing/SlidingWindowsTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/windowing/SlidingWindowsTest.java @@ -21,6 +21,7 @@ import static org.apache.beam.sdk.testing.WindowFnTestUtils.set; import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertThat; @@ -55,11 +56,12 @@ public void testSimple() throws Exception { expected.put(new IntervalWindow(new Instant(0), new Instant(10)), set(1, 2, 5, 9)); expected.put(new IntervalWindow(new Instant(5), new Instant(15)), set(5, 9, 10, 11)); expected.put(new IntervalWindow(new Instant(10), new Instant(20)), set(10, 11)); + SlidingWindows windowFn = SlidingWindows.of(new Duration(10)).every(new Duration(5)); assertEquals( expected, - runWindowFn( - SlidingWindows.of(new Duration(10)).every(new Duration(5)), + runWindowFn(windowFn, Arrays.asList(1L, 2L, 5L, 9L, 10L, 11L))); + assertThat(windowFn.assignsToOneWindow(), is(false)); } @Test @@ -69,11 +71,27 @@ public void testSlightlyOverlapping() throws Exception { expected.put(new IntervalWindow(new Instant(0), new Instant(7)), set(1, 2, 5)); expected.put(new IntervalWindow(new Instant(5), new Instant(12)), set(5, 9, 10, 11)); expected.put(new IntervalWindow(new Instant(10), new Instant(17)), set(10, 11)); + SlidingWindows windowFn = SlidingWindows.of(new Duration(7)).every(new Duration(5)); assertEquals( expected, - runWindowFn( - SlidingWindows.of(new Duration(7)).every(new Duration(5)), + runWindowFn(windowFn, Arrays.asList(1L, 2L, 5L, 9L, 10L, 11L))); + assertThat(windowFn.assignsToOneWindow(), is(false)); + } + + @Test + public void testEqualSize() throws Exception { + Map> expected = new HashMap<>(); + expected.put(new IntervalWindow(new Instant(0), new Instant(3)), set(1, 2)); + expected.put(new IntervalWindow(new Instant(3), new Instant(6)), set(3, 4, 5)); + expected.put(new IntervalWindow(new Instant(6), new Instant(9)), set(6, 7)); + SlidingWindows windowFn = SlidingWindows.of(new Duration(3)).every(new Duration(3)); + assertEquals( + expected, + runWindowFn( + windowFn, + Arrays.asList(1L, 2L, 3L, 4L, 5L, 6L, 7L))); + assertThat(windowFn.assignsToOneWindow(), is(true)); } @Test @@ -82,12 +100,14 @@ public void testElidings() throws Exception { expected.put(new IntervalWindow(new Instant(0), new Instant(3)), set(1, 2)); expected.put(new IntervalWindow(new Instant(10), new Instant(13)), set(10, 11)); expected.put(new IntervalWindow(new Instant(100), new Instant(103)), set(100)); + SlidingWindows windowFn = SlidingWindows.of(new Duration(3)).every(new Duration(10)); assertEquals( expected, runWindowFn( // Only look at the first 3 millisecs of every 10-millisec interval. - SlidingWindows.of(new Duration(3)).every(new Duration(10)), + windowFn, Arrays.asList(1L, 2L, 3L, 5L, 9L, 10L, 11L, 100L))); + assertThat(windowFn.assignsToOneWindow(), is(true)); } @Test From 03741ba0b55e446cd2583e7ccc88f79ec21705b5 Mon Sep 17 00:00:00 2001 From: Luke Cwik Date: Fri, 23 Jun 2017 09:32:49 -0700 Subject: [PATCH 122/200] [BEAM-1187] Improve logging to contain the number of retries done due to IOException and unsuccessful response codes. --- .../sdk/util/RetryHttpRequestInitializer.java | 148 +++++++++++------- .../util/RetryHttpRequestInitializerTest.java | 31 ++-- 2 files changed, 116 insertions(+), 63 deletions(-) diff --git a/sdks/java/extensions/google-cloud-platform-core/src/main/java/org/apache/beam/sdk/util/RetryHttpRequestInitializer.java b/sdks/java/extensions/google-cloud-platform-core/src/main/java/org/apache/beam/sdk/util/RetryHttpRequestInitializer.java index a23bee387e2bd..fd908cf780547 100644 --- a/sdks/java/extensions/google-cloud-platform-core/src/main/java/org/apache/beam/sdk/util/RetryHttpRequestInitializer.java +++ b/sdks/java/extensions/google-cloud-platform-core/src/main/java/org/apache/beam/sdk/util/RetryHttpRequestInitializer.java @@ -17,8 +17,9 @@ */ package org.apache.beam.sdk.util; -import com.google.api.client.http.HttpBackOffIOExceptionHandler; -import com.google.api.client.http.HttpBackOffUnsuccessfulResponseHandler; +import static com.google.api.client.util.BackOffUtils.next; + +import com.google.api.client.http.HttpIOExceptionHandler; import com.google.api.client.http.HttpRequest; import com.google.api.client.http.HttpRequestInitializer; import com.google.api.client.http.HttpResponse; @@ -60,65 +61,106 @@ public class RetryHttpRequestInitializer implements HttpRequestInitializer { */ private static final int HANGING_GET_TIMEOUT_SEC = 80; - private static class LoggingHttpBackOffIOExceptionHandler - extends HttpBackOffIOExceptionHandler { - public LoggingHttpBackOffIOExceptionHandler(BackOff backOff) { - super(backOff); + /** Handlers used to provide additional logging information on unsuccessful HTTP requests. */ + private static class LoggingHttpBackOffHandler + implements HttpIOExceptionHandler, HttpUnsuccessfulResponseHandler { + + private final Sleeper sleeper; + private final BackOff ioExceptionBackOff; + private final BackOff unsuccessfulResponseBackOff; + private final Set ignoredResponseCodes; + private int ioExceptionRetries; + private int unsuccessfulResponseRetries; + + private LoggingHttpBackOffHandler( + Sleeper sleeper, + BackOff ioExceptionBackOff, + BackOff unsucessfulResponseBackOff, + Set ignoredResponseCodes) { + this.sleeper = sleeper; + this.ioExceptionBackOff = ioExceptionBackOff; + this.unsuccessfulResponseBackOff = unsucessfulResponseBackOff; + this.ignoredResponseCodes = ignoredResponseCodes; } @Override public boolean handleIOException(HttpRequest request, boolean supportsRetry) throws IOException { - boolean willRetry = super.handleIOException(request, supportsRetry); + // We will retry if the request supports retry or the backoff was successful. + // Note that the order of these checks is important since + // backOffWasSuccessful will perform a sleep. + boolean willRetry = supportsRetry && backOffWasSuccessful(ioExceptionBackOff); if (willRetry) { + ioExceptionRetries += 1; LOG.debug("Request failed with IOException, will retry: {}", request.getUrl()); } else { - LOG.warn( - "Request failed with IOException (caller responsible for retrying): {}", + String message = "Request failed with IOException, " + + "performed {} retries due to IOExceptions, " + + "performed {} retries due to unsuccessful status codes, " + + "HTTP framework says request {} be retried, " + + "(caller responsible for retrying): {}"; + LOG.warn(message, + ioExceptionRetries, + unsuccessfulResponseRetries, + supportsRetry ? "can" : "cannot", request.getUrl()); } return willRetry; } - } - - private static class LoggingHttpBackoffUnsuccessfulResponseHandler - implements HttpUnsuccessfulResponseHandler { - private final HttpBackOffUnsuccessfulResponseHandler handler; - private final Set ignoredResponseCodes; - - public LoggingHttpBackoffUnsuccessfulResponseHandler(BackOff backoff, - Sleeper sleeper, Set ignoredResponseCodes) { - this.ignoredResponseCodes = ignoredResponseCodes; - handler = new HttpBackOffUnsuccessfulResponseHandler(backoff); - handler.setSleeper(sleeper); - handler.setBackOffRequired( - new HttpBackOffUnsuccessfulResponseHandler.BackOffRequired() { - @Override - public boolean isRequired(HttpResponse response) { - int statusCode = response.getStatusCode(); - return (statusCode == 0) // Code 0 usually means no response / network error - || (statusCode / 100 == 5) // 5xx: server error - || statusCode == 429; // 429: Too many requests - } - }); - } @Override - public boolean handleResponse(HttpRequest request, HttpResponse response, - boolean supportsRetry) throws IOException { - boolean retry = handler.handleResponse(request, response, supportsRetry); - if (retry) { + public boolean handleResponse(HttpRequest request, HttpResponse response, boolean supportsRetry) + throws IOException { + // We will retry if the request supports retry and the status code requires a backoff + // and the backoff was successful. Note that the order of these checks is important since + // backOffWasSuccessful will perform a sleep. + boolean willRetry = supportsRetry + && retryOnStatusCode(response.getStatusCode()) + && backOffWasSuccessful(unsuccessfulResponseBackOff); + if (willRetry) { + unsuccessfulResponseRetries += 1; LOG.debug("Request failed with code {}, will retry: {}", response.getStatusCode(), request.getUrl()); + } else { + String message = "Request failed with code {}, " + + "performed {} retries due to IOExceptions, " + + "performed {} retries due to unsuccessful status codes, " + + "HTTP framework says request {} be retried, " + + "(caller responsible for retrying): {}"; + if (ignoredResponseCodes.contains(response.getStatusCode())) { + // Log ignored response codes at a lower level + LOG.debug(message, + response.getStatusCode(), + ioExceptionRetries, + unsuccessfulResponseRetries, + supportsRetry ? "can" : "cannot", + request.getUrl()); + } else { + LOG.warn(message, + response.getStatusCode(), + ioExceptionRetries, + unsuccessfulResponseRetries, + supportsRetry ? "can" : "cannot", + request.getUrl()); + } + } + return willRetry; + } - } else if (!ignoredResponseCodes.contains(response.getStatusCode())) { - LOG.warn( - "Request failed with code {} (caller responsible for retrying): {}", - response.getStatusCode(), - request.getUrl()); + /** Returns true iff performing the backoff was successful. */ + private boolean backOffWasSuccessful(BackOff backOff) { + try { + return next(sleeper, backOff); + } catch (InterruptedException | IOException e) { + return false; } + } - return retry; + /** Returns true iff the {@code statusCode} represents an error that should be retried. */ + private boolean retryOnStatusCode(int statusCode) { + return (statusCode == 0) // Code 0 usually means no response / network error + || (statusCode / 100 == 5) // 5xx: server error + || statusCode == 429; // 429: Too many requests } } @@ -174,20 +216,20 @@ public void initialize(HttpRequest request) throws IOException { // TODO: Do this exclusively for work requests. request.setReadTimeout(HANGING_GET_TIMEOUT_SEC * 1000); - // Back off on retryable http errors. - request.setUnsuccessfulResponseHandler( + LoggingHttpBackOffHandler loggingHttpBackOffHandler = new LoggingHttpBackOffHandler( + sleeper, + // Retry immediately on IOExceptions. + BackOff.ZERO_BACKOFF, + // Back off on retryable http errors. // A back-off multiplier of 2 raises the maximum request retrying time // to approximately 5 minutes (keeping other back-off parameters to // their default values). - new LoggingHttpBackoffUnsuccessfulResponseHandler( - new ExponentialBackOff.Builder().setNanoClock(nanoClock) - .setMultiplier(2).build(), - sleeper, ignoredResponseCodes)); - - // Retry immediately on IOExceptions. - LoggingHttpBackOffIOExceptionHandler loggingBackoffHandler = - new LoggingHttpBackOffIOExceptionHandler(BackOff.ZERO_BACKOFF); - request.setIOExceptionHandler(loggingBackoffHandler); + new ExponentialBackOff.Builder().setNanoClock(nanoClock).setMultiplier(2).build(), + ignoredResponseCodes + ); + + request.setUnsuccessfulResponseHandler(loggingHttpBackOffHandler); + request.setIOExceptionHandler(loggingHttpBackOffHandler); // Set response initializer if (responseInterceptor != null) { diff --git a/sdks/java/extensions/google-cloud-platform-core/src/test/java/org/apache/beam/sdk/util/RetryHttpRequestInitializerTest.java b/sdks/java/extensions/google-cloud-platform-core/src/test/java/org/apache/beam/sdk/util/RetryHttpRequestInitializerTest.java index 37551a4f9f0ba..13a9309038e84 100644 --- a/sdks/java/extensions/google-cloud-platform-core/src/test/java/org/apache/beam/sdk/util/RetryHttpRequestInitializerTest.java +++ b/sdks/java/extensions/google-cloud-platform-core/src/test/java/org/apache/beam/sdk/util/RetryHttpRequestInitializerTest.java @@ -49,10 +49,11 @@ import java.security.PrivateKey; import java.util.Arrays; import java.util.concurrent.atomic.AtomicLong; +import org.apache.beam.sdk.testing.ExpectedLogs; import org.hamcrest.Matchers; import org.junit.After; -import org.junit.Assert; import org.junit.Before; +import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -67,6 +68,8 @@ @RunWith(JUnit4.class) public class RetryHttpRequestInitializerTest { + @Rule public ExpectedLogs expectedLogs = ExpectedLogs.none(RetryHttpRequestInitializer.class); + @Mock private PrivateKey mockPrivateKey; @Mock private LowLevelHttpRequest mockLowLevelRequest; @Mock private LowLevelHttpResponse mockLowLevelResponse; @@ -135,6 +138,7 @@ public void testBasicOperation() throws IOException { verify(mockLowLevelRequest).setTimeout(anyInt(), anyInt()); verify(mockLowLevelRequest).execute(); verify(mockLowLevelResponse).getStatusCode(); + expectedLogs.verifyNotLogged("Request failed"); } /** @@ -153,7 +157,7 @@ public void testErrorCodeForbidden() throws IOException { HttpResponse response = result.executeUnparsed(); assertNotNull(response); } catch (HttpResponseException e) { - Assert.assertThat(e.getMessage(), Matchers.containsString("403")); + assertThat(e.getMessage(), Matchers.containsString("403")); } verify(mockHttpResponseInterceptor).interceptResponse(any(HttpResponse.class)); @@ -162,6 +166,7 @@ public void testErrorCodeForbidden() throws IOException { verify(mockLowLevelRequest).setTimeout(anyInt(), anyInt()); verify(mockLowLevelRequest).execute(); verify(mockLowLevelResponse).getStatusCode(); + expectedLogs.verifyWarn("Request failed with code 403"); } /** @@ -188,6 +193,7 @@ public void testRetryableError() throws IOException { verify(mockLowLevelRequest, times(3)).setTimeout(anyInt(), anyInt()); verify(mockLowLevelRequest, times(3)).execute(); verify(mockLowLevelResponse, times(3)).getStatusCode(); + expectedLogs.verifyDebug("Request failed with code 503"); } /** @@ -211,6 +217,7 @@ public void testThrowIOException() throws IOException { verify(mockLowLevelRequest, times(2)).setTimeout(anyInt(), anyInt()); verify(mockLowLevelRequest, times(2)).execute(); verify(mockLowLevelResponse).getStatusCode(); + expectedLogs.verifyDebug("Request failed with IOException"); } /** @@ -224,19 +231,22 @@ public void testRetryableErrorRetryEnoughTimes() throws IOException { int n = 0; @Override public Integer answer(InvocationOnMock invocation) { - return (n++ < retries - 1) ? 503 : 200; + return n++ < retries ? 503 : 9999; }}); Storage.Buckets.Get result = storage.buckets().get("test"); - HttpResponse response = result.executeUnparsed(); - assertNotNull(response); + try { + result.executeUnparsed(); + fail(); + } catch (Throwable t) { + } verify(mockHttpResponseInterceptor).interceptResponse(any(HttpResponse.class)); - verify(mockLowLevelRequest, atLeastOnce()).addHeader(anyString(), - anyString()); - verify(mockLowLevelRequest, times(retries)).setTimeout(anyInt(), anyInt()); - verify(mockLowLevelRequest, times(retries)).execute(); - verify(mockLowLevelResponse, times(retries)).getStatusCode(); + verify(mockLowLevelRequest, atLeastOnce()).addHeader(anyString(), anyString()); + verify(mockLowLevelRequest, times(retries + 1)).setTimeout(anyInt(), anyInt()); + verify(mockLowLevelRequest, times(retries + 1)).execute(); + verify(mockLowLevelResponse, times(retries + 1)).getStatusCode(); + expectedLogs.verifyWarn("performed 10 retries due to unsuccessful status codes"); } /** @@ -276,6 +286,7 @@ public LowLevelHttpResponse execute() throws IOException { } catch (Throwable e) { assertThat(e, Matchers.instanceOf(SocketTimeoutException.class)); assertEquals(1 + defaultNumberOfRetries, executeCount.get()); + expectedLogs.verifyWarn("performed 10 retries due to IOExceptions"); } } } From 80c9263617eb453a3595735147f328a8ee6d783e Mon Sep 17 00:00:00 2001 From: Mairbek Khadikov Date: Mon, 19 Jun 2017 12:23:52 -0700 Subject: [PATCH 123/200] Bump spanner version --- pom.xml | 2 +- .../java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java | 2 +- .../java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOTest.java | 3 --- .../org/apache/beam/sdk/io/gcp/spanner/SpannerWriteIT.java | 2 +- 4 files changed, 3 insertions(+), 6 deletions(-) diff --git a/pom.xml b/pom.xml index 29bb4ebed18e9..f06568b77860c 100644 --- a/pom.xml +++ b/pom.xml @@ -138,7 +138,7 @@ 3.2.0 v1-rev10-1.22.0 1.7.14 - 0.16.0-beta + 0.20.0-beta 1.6.2 4.3.5.RELEASE 3.1.4 diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java index 8bfc247adda8c..32bf1d0f0b3f9 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java @@ -337,7 +337,7 @@ public void teardown() throws Exception { if (spanner == null) { return; } - spanner.closeAsync().get(); + spanner.close(); spanner = null; } diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOTest.java index 1e19a59c4849f..0cc08bfc03184 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOTest.java @@ -27,7 +27,6 @@ import static org.mockito.Mockito.when; import static org.mockito.Mockito.withSettings; -import com.google.api.core.ApiFuture; import com.google.cloud.ServiceFactory; import com.google.cloud.spanner.DatabaseClient; import com.google.cloud.spanner.DatabaseId; @@ -274,10 +273,8 @@ public FakeServiceFactory() { mockSpanners.add(mock(Spanner.class, withSettings().serializable())); mockDatabaseClients.add(mock(DatabaseClient.class, withSettings().serializable())); } - ApiFuture voidFuture = mock(ApiFuture.class, withSettings().serializable()); when(mockSpanner().getDatabaseClient(Matchers.any(DatabaseId.class))) .thenReturn(mockDatabaseClient()); - when(mockSpanner().closeAsync()).thenReturn(voidFuture); } DatabaseClient mockDatabaseClient() { diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteIT.java index e1f6582749a4f..33532c929bab0 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteIT.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteIT.java @@ -150,7 +150,7 @@ public void testWrite() throws Exception { @After public void tearDown() throws Exception { databaseAdminClient.dropDatabase(options.getInstanceId(), databaseName); - spanner.closeAsync().get(); + spanner.close(); } private static class GenerateMutations extends DoFn { From 454f1c427353feeb858cdc62185ea3fced8d8a1f Mon Sep 17 00:00:00 2001 From: Mairbek Khadikov Date: Mon, 19 Jun 2017 13:01:20 -0700 Subject: [PATCH 124/200] Pre read api refactoring. Extract `SpannerConfig` and `AbstractSpannerFn` --- .../sdk/io/gcp/spanner/AbstractSpannerFn.java | 41 ++++ .../sdk/io/gcp/spanner/SpannerConfig.java | 118 +++++++++ .../beam/sdk/io/gcp/spanner/SpannerIO.java | 227 ++++-------------- .../io/gcp/spanner/SpannerWriteGroupFn.java | 108 +++++++++ .../sdk/io/gcp/spanner/SpannerIOTest.java | 8 +- 5 files changed, 321 insertions(+), 181 deletions(-) create mode 100644 sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/AbstractSpannerFn.java create mode 100644 sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerConfig.java create mode 100644 sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteGroupFn.java diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/AbstractSpannerFn.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/AbstractSpannerFn.java new file mode 100644 index 0000000000000..08f7fa9cb60f7 --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/AbstractSpannerFn.java @@ -0,0 +1,41 @@ +package org.apache.beam.sdk.io.gcp.spanner; + +import com.google.cloud.spanner.DatabaseClient; +import com.google.cloud.spanner.DatabaseId; +import com.google.cloud.spanner.Spanner; +import com.google.cloud.spanner.SpannerOptions; +import org.apache.beam.sdk.transforms.DoFn; + +/** + * Abstract {@link DoFn} that manages {@link Spanner} lifecycle. Use {@link + * AbstractSpannerFn#databaseClient} to access the Cloud Spanner database client. + */ +abstract class AbstractSpannerFn extends DoFn { + private transient Spanner spanner; + private transient DatabaseClient databaseClient; + + abstract SpannerConfig getSpannerConfig(); + + @Setup + public void setup() throws Exception { + SpannerConfig spannerConfig = getSpannerConfig(); + SpannerOptions options = spannerConfig.buildSpannerOptions(); + spanner = options.getService(); + databaseClient = spanner.getDatabaseClient(DatabaseId + .of(options.getProjectId(), spannerConfig.getInstanceId().get(), + spannerConfig.getDatabaseId().get())); + } + + @Teardown + public void teardown() throws Exception { + if (spanner == null) { + return; + } + spanner.close(); + spanner = null; + } + + protected DatabaseClient databaseClient() { + return databaseClient; + } +} diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerConfig.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerConfig.java new file mode 100644 index 0000000000000..4cb8aa28bd637 --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerConfig.java @@ -0,0 +1,118 @@ +package org.apache.beam.sdk.io.gcp.spanner; + +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.auto.value.AutoValue; +import com.google.cloud.ServiceFactory; +import com.google.cloud.spanner.Spanner; +import com.google.cloud.spanner.SpannerOptions; +import com.google.common.annotations.VisibleForTesting; +import java.io.Serializable; +import javax.annotation.Nullable; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.ValueProvider; +import org.apache.beam.sdk.transforms.display.DisplayData; + +/** Configuration for a Cloud Spanner client. */ +@AutoValue +public abstract class SpannerConfig implements Serializable { + + private static final long serialVersionUID = -5680874609304170301L; + + @Nullable + abstract ValueProvider getProjectId(); + + @Nullable + abstract ValueProvider getInstanceId(); + + @Nullable + abstract ValueProvider getDatabaseId(); + + @Nullable + @VisibleForTesting + abstract ServiceFactory getServiceFactory(); + + abstract Builder toBuilder(); + + SpannerOptions buildSpannerOptions() { + SpannerOptions.Builder builder = SpannerOptions.newBuilder(); + if (getProjectId() != null) { + builder.setProjectId(getProjectId().get()); + } + if (getServiceFactory() != null) { + builder.setServiceFactory(getServiceFactory()); + } + return builder.build(); + } + + public static SpannerConfig create() { + return builder().build(); + } + + public static Builder builder() { + return new AutoValue_SpannerConfig.Builder(); + } + + public void validate(PipelineOptions options) { + checkNotNull( + getInstanceId(), + "SpannerIO.read() requires instance id to be set with withInstanceId method"); + checkNotNull( + getDatabaseId(), + "SpannerIO.read() requires database id to be set with withDatabaseId method"); + } + + public void populateDisplayData(DisplayData.Builder builder) { + builder + .addIfNotNull(DisplayData.item("projectId", getProjectId()).withLabel("Output Project")) + .addIfNotNull(DisplayData.item("instanceId", getInstanceId()).withLabel("Output Instance")) + .addIfNotNull(DisplayData.item("databaseId", getDatabaseId()).withLabel("Output Database")); + + if (getServiceFactory() != null) { + builder.addIfNotNull( + DisplayData.item("serviceFactory", getServiceFactory().getClass().getName()) + .withLabel("Service Factory")); + } + } + + /** Builder for {@link SpannerConfig}. */ + @AutoValue.Builder + public abstract static class Builder { + + + abstract Builder setProjectId(ValueProvider projectId); + + abstract Builder setInstanceId(ValueProvider instanceId); + + abstract Builder setDatabaseId(ValueProvider databaseId); + + + abstract Builder setServiceFactory(ServiceFactory serviceFactory); + + public abstract SpannerConfig build(); + } + + public SpannerConfig withProjectId(ValueProvider projectId) { + return toBuilder().setProjectId(projectId).build(); + } + + public SpannerConfig withProjectId(String projectId) { + return withProjectId(ValueProvider.StaticValueProvider.of(projectId)); + } + + public SpannerConfig withInstanceId(ValueProvider instanceId) { + return toBuilder().setInstanceId(instanceId).build(); + } + + public SpannerConfig withInstanceId(String instanceId) { + return withInstanceId(ValueProvider.StaticValueProvider.of(instanceId)); + } + + public SpannerConfig withDatabaseId(ValueProvider databaseId) { + return toBuilder().setDatabaseId(databaseId).build(); + } + + public SpannerConfig withDatabaseId(String databaseId) { + return withDatabaseId(ValueProvider.StaticValueProvider.of(databaseId)); + } +} diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java index 32bf1d0f0b3f9..791c7e71daf73 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java @@ -17,22 +17,13 @@ */ package org.apache.beam.sdk.io.gcp.spanner; -import static com.google.common.base.Preconditions.checkNotNull; - import com.google.auto.value.AutoValue; import com.google.cloud.ServiceFactory; -import com.google.cloud.ServiceOptions; -import com.google.cloud.spanner.AbortedException; -import com.google.cloud.spanner.DatabaseClient; -import com.google.cloud.spanner.DatabaseId; import com.google.cloud.spanner.Mutation; import com.google.cloud.spanner.Spanner; import com.google.cloud.spanner.SpannerOptions; import com.google.common.annotations.VisibleForTesting; -import com.google.common.collect.Iterables; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; + import javax.annotation.Nullable; import org.apache.beam.sdk.annotations.Experimental; @@ -42,16 +33,8 @@ import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.display.DisplayData; -import org.apache.beam.sdk.transforms.display.DisplayData.Builder; -import org.apache.beam.sdk.util.BackOff; -import org.apache.beam.sdk.util.BackOffUtils; -import org.apache.beam.sdk.util.FluentBackoff; -import org.apache.beam.sdk.util.Sleeper; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PDone; -import org.joda.time.Duration; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; /** * Experimental {@link PTransform Transforms} for reading from and writing to , PDone> { - @Nullable - abstract ValueProvider getProjectId(); + private static final long serialVersionUID = 1920175411827980145L; - @Nullable - abstract ValueProvider getInstanceId(); - - @Nullable - abstract ValueProvider getDatabaseId(); + abstract SpannerConfig getSpannerConfig(); abstract long getBatchSizeBytes(); - @Nullable - @VisibleForTesting - abstract ServiceFactory getServiceFactory(); - abstract Builder toBuilder(); @AutoValue.Builder abstract static class Builder { - abstract Builder setProjectId(ValueProvider projectId); - - abstract Builder setInstanceId(ValueProvider instanceId); + abstract Builder setSpannerConfig(SpannerConfig spannerConfig); - abstract Builder setDatabaseId(ValueProvider databaseId); + abstract SpannerConfig.Builder spannerConfigBuilder(); abstract Builder setBatchSizeBytes(long batchSizeBytes); - @VisibleForTesting - abstract Builder setServiceFactory(ServiceFactory serviceFactory); - abstract Write build(); } @@ -166,8 +135,15 @@ public Write withProjectId(String projectId) { return withProjectId(ValueProvider.StaticValueProvider.of(projectId)); } + /** + * Returns a new {@link SpannerIO.Write} that will write to the specified Cloud Spanner project. + * + *

    Does not modify this object. + */ public Write withProjectId(ValueProvider projectId) { - return toBuilder().setProjectId(projectId).build(); + Write.Builder builder = toBuilder(); + builder.spannerConfigBuilder().setProjectId(projectId); + return builder.build(); } /** @@ -180,10 +156,29 @@ public Write withInstanceId(String instanceId) { return withInstanceId(ValueProvider.StaticValueProvider.of(instanceId)); } + /** + * Returns a new {@link SpannerIO.Write} that will write to the specified Cloud Spanner + * instance. + * + *

    Does not modify this object. + */ public Write withInstanceId(ValueProvider instanceId) { - return toBuilder().setInstanceId(instanceId).build(); + Write.Builder builder = toBuilder(); + builder.spannerConfigBuilder().setInstanceId(instanceId); + return builder.build(); } + /** + * Returns a new {@link SpannerIO.Write} that will write to the specified Cloud Spanner + * config. + * + *

    Does not modify this object. + */ + public Write withSpannerConfig(SpannerConfig spannerConfig) { + return toBuilder().setSpannerConfig(spannerConfig).build(); + } + + /** * Returns a new {@link SpannerIO.Write} with a new batch size limit. * @@ -203,8 +198,16 @@ public Write withDatabaseId(String databaseId) { return withDatabaseId(ValueProvider.StaticValueProvider.of(databaseId)); } + /** + * Returns a new {@link SpannerIO.Write} that will write to the specified Cloud Spanner + * database. + * + *

    Does not modify this object. + */ public Write withDatabaseId(ValueProvider databaseId) { - return toBuilder().setDatabaseId(databaseId).build(); + Write.Builder builder = toBuilder(); + builder.spannerConfigBuilder().setDatabaseId(databaseId); + return builder.build(); } /** @@ -216,17 +219,14 @@ public WriteGrouped grouped() { @VisibleForTesting Write withServiceFactory(ServiceFactory serviceFactory) { - return toBuilder().setServiceFactory(serviceFactory).build(); + Write.Builder builder = toBuilder(); + builder.spannerConfigBuilder().setServiceFactory(serviceFactory); + return builder.build(); } @Override public void validate(PipelineOptions options) { - checkNotNull( - getInstanceId(), - "SpannerIO.write() requires instance id to be set with withInstanceId method"); - checkNotNull( - getDatabaseId(), - "SpannerIO.write() requires database id to be set with withDatabaseId method"); + getSpannerConfig().validate(options); } @Override @@ -237,22 +237,13 @@ public PDone expand(PCollection input) { return PDone.in(input.getPipeline()); } + @Override public void populateDisplayData(DisplayData.Builder builder) { super.populateDisplayData(builder); - builder - .addIfNotNull(DisplayData.item("projectId", getProjectId()).withLabel("Output Project")) - .addIfNotNull( - DisplayData.item("instanceId", getInstanceId()).withLabel("Output Instance")) - .addIfNotNull( - DisplayData.item("databaseId", getDatabaseId()).withLabel("Output Database")) - .add(DisplayData.item("batchSizeBytes", getBatchSizeBytes()) - .withLabel("Batch Size in Bytes")); - if (getServiceFactory() != null) { - builder.addIfNotNull( - DisplayData.item("serviceFactory", getServiceFactory().getClass().getName()) - .withLabel("Service Factory")); - } + getSpannerConfig().populateDisplayData(builder); + builder.add( + DisplayData.item("batchSizeBytes", getBatchSizeBytes()).withLabel("Batch Size in Bytes")); } } @@ -278,123 +269,5 @@ public void processElement(ProcessContext c) throws Exception { } } - /** Batches together and writes mutations to Google Cloud Spanner. */ - @VisibleForTesting - static class SpannerWriteGroupFn extends DoFn { - private static final Logger LOG = LoggerFactory.getLogger(SpannerWriteGroupFn.class); - private final Write spec; - private transient Spanner spanner; - private transient DatabaseClient dbClient; - // Current batch of mutations to be written. - private List mutations; - private long batchSizeBytes = 0; - - private static final int MAX_RETRIES = 5; - private static final FluentBackoff BUNDLE_WRITE_BACKOFF = - FluentBackoff.DEFAULT - .withMaxRetries(MAX_RETRIES) - .withInitialBackoff(Duration.standardSeconds(5)); - - @VisibleForTesting SpannerWriteGroupFn(Write spec) { - this.spec = spec; - } - - @Setup - public void setup() throws Exception { - SpannerOptions spannerOptions = getSpannerOptions(); - spanner = spannerOptions.getService(); - dbClient = spanner.getDatabaseClient( - DatabaseId.of(projectId(), spec.getInstanceId().get(), spec.getDatabaseId().get())); - mutations = new ArrayList<>(); - batchSizeBytes = 0; - } - - @ProcessElement - public void processElement(ProcessContext c) throws Exception { - MutationGroup m = c.element(); - mutations.add(m); - batchSizeBytes += MutationSizeEstimator.sizeOf(m); - if (batchSizeBytes >= spec.getBatchSizeBytes()) { - flushBatch(); - } - } - - private String projectId() { - return spec.getProjectId() == null - ? ServiceOptions.getDefaultProjectId() - : spec.getProjectId().get(); - } - - @FinishBundle - public void finishBundle() throws Exception { - if (!mutations.isEmpty()) { - flushBatch(); - } - } - - @Teardown - public void teardown() throws Exception { - if (spanner == null) { - return; - } - spanner.close(); - spanner = null; - } - - private SpannerOptions getSpannerOptions() { - SpannerOptions.Builder spannerOptionsBuider = SpannerOptions.newBuilder(); - if (spec.getServiceFactory() != null) { - spannerOptionsBuider.setServiceFactory(spec.getServiceFactory()); - } - if (spec.getProjectId() != null) { - spannerOptionsBuider.setProjectId(spec.getProjectId().get()); - } - return spannerOptionsBuider.build(); - } - - /** - * Writes a batch of mutations to Cloud Spanner. - * - *

    If a commit fails, it will be retried up to {@link #MAX_RETRIES} times. If the retry limit - * is exceeded, the last exception from Cloud Spanner will be thrown. - * - * @throws AbortedException if the commit fails or IOException or InterruptedException if - * backing off between retries fails. - */ - private void flushBatch() throws AbortedException, IOException, InterruptedException { - LOG.debug("Writing batch of {} mutations", mutations.size()); - Sleeper sleeper = Sleeper.DEFAULT; - BackOff backoff = BUNDLE_WRITE_BACKOFF.backoff(); - - while (true) { - // Batch upsert rows. - try { - dbClient.writeAtLeastOnce(Iterables.concat(mutations)); - - // Break if the commit threw no exception. - break; - } catch (AbortedException exception) { - // Only log the code and message for potentially-transient errors. The entire exception - // will be propagated upon the last retry. - LOG.error( - "Error writing to Spanner ({}): {}", exception.getCode(), exception.getMessage()); - if (!BackOffUtils.next(sleeper, backoff)) { - LOG.error("Aborting after {} retries.", MAX_RETRIES); - throw exception; - } - } - } - LOG.debug("Successfully wrote {} mutations", mutations.size()); - mutations = new ArrayList<>(); - batchSizeBytes = 0; - } - - @Override - public void populateDisplayData(Builder builder) { - super.populateDisplayData(builder); - spec.populateDisplayData(builder); - } - } - private SpannerIO() {} // Prevent construction. } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteGroupFn.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteGroupFn.java new file mode 100644 index 0000000000000..aed4832b7d863 --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteGroupFn.java @@ -0,0 +1,108 @@ +package org.apache.beam.sdk.io.gcp.spanner; + +import com.google.cloud.spanner.AbortedException; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.Iterables; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import org.apache.beam.sdk.transforms.display.DisplayData; +import org.apache.beam.sdk.util.BackOff; +import org.apache.beam.sdk.util.BackOffUtils; +import org.apache.beam.sdk.util.FluentBackoff; +import org.apache.beam.sdk.util.Sleeper; +import org.joda.time.Duration; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Batches together and writes mutations to Google Cloud Spanner. */ +@VisibleForTesting class SpannerWriteGroupFn extends AbstractSpannerFn { + private static final Logger LOG = LoggerFactory.getLogger(SpannerWriteGroupFn.class); + private final SpannerIO.Write spec; + // Current batch of mutations to be written. + private List mutations; + private long batchSizeBytes = 0; + + private static final int MAX_RETRIES = 5; + private static final FluentBackoff BUNDLE_WRITE_BACKOFF = + FluentBackoff.DEFAULT + .withMaxRetries(MAX_RETRIES) + .withInitialBackoff(Duration.standardSeconds(5)); + + @VisibleForTesting SpannerWriteGroupFn(SpannerIO.Write spec) { + this.spec = spec; + } + + @Override SpannerConfig getSpannerConfig() { + return spec.getSpannerConfig(); + } + + @Setup + public void setup() throws Exception { + super.setup(); + mutations = new ArrayList<>(); + batchSizeBytes = 0; + } + + @ProcessElement + public void processElement(ProcessContext c) throws Exception { + MutationGroup m = c.element(); + mutations.add(m); + batchSizeBytes += MutationSizeEstimator.sizeOf(m); + if (batchSizeBytes >= spec.getBatchSizeBytes()) { + flushBatch(); + } + } + + @FinishBundle + public void finishBundle() throws Exception { + if (!mutations.isEmpty()) { + flushBatch(); + } + } + + /** + * Writes a batch of mutations to Cloud Spanner. + * + *

    If a commit fails, it will be retried up to {@link #MAX_RETRIES} times. If the retry limit + * is exceeded, the last exception from Cloud Spanner will be thrown. + * + * @throws AbortedException if the commit fails or IOException or InterruptedException if + * backing off between retries fails. + */ + private void flushBatch() throws AbortedException, IOException, InterruptedException { + LOG.debug("Writing batch of {} mutations", mutations.size()); + Sleeper sleeper = Sleeper.DEFAULT; + BackOff backoff = BUNDLE_WRITE_BACKOFF.backoff(); + + while (true) { + // Batch upsert rows. + try { + databaseClient().writeAtLeastOnce(Iterables.concat(mutations)); + + // Break if the commit threw no exception. + break; + } catch (AbortedException exception) { + // Only log the code and message for potentially-transient errors. The entire exception + // will be propagated upon the last retry. + LOG.error( + "Error writing to Spanner ({}): {}", exception.getCode(), exception.getMessage()); + if (!BackOffUtils.next(sleeper, backoff)) { + LOG.error("Aborting after {} retries.", MAX_RETRIES); + throw exception; + } + } + } + LOG.debug("Successfully wrote {} mutations", mutations.size()); + mutations = new ArrayList<>(); + batchSizeBytes = 0; + } + + @Override + public void populateDisplayData(DisplayData.Builder builder) { + super.populateDisplayData(builder); + spec.populateDisplayData(builder); + } +} diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOTest.java index 0cc08bfc03184..abeac0a8f4ae1 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOTest.java @@ -149,7 +149,7 @@ public void batching() throws Exception { .withDatabaseId("test-database") .withBatchSizeBytes(1000000000) .withServiceFactory(serviceFactory); - SpannerIO.SpannerWriteGroupFn writerFn = new SpannerIO.SpannerWriteGroupFn(write); + SpannerWriteGroupFn writerFn = new SpannerWriteGroupFn(write); DoFnTester fnTester = DoFnTester.of(writerFn); fnTester.processBundle(Arrays.asList(one, two)); @@ -175,7 +175,7 @@ public void batchingGroups() throws Exception { .withDatabaseId("test-database") .withBatchSizeBytes(batchSize) .withServiceFactory(serviceFactory); - SpannerIO.SpannerWriteGroupFn writerFn = new SpannerIO.SpannerWriteGroupFn(write); + SpannerWriteGroupFn writerFn = new SpannerWriteGroupFn(write); DoFnTester fnTester = DoFnTester.of(writerFn); fnTester.processBundle(Arrays.asList(one, two, three)); @@ -198,7 +198,7 @@ public void noBatching() throws Exception { .withDatabaseId("test-database") .withBatchSizeBytes(0) // turn off batching. .withServiceFactory(serviceFactory); - SpannerIO.SpannerWriteGroupFn writerFn = new SpannerIO.SpannerWriteGroupFn(write); + SpannerWriteGroupFn writerFn = new SpannerWriteGroupFn(write); DoFnTester fnTester = DoFnTester.of(writerFn); fnTester.processBundle(Arrays.asList(one, two)); @@ -224,7 +224,7 @@ public void groups() throws Exception { .withDatabaseId("test-database") .withBatchSizeBytes(batchSize) .withServiceFactory(serviceFactory); - SpannerIO.SpannerWriteGroupFn writerFn = new SpannerIO.SpannerWriteGroupFn(write); + SpannerWriteGroupFn writerFn = new SpannerWriteGroupFn(write); DoFnTester fnTester = DoFnTester.of(writerFn); fnTester.processBundle(Arrays.asList(g(one, two, three))); From a21a6d797777ad38b927bd5d44c63306d85a752a Mon Sep 17 00:00:00 2001 From: Mairbek Khadikov Date: Mon, 19 Jun 2017 13:28:52 -0700 Subject: [PATCH 125/200] Read api with naive implementation --- pom.xml | 12 + sdks/java/io/google-cloud-platform/pom.xml | 16 +- .../sdk/io/gcp/spanner/AbstractSpannerFn.java | 17 + .../io/gcp/spanner/CreateTransactionFn.java | 51 ++ .../io/gcp/spanner/NaiveSpannerReadFn.java | 65 +++ .../sdk/io/gcp/spanner/SpannerConfig.java | 29 +- .../beam/sdk/io/gcp/spanner/SpannerIO.java | 479 +++++++++++++++--- .../io/gcp/spanner/SpannerWriteGroupFn.java | 17 + .../beam/sdk/io/gcp/spanner/Transaction.java | 33 ++ .../beam/sdk/io/gcp/GcpApiSurfaceTest.java | 10 + .../io/gcp/spanner/FakeServiceFactory.java | 82 +++ .../sdk/io/gcp/spanner/SpannerIOReadTest.java | 275 ++++++++++ ...nerIOTest.java => SpannerIOWriteTest.java} | 58 +-- .../sdk/io/gcp/spanner/SpannerReadIT.java | 169 ++++++ 14 files changed, 1175 insertions(+), 138 deletions(-) create mode 100644 sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/CreateTransactionFn.java create mode 100644 sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/NaiveSpannerReadFn.java create mode 100644 sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/Transaction.java create mode 100644 sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/FakeServiceFactory.java create mode 100644 sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOReadTest.java rename sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/{SpannerIOTest.java => SpannerIOWriteTest.java} (85%) create mode 100644 sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerReadIT.java diff --git a/pom.xml b/pom.xml index f06568b77860c..069191ccbb95c 100644 --- a/pom.xml +++ b/pom.xml @@ -161,6 +161,7 @@ -Werror -Xpkginfo:always nothing + 0.20.0 pom @@ -637,6 +638,12 @@ ${google-api-common.version} + + com.google.api + gax-grpc + ${gax-grpc.version} + + com.google.api-client google-api-client @@ -851,6 +858,11 @@ + + com.google.cloud + google-cloud-core-grpc + ${grpc.version} + com.google.cloud.bigtable bigtable-protos diff --git a/sdks/java/io/google-cloud-platform/pom.xml b/sdks/java/io/google-cloud-platform/pom.xml index 6737eea5b256a..94066c7b038a5 100644 --- a/sdks/java/io/google-cloud-platform/pom.xml +++ b/sdks/java/io/google-cloud-platform/pom.xml @@ -93,7 +93,12 @@ com.google.api - api-common + gax-grpc + + + + com.google.cloud + google-cloud-core-grpc @@ -255,11 +260,16 @@ org.apache.commons - commons-text - test + commons-lang3 + provided + + org.apache.commons + commons-text + test + org.apache.beam beam-sdks-java-core diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/AbstractSpannerFn.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/AbstractSpannerFn.java index 08f7fa9cb60f7..00008f1ebdcc1 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/AbstractSpannerFn.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/AbstractSpannerFn.java @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.apache.beam.sdk.io.gcp.spanner; import com.google.cloud.spanner.DatabaseClient; diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/CreateTransactionFn.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/CreateTransactionFn.java new file mode 100644 index 0000000000000..da8e8b15e1adf --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/CreateTransactionFn.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.spanner; + +import com.google.cloud.spanner.ReadOnlyTransaction; +import com.google.cloud.spanner.ResultSet; +import com.google.cloud.spanner.Statement; + +/** Creates a batch transaction. */ +class CreateTransactionFn extends AbstractSpannerFn { + + private final SpannerIO.CreateTransaction config; + + CreateTransactionFn(SpannerIO.CreateTransaction config) { + this.config = config; + } + + @ProcessElement + public void processElement(ProcessContext c) throws Exception { + try (ReadOnlyTransaction readOnlyTransaction = + databaseClient().readOnlyTransaction(config.getTimestampBound())) { + // Run a dummy sql statement to force the RPC and obtain the timestamp from the server. + ResultSet resultSet = readOnlyTransaction.executeQuery(Statement.of("SELECT 1")); + while (resultSet.next()) { + // do nothing + } + Transaction tx = Transaction.create(readOnlyTransaction.getReadTimestamp()); + c.output(tx); + } + } + + @Override + SpannerConfig getSpannerConfig() { + return config.getSpannerConfig(); + } +} diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/NaiveSpannerReadFn.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/NaiveSpannerReadFn.java new file mode 100644 index 0000000000000..d193b95768ad2 --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/NaiveSpannerReadFn.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.spanner; + +import com.google.cloud.spanner.ReadOnlyTransaction; +import com.google.cloud.spanner.ResultSet; +import com.google.cloud.spanner.Struct; +import com.google.cloud.spanner.TimestampBound; +import com.google.common.annotations.VisibleForTesting; + +/** A simplest read function implementation. Parallelism support is coming. */ +@VisibleForTesting +class NaiveSpannerReadFn extends AbstractSpannerFn { + private final SpannerIO.Read config; + + NaiveSpannerReadFn(SpannerIO.Read config) { + this.config = config; + } + + SpannerConfig getSpannerConfig() { + return config.getSpannerConfig(); + } + + @ProcessElement + public void processElement(ProcessContext c) throws Exception { + TimestampBound timestampBound = TimestampBound.strong(); + if (config.getTransaction() != null) { + Transaction transaction = c.sideInput(config.getTransaction()); + timestampBound = TimestampBound.ofReadTimestamp(transaction.timestamp()); + } + try (ReadOnlyTransaction readOnlyTransaction = + databaseClient().readOnlyTransaction(timestampBound)) { + ResultSet resultSet = execute(readOnlyTransaction); + while (resultSet.next()) { + c.output(resultSet.getCurrentRowAsStruct()); + } + } + } + + private ResultSet execute(ReadOnlyTransaction readOnlyTransaction) { + if (config.getQuery() != null) { + return readOnlyTransaction.executeQuery(config.getQuery()); + } + if (config.getIndex() != null) { + return readOnlyTransaction.readUsingIndex( + config.getTable(), config.getIndex(), config.getKeySet(), config.getColumns()); + } + return readOnlyTransaction.read(config.getTable(), config.getKeySet(), config.getColumns()); + } +} diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerConfig.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerConfig.java index 4cb8aa28bd637..02716fbaf4804 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerConfig.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerConfig.java @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.apache.beam.sdk.io.gcp.spanner; import static com.google.common.base.Preconditions.checkNotNull; @@ -17,8 +34,6 @@ @AutoValue public abstract class SpannerConfig implements Serializable { - private static final long serialVersionUID = -5680874609304170301L; - @Nullable abstract ValueProvider getProjectId(); @@ -49,7 +64,7 @@ public static SpannerConfig create() { return builder().build(); } - public static Builder builder() { + static Builder builder() { return new AutoValue_SpannerConfig.Builder(); } @@ -79,14 +94,12 @@ public void populateDisplayData(DisplayData.Builder builder) { @AutoValue.Builder public abstract static class Builder { - abstract Builder setProjectId(ValueProvider projectId); abstract Builder setInstanceId(ValueProvider instanceId); abstract Builder setDatabaseId(ValueProvider databaseId); - abstract Builder setServiceFactory(ServiceFactory serviceFactory); public abstract SpannerConfig build(); @@ -115,4 +128,10 @@ public SpannerConfig withDatabaseId(ValueProvider databaseId) { public SpannerConfig withDatabaseId(String databaseId) { return withDatabaseId(ValueProvider.StaticValueProvider.of(databaseId)); } + + @VisibleForTesting + SpannerConfig withServiceFactory(ServiceFactory serviceFactory) { + return toBuilder().setServiceFactory(serviceFactory).build(); + } + } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java index 791c7e71daf73..acf928530a6bd 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java @@ -17,23 +17,38 @@ */ package org.apache.beam.sdk.io.gcp.spanner; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; + import com.google.auto.value.AutoValue; import com.google.cloud.ServiceFactory; +import com.google.cloud.Timestamp; +import com.google.cloud.spanner.KeySet; import com.google.cloud.spanner.Mutation; import com.google.cloud.spanner.Spanner; import com.google.cloud.spanner.SpannerOptions; +import com.google.cloud.spanner.Statement; +import com.google.cloud.spanner.Struct; +import com.google.cloud.spanner.TimestampBound; import com.google.common.annotations.VisibleForTesting; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; import javax.annotation.Nullable; import org.apache.beam.sdk.annotations.Experimental; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.ValueProvider; +import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.View; import org.apache.beam.sdk.transforms.display.DisplayData; +import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.PDone; /** @@ -42,7 +57,69 @@ * *

    Reading from Cloud Spanner

    * - *

    This functionality is not yet implemented. + *

    To read from Cloud Spanner, apply {@link SpannerIO.Read} transformation. It will return a + * {@link PCollection} of {@link Struct Structs}, where each element represents + * an individual row returned from the read operation. Both Query and Read APIs are supported. + * See more information about reading from + * Cloud Spanner + * + *

    To execute a query, specify a {@link SpannerIO.Read#withQuery(Statement)} or + * {@link SpannerIO.Read#withQuery(String)} during the construction of the transform. + * + *

    {@code
    + *  PCollection rows = p.apply(
    + *      SpannerIO.read()
    + *          .withInstanceId(instanceId)
    + *          .withDatabaseId(dbId)
    + *          .withQuery("SELECT id, name, email FROM users"));
    + * }
    + * + *

    To use the Read API, specify a {@link SpannerIO.Read#withTable(String) table name} and + * a {@link SpannerIO.Read#withColumns(List) list of columns}. + * + *

    {@code
    + * PCollection rows = p.apply(
    + *    SpannerIO.read()
    + *        .withInstanceId(instanceId)
    + *        .withDatabaseId(dbId)
    + *        .withTable("users")
    + *        .withColumns("id", "name", "email"));
    + * }
    + * + *

    To optimally read using index, specify the index name using {@link SpannerIO.Read#withIndex}. + * + *

    The transform is guaranteed to be executed on a consistent snapshot of data, utilizing the + * power of read only transactions. Staleness of data can be controlled using + * {@link SpannerIO.Read#withTimestampBound} or {@link SpannerIO.Read#withTimestamp(Timestamp)} + * methods. Read more about + * transactions in Cloud Spanner. + * + *

    It is possible to read several {@link PCollection PCollections} within a single transaction. + * Apply {@link SpannerIO#createTransaction()} transform, that lazily creates a transaction. The + * result of this transformation can be passed to read operation using + * {@link SpannerIO.Read#withTransaction(PCollectionView)}. + * + *

    {@code
    + * SpannerConfig spannerConfig = ...
    + *
    + * PCollectionView tx =
    + * p.apply(
    + *    SpannerIO.createTransaction()
    + *        .withSpannerConfig(spannerConfig)
    + *        .withTimestampBound(TimestampBound.strong()));
    + *
    + * PCollection users = p.apply(
    + *    SpannerIO.read()
    + *        .withSpannerConfig(spannerConfig)
    + *        .withQuery("SELECT name, email FROM users")
    + *        .withTransaction(tx));
    + *
    + * PCollection tweets = p.apply(
    + *    SpannerIO.read()
    + *        .withSpannerConfig(spannerConfig)
    + *        .withQuery("SELECT user, tweet, date FROM tweets")
    + *        .withTransaction(tx));
    + * }
    * *

    Writing to Cloud Spanner

    * @@ -85,6 +162,33 @@ public class SpannerIO { private static final long DEFAULT_BATCH_SIZE_BYTES = 1024 * 1024; // 1 MB + /** + * Creates an uninitialized instance of {@link Read}. Before use, the {@link Read} must be + * configured with a {@link Read#withInstanceId} and {@link Read#withDatabaseId} that identify the + * Cloud Spanner database. + */ + @Experimental + public static Read read() { + return new AutoValue_SpannerIO_Read.Builder() + .setSpannerConfig(SpannerConfig.create()) + .setTimestampBound(TimestampBound.strong()) + .setKeySet(KeySet.all()) + .build(); + } + + /** + * Returns a transform that creates a batch transaction. By default, + * {@link TimestampBound#strong()} transaction is created, to override this use + * {@link CreateTransaction#withTimestampBound(TimestampBound)}. + */ + @Experimental + public static CreateTransaction createTransaction() { + return new AutoValue_SpannerIO_CreateTransaction.Builder() + .setSpannerConfig(SpannerConfig.create()) + .setTimestampBound(TimestampBound.strong()) + .build(); + } + /** * Creates an uninitialized instance of {@link Write}. Before use, the {@link Write} must be * configured with a {@link Write#withInstanceId} and {@link Write#withDatabaseId} that identify @@ -93,10 +197,285 @@ public class SpannerIO { @Experimental public static Write write() { return new AutoValue_SpannerIO_Write.Builder() + .setSpannerConfig(SpannerConfig.create()) .setBatchSizeBytes(DEFAULT_BATCH_SIZE_BYTES) .build(); } + /** + * A {@link PTransform} that reads data from Google Cloud Spanner. + * + * @see SpannerIO + */ + @Experimental(Experimental.Kind.SOURCE_SINK) + @AutoValue + public abstract static class Read extends PTransform> { + + abstract SpannerConfig getSpannerConfig(); + + @Nullable + abstract TimestampBound getTimestampBound(); + + @Nullable + abstract Statement getQuery(); + + @Nullable + abstract String getTable(); + + @Nullable + abstract String getIndex(); + + @Nullable + abstract List getColumns(); + + @Nullable + abstract KeySet getKeySet(); + + @Nullable + abstract PCollectionView getTransaction(); + + abstract Builder toBuilder(); + + @AutoValue.Builder + abstract static class Builder { + + abstract Builder setSpannerConfig(SpannerConfig spannerConfig); + + abstract Builder setTimestampBound(TimestampBound timestampBound); + + abstract Builder setQuery(Statement statement); + + abstract Builder setTable(String table); + + abstract Builder setIndex(String index); + + abstract Builder setColumns(List columns); + + abstract Builder setKeySet(KeySet keySet); + + abstract Builder setTransaction(PCollectionView transaction); + + abstract Read build(); + } + + /** Specifies the Cloud Spanner configuration. */ + public Read withSpannerConfig(SpannerConfig spannerConfig) { + return toBuilder().setSpannerConfig(spannerConfig).build(); + } + + /** Specifies the Cloud Spanner project. */ + public Read withProjectId(String projectId) { + return withProjectId(ValueProvider.StaticValueProvider.of(projectId)); + } + + /** Specifies the Cloud Spanner project. */ + public Read withProjectId(ValueProvider projectId) { + SpannerConfig config = getSpannerConfig(); + return withSpannerConfig(config.withProjectId(projectId)); + } + + /** Specifies the Cloud Spanner instance. */ + public Read withInstanceId(String instanceId) { + return withInstanceId(ValueProvider.StaticValueProvider.of(instanceId)); + } + + /** Specifies the Cloud Spanner instance. */ + public Read withInstanceId(ValueProvider instanceId) { + SpannerConfig config = getSpannerConfig(); + return withSpannerConfig(config.withInstanceId(instanceId)); + } + + /** Specifies the Cloud Spanner database. */ + public Read withDatabaseId(String databaseId) { + return withDatabaseId(ValueProvider.StaticValueProvider.of(databaseId)); + } + + /** Specifies the Cloud Spanner database. */ + public Read withDatabaseId(ValueProvider databaseId) { + SpannerConfig config = getSpannerConfig(); + return withSpannerConfig(config.withDatabaseId(databaseId)); + } + + @VisibleForTesting + Read withServiceFactory(ServiceFactory serviceFactory) { + SpannerConfig config = getSpannerConfig(); + return withSpannerConfig(config.withServiceFactory(serviceFactory)); + } + + public Read withTransaction(PCollectionView transaction) { + return toBuilder().setTransaction(transaction).build(); + } + + public Read withTimestamp(Timestamp timestamp) { + return withTimestampBound(TimestampBound.ofReadTimestamp(timestamp)); + } + + public Read withTimestampBound(TimestampBound timestampBound) { + return toBuilder().setTimestampBound(timestampBound).build(); + } + + public Read withTable(String table) { + return toBuilder().setTable(table).build(); + } + + public Read withColumns(String... columns) { + return withColumns(Arrays.asList(columns)); + } + + public Read withColumns(List columns) { + return toBuilder().setColumns(columns).build(); + } + + public Read withQuery(Statement statement) { + return toBuilder().setQuery(statement).build(); + } + + public Read withQuery(String sql) { + return withQuery(Statement.of(sql)); + } + + public Read withKeySet(KeySet keySet) { + return toBuilder().setKeySet(keySet).build(); + } + + public Read withIndex(String index) { + return toBuilder().setIndex(index).build(); + } + + + @Override + public void validate(PipelineOptions options) { + getSpannerConfig().validate(options); + checkNotNull( + getTimestampBound(), + "SpannerIO.read() runs in a read only transaction and requires timestamp to be set " + + "with withTimestampBound or withTimestamp method"); + + if (getQuery() != null) { + // TODO: validate query? + } else if (getTable() != null) { + // Assume read + checkNotNull( + getColumns(), + "For a read operation SpannerIO.read() requires a list of " + + "columns to set with withColumns method"); + checkArgument( + !getColumns().isEmpty(), + "For a read operation SpannerIO.read() requires a" + + " list of columns to set with withColumns method"); + } else { + throw new IllegalArgumentException( + "SpannerIO.read() requires configuring query or read operation."); + } + } + + @Override + public PCollection expand(PBegin input) { + Read config = this; + List> sideInputs = Collections.emptyList(); + if (getTimestampBound() != null) { + PCollectionView transaction = + input.apply(createTransaction().withSpannerConfig(getSpannerConfig())); + config = config.withTransaction(transaction); + sideInputs = Collections.singletonList(transaction); + } + return input + .apply(Create.of(1)) + .apply( + "Execute query", ParDo.of(new NaiveSpannerReadFn(config)).withSideInputs(sideInputs)); + } + } + + /** + * A {@link PTransform} that create a transaction. + * + * @see SpannerIO + */ + @Experimental(Experimental.Kind.SOURCE_SINK) + @AutoValue + public abstract static class CreateTransaction + extends PTransform> { + + abstract SpannerConfig getSpannerConfig(); + + @Nullable + abstract TimestampBound getTimestampBound(); + + abstract Builder toBuilder(); + + @Override + public PCollectionView expand(PBegin input) { + return input.apply(Create.of(1)) + .apply("Create transaction", ParDo.of(new CreateTransactionFn(this))) + .apply("As PCollectionView", View.asSingleton()); + } + + /** Specifies the Cloud Spanner configuration. */ + public CreateTransaction withSpannerConfig(SpannerConfig spannerConfig) { + return toBuilder().setSpannerConfig(spannerConfig).build(); + } + + /** Specifies the Cloud Spanner project. */ + public CreateTransaction withProjectId(String projectId) { + return withProjectId(ValueProvider.StaticValueProvider.of(projectId)); + } + + /** Specifies the Cloud Spanner project. */ + public CreateTransaction withProjectId(ValueProvider projectId) { + SpannerConfig config = getSpannerConfig(); + return withSpannerConfig(config.withProjectId(projectId)); + } + + /** Specifies the Cloud Spanner instance. */ + public CreateTransaction withInstanceId(String instanceId) { + return withInstanceId(ValueProvider.StaticValueProvider.of(instanceId)); + } + + /** Specifies the Cloud Spanner instance. */ + public CreateTransaction withInstanceId(ValueProvider instanceId) { + SpannerConfig config = getSpannerConfig(); + return withSpannerConfig(config.withInstanceId(instanceId)); + } + + /** Specifies the Cloud Spanner database. */ + public CreateTransaction withDatabaseId(String databaseId) { + return withDatabaseId(ValueProvider.StaticValueProvider.of(databaseId)); + } + + /** Specifies the Cloud Spanner database. */ + public CreateTransaction withDatabaseId(ValueProvider databaseId) { + SpannerConfig config = getSpannerConfig(); + return withSpannerConfig(config.withDatabaseId(databaseId)); + } + + @VisibleForTesting + CreateTransaction withServiceFactory( + ServiceFactory serviceFactory) { + SpannerConfig config = getSpannerConfig(); + return withSpannerConfig(config.withServiceFactory(serviceFactory)); + } + + public CreateTransaction withTimestampBound(TimestampBound timestampBound) { + return toBuilder().setTimestampBound(timestampBound).build(); + } + + @Override + public void validate(PipelineOptions options) { + getSpannerConfig().validate(options); + } + + /** A builder for {@link CreateTransaction}. */ + @AutoValue.Builder public abstract static class Builder { + + public abstract Builder setSpannerConfig(SpannerConfig spannerConfig); + + public abstract Builder setTimestampBound(TimestampBound newTimestampBound); + + public abstract CreateTransaction build(); + } + } + + /** * A {@link PTransform} that writes {@link Mutation} objects to Google Cloud Spanner. * @@ -106,8 +485,6 @@ public static Write write() { @AutoValue public abstract static class Write extends PTransform, PDone> { - private static final long serialVersionUID = 1920175411827980145L; - abstract SpannerConfig getSpannerConfig(); abstract long getBatchSizeBytes(); @@ -119,95 +496,53 @@ abstract static class Builder { abstract Builder setSpannerConfig(SpannerConfig spannerConfig); - abstract SpannerConfig.Builder spannerConfigBuilder(); - abstract Builder setBatchSizeBytes(long batchSizeBytes); abstract Write build(); } - /** - * Returns a new {@link SpannerIO.Write} that will write to the specified Cloud Spanner project. - * - *

    Does not modify this object. - */ + /** Specifies the Cloud Spanner configuration. */ + public Write withSpannerConfig(SpannerConfig spannerConfig) { + return toBuilder().setSpannerConfig(spannerConfig).build(); + } + + /** Specifies the Cloud Spanner project. */ public Write withProjectId(String projectId) { return withProjectId(ValueProvider.StaticValueProvider.of(projectId)); } - /** - * Returns a new {@link SpannerIO.Write} that will write to the specified Cloud Spanner project. - * - *

    Does not modify this object. - */ + /** Specifies the Cloud Spanner project. */ public Write withProjectId(ValueProvider projectId) { - Write.Builder builder = toBuilder(); - builder.spannerConfigBuilder().setProjectId(projectId); - return builder.build(); + SpannerConfig config = getSpannerConfig(); + return withSpannerConfig(config.withProjectId(projectId)); } - /** - * Returns a new {@link SpannerIO.Write} that will write to the specified Cloud Spanner - * instance. - * - *

    Does not modify this object. - */ + /** Specifies the Cloud Spanner instance. */ public Write withInstanceId(String instanceId) { return withInstanceId(ValueProvider.StaticValueProvider.of(instanceId)); } - /** - * Returns a new {@link SpannerIO.Write} that will write to the specified Cloud Spanner - * instance. - * - *

    Does not modify this object. - */ + /** Specifies the Cloud Spanner instance. */ public Write withInstanceId(ValueProvider instanceId) { - Write.Builder builder = toBuilder(); - builder.spannerConfigBuilder().setInstanceId(instanceId); - return builder.build(); + SpannerConfig config = getSpannerConfig(); + return withSpannerConfig(config.withInstanceId(instanceId)); } - /** - * Returns a new {@link SpannerIO.Write} that will write to the specified Cloud Spanner - * config. - * - *

    Does not modify this object. - */ - public Write withSpannerConfig(SpannerConfig spannerConfig) { - return toBuilder().setSpannerConfig(spannerConfig).build(); - } - - - /** - * Returns a new {@link SpannerIO.Write} with a new batch size limit. - * - *

    Does not modify this object. - */ - public Write withBatchSizeBytes(long batchSizeBytes) { - return toBuilder().setBatchSizeBytes(batchSizeBytes).build(); - } - - /** - * Returns a new {@link SpannerIO.Write} that will write to the specified Cloud Spanner - * database. - * - *

    Does not modify this object. - */ + /** Specifies the Cloud Spanner database. */ public Write withDatabaseId(String databaseId) { return withDatabaseId(ValueProvider.StaticValueProvider.of(databaseId)); } - /** - * Returns a new {@link SpannerIO.Write} that will write to the specified Cloud Spanner - * database. - * - *

    Does not modify this object. - */ + /** Specifies the Cloud Spanner database. */ public Write withDatabaseId(ValueProvider databaseId) { - Write.Builder builder = toBuilder(); - builder.spannerConfigBuilder().setDatabaseId(databaseId); - return builder.build(); + SpannerConfig config = getSpannerConfig(); + return withSpannerConfig(config.withDatabaseId(databaseId)); + } + + @VisibleForTesting + Write withServiceFactory(ServiceFactory serviceFactory) { + SpannerConfig config = getSpannerConfig(); + return withSpannerConfig(config.withServiceFactory(serviceFactory)); } /** @@ -217,11 +552,9 @@ public WriteGrouped grouped() { return new WriteGrouped(this); } - @VisibleForTesting - Write withServiceFactory(ServiceFactory serviceFactory) { - Write.Builder builder = toBuilder(); - builder.spannerConfigBuilder().setServiceFactory(serviceFactory); - return builder.build(); + /** Specifies the batch size limit. */ + public Write withBatchSizeBytes(long batchSizeBytes) { + return toBuilder().setBatchSizeBytes(batchSizeBytes).build(); } @Override diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteGroupFn.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteGroupFn.java index aed4832b7d863..34a11da8754f5 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteGroupFn.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteGroupFn.java @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.apache.beam.sdk.io.gcp.spanner; import com.google.cloud.spanner.AbortedException; diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/Transaction.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/Transaction.java new file mode 100644 index 0000000000000..22af3b8fe6639 --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/Transaction.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.spanner; + +import com.google.auto.value.AutoValue; +import com.google.cloud.Timestamp; +import java.io.Serializable; + +/** A transaction object. */ +@AutoValue +public abstract class Transaction implements Serializable { + + abstract Timestamp timestamp(); + + public static Transaction create(Timestamp timestamp) { + return new AutoValue_Transaction(timestamp); + } +} diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/GcpApiSurfaceTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/GcpApiSurfaceTest.java index 91caded1ad35f..8aac417f58167 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/GcpApiSurfaceTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/GcpApiSurfaceTest.java @@ -52,6 +52,7 @@ public void testGcpApiSurface() throws Exception { @SuppressWarnings("unchecked") final Set>> allowedClasses = ImmutableSet.of( + classesInPackage("com.google.api.core"), classesInPackage("com.google.api.client.googleapis"), classesInPackage("com.google.api.client.http"), classesInPackage("com.google.api.client.json"), @@ -60,9 +61,18 @@ public void testGcpApiSurface() throws Exception { classesInPackage("com.google.auth"), classesInPackage("com.google.bigtable.v2"), classesInPackage("com.google.cloud.bigtable.config"), + classesInPackage("com.google.spanner.v1"), + Matchers.>equalTo(com.google.api.gax.grpc.ApiException.class), Matchers.>equalTo(com.google.cloud.bigtable.grpc.BigtableClusterName.class), Matchers.>equalTo(com.google.cloud.bigtable.grpc.BigtableInstanceName.class), Matchers.>equalTo(com.google.cloud.bigtable.grpc.BigtableTableName.class), + Matchers.>equalTo(com.google.cloud.BaseServiceException.class), + Matchers.>equalTo(com.google.cloud.BaseServiceException.Error.class), + Matchers.>equalTo(com.google.cloud.BaseServiceException.ExceptionData.class), + Matchers.>equalTo(com.google.cloud.BaseServiceException.ExceptionData.Builder + .class), + Matchers.>equalTo(com.google.cloud.RetryHelper.RetryHelperException.class), + Matchers.>equalTo(com.google.cloud.grpc.BaseGrpcServiceException.class), Matchers.>equalTo(com.google.cloud.ByteArray.class), Matchers.>equalTo(com.google.cloud.Date.class), Matchers.>equalTo(com.google.cloud.Timestamp.class), diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/FakeServiceFactory.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/FakeServiceFactory.java new file mode 100644 index 0000000000000..753d807eb7ee8 --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/FakeServiceFactory.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.spanner; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.mockito.Mockito.withSettings; + +import com.google.cloud.ServiceFactory; +import com.google.cloud.spanner.DatabaseClient; +import com.google.cloud.spanner.DatabaseId; +import com.google.cloud.spanner.Spanner; +import com.google.cloud.spanner.SpannerOptions; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; +import javax.annotation.concurrent.GuardedBy; +import org.mockito.Matchers; + +/** + * A serialization friendly type service factory that maintains a mock {@link Spanner} and + * {@link DatabaseClient}. + * */ +class FakeServiceFactory + implements ServiceFactory, Serializable { + + // Marked as static so they could be returned by serviceFactory, which is serializable. + private static final Object lock = new Object(); + + @GuardedBy("lock") + private static final List mockSpanners = new ArrayList<>(); + + @GuardedBy("lock") + private static final List mockDatabaseClients = new ArrayList<>(); + + @GuardedBy("lock") + private static int count = 0; + + private final int index; + + public FakeServiceFactory() { + synchronized (lock) { + index = count++; + mockSpanners.add(mock(Spanner.class, withSettings().serializable())); + mockDatabaseClients.add(mock(DatabaseClient.class, withSettings().serializable())); + } + when(mockSpanner().getDatabaseClient(Matchers.any(DatabaseId.class))) + .thenReturn(mockDatabaseClient()); + } + + DatabaseClient mockDatabaseClient() { + synchronized (lock) { + return mockDatabaseClients.get(index); + } + } + + Spanner mockSpanner() { + synchronized (lock) { + return mockSpanners.get(index); + } + } + + @Override + public Spanner create(SpannerOptions serviceOptions) { + return mockSpanner(); + } +} diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOReadTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOReadTest.java new file mode 100644 index 0000000000000..e5d4e72f51981 --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOReadTest.java @@ -0,0 +1,275 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.spanner; + +import static org.junit.Assert.assertThat; +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.cloud.Timestamp; +import com.google.cloud.spanner.KeySet; +import com.google.cloud.spanner.ReadOnlyTransaction; +import com.google.cloud.spanner.ResultSets; +import com.google.cloud.spanner.Statement; +import com.google.cloud.spanner.Struct; +import com.google.cloud.spanner.TimestampBound; +import com.google.cloud.spanner.Type; +import com.google.cloud.spanner.Value; +import java.io.Serializable; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import org.apache.beam.sdk.testing.NeedsRunner; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.DoFnTester; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionView; +import org.hamcrest.Matchers; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mockito; + +/** Unit tests for {@link SpannerIO}. */ +@RunWith(JUnit4.class) +public class SpannerIOReadTest implements Serializable { + @Rule + public final transient TestPipeline pipeline = TestPipeline.create(); + @Rule + public final transient ExpectedException thrown = ExpectedException.none(); + + private FakeServiceFactory serviceFactory; + private ReadOnlyTransaction mockTx; + + private Type fakeType = Type.struct(Type.StructField.of("id", Type.int64()), + Type.StructField.of("name", Type.string())); + + private List fakeRows = Arrays.asList( + Struct.newBuilder().add("id", Value.int64(1)).add("name", Value.string("Alice")).build(), + Struct.newBuilder().add("id", Value.int64(2)).add("name", Value.string("Bob")).build()); + + @Before + @SuppressWarnings("unchecked") + public void setUp() throws Exception { + serviceFactory = new FakeServiceFactory(); + mockTx = Mockito.mock(ReadOnlyTransaction.class); + } + + @Test + public void emptyTransform() throws Exception { + SpannerIO.Read read = SpannerIO.read(); + thrown.expect(NullPointerException.class); + thrown.expectMessage("requires instance id to be set with"); + read.validate(null); + } + + @Test + public void emptyInstanceId() throws Exception { + SpannerIO.Read read = SpannerIO.read().withDatabaseId("123"); + thrown.expect(NullPointerException.class); + thrown.expectMessage("requires instance id to be set with"); + read.validate(null); + } + + @Test + public void emptyDatabaseId() throws Exception { + SpannerIO.Read read = SpannerIO.read().withInstanceId("123"); + thrown.expect(NullPointerException.class); + thrown.expectMessage("requires database id to be set with"); + read.validate(null); + } + + @Test + public void emptyQuery() throws Exception { + SpannerIO.Read read = + SpannerIO.read().withInstanceId("123").withDatabaseId("aaa").withTimestamp(Timestamp.now()); + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("requires configuring query or read operation"); + read.validate(null); + } + + @Test + public void emptyColumns() throws Exception { + SpannerIO.Read read = + SpannerIO.read() + .withInstanceId("123") + .withDatabaseId("aaa") + .withTimestamp(Timestamp.now()) + .withTable("users"); + thrown.expect(NullPointerException.class); + thrown.expectMessage("requires a list of columns"); + read.validate(null); + } + + @Test + public void validRead() throws Exception { + SpannerIO.Read read = + SpannerIO.read() + .withInstanceId("123") + .withDatabaseId("aaa") + .withTimestamp(Timestamp.now()) + .withTable("users") + .withColumns("id", "name", "email"); + read.validate(null); + } + + @Test + public void validQuery() throws Exception { + SpannerIO.Read read = + SpannerIO.read() + .withInstanceId("123") + .withDatabaseId("aaa") + .withTimestamp(Timestamp.now()) + .withQuery("SELECT * FROM users"); + read.validate(null); + } + + @Test + public void runQuery() throws Exception { + SpannerIO.Read read = + SpannerIO.read() + .withInstanceId("123") + .withDatabaseId("aaa") + .withTimestamp(Timestamp.now()) + .withQuery("SELECT * FROM users") + .withServiceFactory(serviceFactory); + + NaiveSpannerReadFn readFn = new NaiveSpannerReadFn(read); + DoFnTester fnTester = DoFnTester.of(readFn); + + when(serviceFactory.mockDatabaseClient().readOnlyTransaction(any(TimestampBound.class))) + .thenReturn(mockTx); + when(mockTx.executeQuery(any(Statement.class))) + .thenReturn(ResultSets.forRows(fakeType, fakeRows)); + + List result = fnTester.processBundle(1); + assertThat(result, Matchers.iterableWithSize(2)); + + verify(serviceFactory.mockDatabaseClient()).readOnlyTransaction(TimestampBound + .strong()); + verify(mockTx).executeQuery(Statement.of("SELECT * FROM users")); + } + + @Test + public void runRead() throws Exception { + SpannerIO.Read read = + SpannerIO.read() + .withInstanceId("123") + .withDatabaseId("aaa") + .withTimestamp(Timestamp.now()) + .withTable("users") + .withColumns("id", "name") + .withServiceFactory(serviceFactory); + + NaiveSpannerReadFn readFn = new NaiveSpannerReadFn(read); + DoFnTester fnTester = DoFnTester.of(readFn); + + when(serviceFactory.mockDatabaseClient().readOnlyTransaction(any(TimestampBound.class))) + .thenReturn(mockTx); + when(mockTx.read("users", KeySet.all(), Arrays.asList("id", "name"))) + .thenReturn(ResultSets.forRows(fakeType, fakeRows)); + + List result = fnTester.processBundle(1); + assertThat(result, Matchers.iterableWithSize(2)); + + verify(serviceFactory.mockDatabaseClient()).readOnlyTransaction(TimestampBound.strong()); + verify(mockTx).read("users", KeySet.all(), Arrays.asList("id", "name")); + } + + @Test + public void runReadUsingIndex() throws Exception { + SpannerIO.Read read = + SpannerIO.read() + .withInstanceId("123") + .withDatabaseId("aaa") + .withTimestamp(Timestamp.now()) + .withTable("users") + .withColumns("id", "name") + .withIndex("theindex") + .withServiceFactory(serviceFactory); + + NaiveSpannerReadFn readFn = new NaiveSpannerReadFn(read); + DoFnTester fnTester = DoFnTester.of(readFn); + + when(serviceFactory.mockDatabaseClient().readOnlyTransaction(any(TimestampBound.class))) + .thenReturn(mockTx); + when(mockTx.readUsingIndex("users", "theindex", KeySet.all(), Arrays.asList("id", "name"))) + .thenReturn(ResultSets.forRows(fakeType, fakeRows)); + + List result = fnTester.processBundle(1); + assertThat(result, Matchers.iterableWithSize(2)); + + verify(serviceFactory.mockDatabaseClient()).readOnlyTransaction(TimestampBound.strong()); + verify(mockTx).readUsingIndex("users", "theindex", KeySet.all(), Arrays.asList("id", "name")); + } + + @Test + @Category(NeedsRunner.class) + public void readPipeline() throws Exception { + Timestamp timestamp = Timestamp.ofTimeMicroseconds(12345); + + PCollectionView tx = pipeline + .apply("tx", SpannerIO.createTransaction() + .withInstanceId("123") + .withDatabaseId("aaa") + .withServiceFactory(serviceFactory)); + + PCollection one = pipeline.apply("read q", SpannerIO.read() + .withInstanceId("123") + .withDatabaseId("aaa") + .withTimestamp(Timestamp.now()) + .withQuery("SELECT * FROM users") + .withServiceFactory(serviceFactory) + .withTransaction(tx)); + PCollection two = pipeline.apply("read r", SpannerIO.read() + .withInstanceId("123") + .withDatabaseId("aaa") + .withTimestamp(Timestamp.now()) + .withTable("users") + .withColumns("id", "name") + .withServiceFactory(serviceFactory) + .withTransaction(tx)); + + when(serviceFactory.mockDatabaseClient().readOnlyTransaction(any(TimestampBound.class))) + .thenReturn(mockTx); + + when(mockTx.executeQuery(Statement.of("SELECT 1"))).thenReturn(ResultSets.forRows(Type.struct(), + Collections.emptyList())); + + when(mockTx.executeQuery(Statement.of("SELECT * FROM users"))) + .thenReturn(ResultSets.forRows(fakeType, fakeRows)); + when(mockTx.read("users", KeySet.all(), Arrays.asList("id", "name"))) + .thenReturn(ResultSets.forRows(fakeType, fakeRows)); + when(mockTx.getReadTimestamp()).thenReturn(timestamp); + + PAssert.that(one).containsInAnyOrder(fakeRows); + PAssert.that(two).containsInAnyOrder(fakeRows); + + pipeline.run(); + + verify(serviceFactory.mockDatabaseClient(), times(2)) + .readOnlyTransaction(TimestampBound.ofReadTimestamp(timestamp)); + } +} diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOWriteTest.java similarity index 85% rename from sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOTest.java rename to sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOWriteTest.java index abeac0a8f4ae1..09cdb8e995d7a 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOWriteTest.java @@ -21,24 +21,14 @@ import static org.hamcrest.Matchers.hasSize; import static org.junit.Assert.assertThat; import static org.mockito.Mockito.argThat; -import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; -import static org.mockito.Mockito.withSettings; -import com.google.cloud.ServiceFactory; -import com.google.cloud.spanner.DatabaseClient; import com.google.cloud.spanner.DatabaseId; import com.google.cloud.spanner.Mutation; -import com.google.cloud.spanner.Spanner; -import com.google.cloud.spanner.SpannerOptions; import com.google.common.collect.Iterables; import java.io.Serializable; -import java.util.ArrayList; import java.util.Arrays; -import java.util.List; -import javax.annotation.concurrent.GuardedBy; import org.apache.beam.sdk.testing.NeedsRunner; import org.apache.beam.sdk.testing.TestPipeline; @@ -54,14 +44,12 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentMatcher; -import org.mockito.Matchers; - /** * Unit tests for {@link SpannerIO}. */ @RunWith(JUnit4.class) -public class SpannerIOTest implements Serializable { +public class SpannerIOWriteTest implements Serializable { @Rule public final transient TestPipeline pipeline = TestPipeline.create(); @Rule public transient ExpectedException thrown = ExpectedException.none(); @@ -251,50 +239,6 @@ public void displayData() throws Exception { assertThat(data, hasDisplayItem("batchSizeBytes", 123)); } - private static class FakeServiceFactory - implements ServiceFactory, Serializable { - // Marked as static so they could be returned by serviceFactory, which is serializable. - private static final Object lock = new Object(); - - @GuardedBy("lock") - private static final List mockSpanners = new ArrayList<>(); - - @GuardedBy("lock") - private static final List mockDatabaseClients = new ArrayList<>(); - - @GuardedBy("lock") - private static int count = 0; - - private final int index; - - public FakeServiceFactory() { - synchronized (lock) { - index = count++; - mockSpanners.add(mock(Spanner.class, withSettings().serializable())); - mockDatabaseClients.add(mock(DatabaseClient.class, withSettings().serializable())); - } - when(mockSpanner().getDatabaseClient(Matchers.any(DatabaseId.class))) - .thenReturn(mockDatabaseClient()); - } - - DatabaseClient mockDatabaseClient() { - synchronized (lock) { - return mockDatabaseClients.get(index); - } - } - - Spanner mockSpanner() { - synchronized (lock) { - return mockSpanners.get(index); - } - } - - @Override - public Spanner create(SpannerOptions serviceOptions) { - return mockSpanner(); - } - } - private static class IterableOfSize extends ArgumentMatcher> { private final int size; diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerReadIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerReadIT.java new file mode 100644 index 0000000000000..f5d7cbd6c31d3 --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerReadIT.java @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.spanner; + +import com.google.cloud.spanner.Database; +import com.google.cloud.spanner.DatabaseAdminClient; +import com.google.cloud.spanner.DatabaseClient; +import com.google.cloud.spanner.DatabaseId; +import com.google.cloud.spanner.Mutation; +import com.google.cloud.spanner.Operation; +import com.google.cloud.spanner.Spanner; +import com.google.cloud.spanner.SpannerOptions; +import com.google.cloud.spanner.Struct; +import com.google.cloud.spanner.TimestampBound; +import com.google.spanner.admin.database.v1.CreateDatabaseMetadata; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import org.apache.beam.sdk.options.Default; +import org.apache.beam.sdk.options.Description; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.testing.TestPipelineOptions; +import org.apache.beam.sdk.transforms.Count; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionView; +import org.apache.commons.lang3.RandomStringUtils; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** End-to-end test of Cloud Spanner Source. */ +@RunWith(JUnit4.class) +public class SpannerReadIT { + + private static final int MAX_DB_NAME_LENGTH = 30; + + @Rule public final transient TestPipeline p = TestPipeline.create(); + + /** Pipeline options for this test. */ + public interface SpannerTestPipelineOptions extends TestPipelineOptions { + @Description("Project ID for Spanner") + @Default.String("apache-beam-testing") + String getProjectId(); + void setProjectId(String value); + + @Description("Instance ID to write to in Spanner") + @Default.String("beam-test") + String getInstanceId(); + void setInstanceId(String value); + + @Description("Database ID prefix to write to in Spanner") + @Default.String("beam-testdb") + String getDatabaseIdPrefix(); + void setDatabaseIdPrefix(String value); + + @Description("Table name") + @Default.String("users") + String getTable(); + void setTable(String value); + } + + private Spanner spanner; + private DatabaseAdminClient databaseAdminClient; + private SpannerTestPipelineOptions options; + private String databaseName; + + @Before + public void setUp() throws Exception { + PipelineOptionsFactory.register(SpannerTestPipelineOptions.class); + options = TestPipeline.testingPipelineOptions().as(SpannerTestPipelineOptions.class); + + spanner = SpannerOptions.newBuilder().setProjectId(options.getProjectId()).build().getService(); + + databaseName = generateDatabaseName(); + + databaseAdminClient = spanner.getDatabaseAdminClient(); + + // Delete database if exists. + databaseAdminClient.dropDatabase(options.getInstanceId(), databaseName); + + Operation op = + databaseAdminClient.createDatabase( + options.getInstanceId(), + databaseName, + Collections.singleton( + "CREATE TABLE " + + options.getTable() + + " (" + + " Key INT64," + + " Value STRING(MAX)," + + ") PRIMARY KEY (Key)")); + op.waitFor(); + } + + @Test + public void testRead() throws Exception { + DatabaseClient databaseClient = + spanner.getDatabaseClient( + DatabaseId.of( + options.getProjectId(), options.getInstanceId(), databaseName)); + + List mutations = new ArrayList<>(); + for (int i = 0; i < 5L; i++) { + mutations.add( + Mutation.newInsertOrUpdateBuilder(options.getTable()) + .set("key") + .to((long) i) + .set("value") + .to(RandomStringUtils.random(100, true, true)) + .build()); + } + + databaseClient.writeAtLeastOnce(mutations); + + SpannerConfig spannerConfig = SpannerConfig.create() + .withProjectId(options.getProjectId()) + .withInstanceId(options.getInstanceId()) + .withDatabaseId(databaseName); + + PCollectionView tx = + p.apply( + SpannerIO.createTransaction() + .withSpannerConfig(spannerConfig) + .withTimestampBound(TimestampBound.strong())); + + PCollection output = + p.apply( + SpannerIO.read() + .withSpannerConfig(spannerConfig) + .withQuery("SELECT * FROM " + options.getTable()) + .withTransaction(tx)); + PAssert.thatSingleton(output.apply("Count rows", Count.globally())).isEqualTo(5L); + p.run(); + } + + @After + public void tearDown() throws Exception { + databaseAdminClient.dropDatabase(options.getInstanceId(), databaseName); + spanner.close(); + } + + private String generateDatabaseName() { + String random = + RandomStringUtils.randomAlphanumeric( + MAX_DB_NAME_LENGTH - 1 - options.getDatabaseIdPrefix().length()) + .toLowerCase(); + return options.getDatabaseIdPrefix() + "-" + random; + } +} From 58fba590ddc554a343036a7beeffe9caa319aa81 Mon Sep 17 00:00:00 2001 From: Kenneth Knowles Date: Tue, 27 Jun 2017 14:35:00 -0700 Subject: [PATCH 126/200] Add utility to expand list of PCollectionViews --- .../apache/beam/sdk/values/PCollectionViews.java | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionViews.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionViews.java index 0c04370a11c51..e17e146353e45 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionViews.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionViews.java @@ -21,6 +21,7 @@ import com.google.common.base.MoreObjects; import com.google.common.collect.HashMultimap; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; import com.google.common.collect.Multimap; import java.io.IOException; @@ -38,6 +39,7 @@ import org.apache.beam.sdk.coders.IterableCoder; import org.apache.beam.sdk.transforms.Materialization; import org.apache.beam.sdk.transforms.Materializations; +import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ViewFn; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.InvalidWindows; @@ -138,6 +140,18 @@ public static PCollectionView valueCoder); } + /** + * Expands a list of {@link PCollectionView} into the form needed for + * {@link PTransform#getAdditionalInputs()}. + */ + public static Map, PValue> toAdditionalInputs(Iterable> views) { + ImmutableMap.Builder, PValue> additionalInputs = ImmutableMap.builder(); + for (PCollectionView view : views) { + additionalInputs.put(view.getTagInternal(), view.getPCollection()); + } + return additionalInputs.build(); + } + /** * Implementation of conversion of singleton {@code Iterable>} to {@code T}. * From a66bcd68a1e56d5d38fccfce2ffeec28ba1c82de Mon Sep 17 00:00:00 2001 From: Kenneth Knowles Date: Tue, 13 Jun 2017 10:00:09 -0700 Subject: [PATCH 127/200] Fix getAdditionalInputs for SplittableParDo transforms --- .../apache/beam/runners/apex/ApexRunner.java | 2 +- .../core/construction/SplittableParDo.java | 66 ++++++++++++++----- .../construction/SplittableParDoTest.java | 8 +-- .../direct/ParDoMultiOverrideFactory.java | 2 +- .../FlinkStreamingPipelineTranslator.java | 2 +- .../dataflow/SplittableParDoOverrides.java | 2 +- 6 files changed, 57 insertions(+), 25 deletions(-) diff --git a/runners/apex/src/main/java/org/apache/beam/runners/apex/ApexRunner.java b/runners/apex/src/main/java/org/apache/beam/runners/apex/ApexRunner.java index 95b354a9fe337..fd0a1c93d1e66 100644 --- a/runners/apex/src/main/java/org/apache/beam/runners/apex/ApexRunner.java +++ b/runners/apex/src/main/java/org/apache/beam/runners/apex/ApexRunner.java @@ -381,7 +381,7 @@ public PTransformReplacement, PCollectionTuple> getReplaceme AppliedPTransform, PCollectionTuple, MultiOutput> transform) { return PTransformReplacement.of(PTransformReplacements.getSingletonMainInput(transform), - new SplittableParDo<>(transform.getTransform())); + SplittableParDo.forJavaParDo(transform.getTransform())); } @Override diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDo.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDo.java index 5ccafcbc8ea17..f31b495739b8a 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDo.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDo.java @@ -18,9 +18,9 @@ package org.apache.beam.runners.core.construction; import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Preconditions.checkNotNull; import java.util.List; +import java.util.Map; import java.util.UUID; import org.apache.beam.runners.core.construction.PTransformTranslation.RawPTransform; import org.apache.beam.sdk.annotations.Experimental; @@ -40,6 +40,8 @@ import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionTuple; import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.PCollectionViews; +import org.apache.beam.sdk.values.PValue; import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.TupleTagList; import org.apache.beam.sdk.values.WindowingStrategy; @@ -64,7 +66,11 @@ @Experimental(Experimental.Kind.SPLITTABLE_DO_FN) public class SplittableParDo extends PTransform, PCollectionTuple> { - private final ParDo.MultiOutput parDo; + + private final DoFn doFn; + private final List> sideInputs; + private final TupleTag mainOutputTag; + private final TupleTagList additionalOutputTags; public static final String SPLITTABLE_PROCESS_URN = "urn:beam:runners_core:transforms:splittable_process:v1"; @@ -75,24 +81,39 @@ public class SplittableParDo public static final String SPLITTABLE_GBKIKWI_URN = "urn:beam:runners_core:transforms:splittable_gbkikwi:v1"; + private SplittableParDo( + DoFn doFn, + TupleTag mainOutputTag, + List> sideInputs, + TupleTagList additionalOutputTags) { + checkArgument( + DoFnSignatures.getSignature(doFn.getClass()).processElement().isSplittable(), + "fn must be a splittable DoFn"); + this.doFn = doFn; + this.mainOutputTag = mainOutputTag; + this.sideInputs = sideInputs; + this.additionalOutputTags = additionalOutputTags; + } + /** - * Creates the transform for the given original multi-output {@link ParDo}. + * Creates a {@link SplittableParDo} from an original Java {@link ParDo}. * * @param parDo The splittable {@link ParDo} transform. */ - public SplittableParDo(ParDo.MultiOutput parDo) { - checkNotNull(parDo, "parDo must not be null"); - this.parDo = parDo; - checkArgument( - DoFnSignatures.getSignature(parDo.getFn().getClass()).processElement().isSplittable(), - "fn must be a splittable DoFn"); + public static SplittableParDo forJavaParDo( + ParDo.MultiOutput parDo) { + checkArgument(parDo != null, "parDo must not be null"); + return new SplittableParDo( + parDo.getFn(), + parDo.getMainOutputTag(), + parDo.getSideInputs(), + parDo.getAdditionalOutputTags()); } @Override public PCollectionTuple expand(PCollection input) { - DoFn fn = parDo.getFn(); Coder restrictionCoder = - DoFnInvokers.invokerFor(fn) + DoFnInvokers.invokerFor(doFn) .invokeGetRestrictionCoder(input.getPipeline().getCoderRegistry()); Coder> splitCoder = KvCoder.of(input.getCoder(), restrictionCoder); @@ -100,9 +121,10 @@ public PCollectionTuple expand(PCollection input) { input .apply( "Pair with initial restriction", - ParDo.of(new PairWithRestrictionFn(fn))) + ParDo.of(new PairWithRestrictionFn(doFn))) .setCoder(splitCoder) - .apply("Split restriction", ParDo.of(new SplitRestrictionFn(fn))) + .apply( + "Split restriction", ParDo.of(new SplitRestrictionFn(doFn))) .setCoder(splitCoder) // ProcessFn requires all input elements to be in a single window and have a single // element per work item. This must precede the unique keying so each key has a single @@ -115,13 +137,18 @@ public PCollectionTuple expand(PCollection input) { return keyedRestrictions.apply( "ProcessKeyedElements", new ProcessKeyedElements<>( - fn, + doFn, input.getCoder(), restrictionCoder, (WindowingStrategy) input.getWindowingStrategy(), - parDo.getSideInputs(), - parDo.getMainOutputTag(), - parDo.getAdditionalOutputTags())); + sideInputs, + mainOutputTag, + additionalOutputTags)); + } + + @Override + public Map, PValue> getAdditionalInputs() { + return PCollectionViews.toAdditionalInputs(sideInputs); } /** @@ -230,6 +257,11 @@ public static PCollectionTuple createPrimitiveOutputFor( return outputs; } + @Override + public Map, PValue> getAdditionalInputs() { + return PCollectionViews.toAdditionalInputs(sideInputs); + } + @Override public String getUrn() { return SPLITTABLE_PROCESS_KEYED_ELEMENTS_URN; diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/SplittableParDoTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/SplittableParDoTest.java index 6e4d6c458f04c..f4c596e019517 100644 --- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/SplittableParDoTest.java +++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/SplittableParDoTest.java @@ -122,14 +122,14 @@ public void testBoundednessForBoundedFn() { "Applying a bounded SDF to a bounded collection produces a bounded collection", PCollection.IsBounded.BOUNDED, makeBoundedCollection(pipeline) - .apply("bounded to bounded", new SplittableParDo<>(makeParDo(boundedFn))) + .apply("bounded to bounded", SplittableParDo.forJavaParDo(makeParDo(boundedFn))) .get(MAIN_OUTPUT_TAG) .isBounded()); assertEquals( "Applying a bounded SDF to an unbounded collection produces an unbounded collection", PCollection.IsBounded.UNBOUNDED, makeUnboundedCollection(pipeline) - .apply("bounded to unbounded", new SplittableParDo<>(makeParDo(boundedFn))) + .apply("bounded to unbounded", SplittableParDo.forJavaParDo(makeParDo(boundedFn))) .get(MAIN_OUTPUT_TAG) .isBounded()); } @@ -143,14 +143,14 @@ public void testBoundednessForUnboundedFn() { "Applying an unbounded SDF to a bounded collection produces a bounded collection", PCollection.IsBounded.UNBOUNDED, makeBoundedCollection(pipeline) - .apply("unbounded to bounded", new SplittableParDo<>(makeParDo(unboundedFn))) + .apply("unbounded to bounded", SplittableParDo.forJavaParDo(makeParDo(unboundedFn))) .get(MAIN_OUTPUT_TAG) .isBounded()); assertEquals( "Applying an unbounded SDF to an unbounded collection produces an unbounded collection", PCollection.IsBounded.UNBOUNDED, makeUnboundedCollection(pipeline) - .apply("unbounded to unbounded", new SplittableParDo<>(makeParDo(unboundedFn))) + .apply("unbounded to unbounded", SplittableParDo.forJavaParDo(makeParDo(unboundedFn))) .get(MAIN_OUTPUT_TAG) .isBounded()); } diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java index b20113edf588d..9a26283214784 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java @@ -81,7 +81,7 @@ private PTransform, PCollectionTuple> getReplaceme DoFn fn = transform.getFn(); DoFnSignature signature = DoFnSignatures.getSignature(fn.getClass()); if (signature.processElement().isSplittable()) { - return new SplittableParDo(transform); + return (PTransform) SplittableParDo.forJavaParDo(transform); } else if (signature.stateDeclarations().size() > 0 || signature.timerDeclarations().size() > 0) { diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPipelineTranslator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPipelineTranslator.java index 27bb4ecfb9fe4..ebc934516181f 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPipelineTranslator.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPipelineTranslator.java @@ -188,7 +188,7 @@ static class SplittableParDoOverrideFactory transform) { return PTransformReplacement.of( PTransformReplacements.getSingletonMainInput(transform), - new SplittableParDo<>(transform.getTransform())); + SplittableParDo.forJavaParDo(transform.getTransform())); } @Override diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/SplittableParDoOverrides.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/SplittableParDoOverrides.java index 93228782372bc..fc010f81aadd1 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/SplittableParDoOverrides.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/SplittableParDoOverrides.java @@ -64,7 +64,7 @@ public PTransformReplacement, PCollectionTuple> getReplaceme appliedTransform) { return PTransformReplacement.of( PTransformReplacements.getSingletonMainInput(appliedTransform), - new SplittableParDo<>(appliedTransform.getTransform())); + SplittableParDo.forJavaParDo(appliedTransform.getTransform())); } @Override From 423827665ae5923cd7fccc654bd9a5e1efed7876 Mon Sep 17 00:00:00 2001 From: Kenneth Knowles Date: Tue, 27 Jun 2017 14:39:06 -0700 Subject: [PATCH 128/200] Use PCollectionViews.toAdditionalInputs in ParDoMultiOverrideFactory --- .../runners/direct/ParDoMultiOverrideFactory.java | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java index 9a26283214784..2904bc170c442 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java @@ -19,7 +19,6 @@ import static com.google.common.base.Preconditions.checkState; -import com.google.common.collect.ImmutableMap; import java.util.List; import java.util.Map; import org.apache.beam.runners.core.KeyedWorkItem; @@ -50,6 +49,7 @@ import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionTuple; import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.PCollectionViews; import org.apache.beam.sdk.values.PValue; import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.TupleTagList; @@ -123,11 +123,7 @@ public GbkThenStatefulParDo( @Override public Map, PValue> getAdditionalInputs() { - ImmutableMap.Builder, PValue> additionalInputs = ImmutableMap.builder(); - for (PCollectionView sideInput : sideInputs) { - additionalInputs.put(sideInput.getTagInternal(), sideInput.getPCollection()); - } - return additionalInputs.build(); + return PCollectionViews.toAdditionalInputs(sideInputs); } @Override @@ -231,11 +227,7 @@ public TupleTagList getAdditionalOutputTags() { @Override public Map, PValue> getAdditionalInputs() { - ImmutableMap.Builder, PValue> additionalInputs = ImmutableMap.builder(); - for (PCollectionView sideInput : sideInputs) { - additionalInputs.put(sideInput.getTagInternal(), sideInput.getPCollection()); - } - return additionalInputs.build(); + return PCollectionViews.toAdditionalInputs(sideInputs); } @Override From ed476dd2807577c8069087aa0764b21d1bb06512 Mon Sep 17 00:00:00 2001 From: Kenneth Knowles Date: Tue, 27 Jun 2017 14:41:30 -0700 Subject: [PATCH 129/200] Use PCollectionViews.toAdditionalInputs in ParDo --- .../java/org/apache/beam/sdk/transforms/ParDo.java | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java index edf14191cd1c2..db1f7918e4fcd 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java @@ -20,7 +20,6 @@ import static com.google.common.base.Preconditions.checkArgument; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; import java.io.Serializable; import java.lang.reflect.ParameterizedType; import java.lang.reflect.Type; @@ -50,6 +49,7 @@ import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionTuple; import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.PCollectionViews; import org.apache.beam.sdk.values.PValue; import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.TupleTagList; @@ -662,11 +662,7 @@ public List> getSideInputs() { */ @Override public Map, PValue> getAdditionalInputs() { - ImmutableMap.Builder, PValue> additionalInputs = ImmutableMap.builder(); - for (PCollectionView sideInput : sideInputs) { - additionalInputs.put(sideInput.getTagInternal(), sideInput.getPCollection()); - } - return additionalInputs.build(); + return PCollectionViews.toAdditionalInputs(sideInputs); } } @@ -807,11 +803,7 @@ public List> getSideInputs() { */ @Override public Map, PValue> getAdditionalInputs() { - ImmutableMap.Builder, PValue> additionalInputs = ImmutableMap.builder(); - for (PCollectionView sideInput : sideInputs) { - additionalInputs.put(sideInput.getTagInternal(), sideInput.getPCollection()); - } - return additionalInputs.build(); + return PCollectionViews.toAdditionalInputs(sideInputs); } } From 27674f07cf8363bb6b3c051a990caa5d61b8cd5c Mon Sep 17 00:00:00 2001 From: Kenneth Knowles Date: Tue, 27 Jun 2017 14:44:50 -0700 Subject: [PATCH 130/200] Use PCollectionViews.toAdditionalInputs in Combine --- .../org/apache/beam/sdk/transforms/Combine.java | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Combine.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Combine.java index 6a90bcfde2e48..d7effb5f7ed9d 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Combine.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Combine.java @@ -20,7 +20,6 @@ import static com.google.common.base.Preconditions.checkState; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; import java.io.IOException; import java.io.InputStream; @@ -1122,11 +1121,7 @@ public List> getSideInputs() { */ @Override public Map, PValue> getAdditionalInputs() { - ImmutableMap.Builder, PValue> additionalInputs = ImmutableMap.builder(); - for (PCollectionView sideInput : sideInputs) { - additionalInputs.put(sideInput.getTagInternal(), sideInput.getPCollection()); - } - return additionalInputs.build(); + return PCollectionViews.toAdditionalInputs(sideInputs); } /** @@ -1578,11 +1573,7 @@ public List> getSideInputs() { */ @Override public Map, PValue> getAdditionalInputs() { - ImmutableMap.Builder, PValue> additionalInputs = ImmutableMap.builder(); - for (PCollectionView sideInput : sideInputs) { - additionalInputs.put(sideInput.getTagInternal(), sideInput.getPCollection()); - } - return additionalInputs.build(); + return PCollectionViews.toAdditionalInputs(sideInputs); } @Override From b1ed9757cead18b006d2e22c73fe1399a3022ae5 Mon Sep 17 00:00:00 2001 From: Etienne Chauchot Date: Wed, 21 Jun 2017 10:14:08 +0200 Subject: [PATCH 131/200] [BEAM-2488] Elasticsearch IO should read also in replica shards --- sdks/java/io/elasticsearch/pom.xml | 8 ++++++++ .../beam/sdk/io/elasticsearch/ElasticsearchIO.java | 11 +---------- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/sdks/java/io/elasticsearch/pom.xml b/sdks/java/io/elasticsearch/pom.xml index 03632cea0007c..c8e308c3ceac8 100644 --- a/sdks/java/io/elasticsearch/pom.xml +++ b/sdks/java/io/elasticsearch/pom.xml @@ -137,6 +137,14 @@ test + + + net.java.dev.jna + jna + 4.1.0 + test + + org.apache.beam beam-runners-direct-java diff --git a/sdks/java/io/elasticsearch/src/main/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIO.java b/sdks/java/io/elasticsearch/src/main/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIO.java index e3965dc6a0c01..fa67fe194f78f 100644 --- a/sdks/java/io/elasticsearch/src/main/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIO.java +++ b/sdks/java/io/elasticsearch/src/main/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIO.java @@ -455,16 +455,7 @@ public List> split( while (shards.hasNext()) { Map.Entry shardJson = shards.next(); String shardId = shardJson.getKey(); - JsonNode value = (JsonNode) shardJson.getValue(); - boolean isPrimaryShard = - value - .path(0) - .path("routing") - .path("primary") - .asBoolean(); - if (isPrimaryShard) { - sources.add(new BoundedElasticsearchSource(spec, shardId)); - } + sources.add(new BoundedElasticsearchSource(spec, shardId)); } checkArgument(!sources.isEmpty(), "No primary shard found"); return sources; From 2cb2161cec824a5ca5e719a92243029e712347c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Isma=C3=ABl=20Mej=C3=ADa?= Date: Wed, 28 Jun 2017 10:31:25 +0200 Subject: [PATCH 132/200] Add Experimental annotation to AMQP and refine Kind for the Experimental IOs --- .../src/main/java/org/apache/beam/sdk/io/amqp/AmqpIO.java | 2 ++ .../java/org/apache/beam/sdk/io/cassandra/CassandraIO.java | 2 +- .../apache/beam/sdk/io/elasticsearch/ElasticsearchIO.java | 2 +- .../org/apache/beam/sdk/io/gcp/bigtable/BigtableIO.java | 6 +++--- .../java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java | 2 +- .../beam/sdk/io/hadoop/inputformat/HadoopInputFormatIO.java | 2 +- .../src/main/java/org/apache/beam/sdk/io/hbase/HBaseIO.java | 2 +- .../java/org/apache/beam/sdk/io/hcatalog/HCatalogIO.java | 2 +- .../src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java | 2 +- .../jms/src/main/java/org/apache/beam/sdk/io/jms/JmsIO.java | 2 +- .../src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java | 2 +- .../main/java/org/apache/beam/sdk/io/kinesis/KinesisIO.java | 2 +- .../org/apache/beam/sdk/io/mongodb/MongoDbGridFSIO.java | 2 +- .../main/java/org/apache/beam/sdk/io/mongodb/MongoDbIO.java | 2 +- .../src/main/java/org/apache/beam/sdk/io/mqtt/MqttIO.java | 2 +- 15 files changed, 18 insertions(+), 16 deletions(-) diff --git a/sdks/java/io/amqp/src/main/java/org/apache/beam/sdk/io/amqp/AmqpIO.java b/sdks/java/io/amqp/src/main/java/org/apache/beam/sdk/io/amqp/AmqpIO.java index b9a0be9a078fe..1f307b252cb05 100644 --- a/sdks/java/io/amqp/src/main/java/org/apache/beam/sdk/io/amqp/AmqpIO.java +++ b/sdks/java/io/amqp/src/main/java/org/apache/beam/sdk/io/amqp/AmqpIO.java @@ -31,6 +31,7 @@ import javax.annotation.Nullable; +import org.apache.beam.sdk.annotations.Experimental; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.SerializableCoder; import org.apache.beam.sdk.io.UnboundedSource; @@ -94,6 +95,7 @@ * * } */ +@Experimental(Experimental.Kind.SOURCE_SINK) public class AmqpIO { public static Read read() { diff --git a/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/CassandraIO.java b/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/CassandraIO.java index b6f4ef6bb67e0..32905b77258a8 100644 --- a/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/CassandraIO.java +++ b/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/CassandraIO.java @@ -82,7 +82,7 @@ * .withEntity(Person.class)); * } */ -@Experimental +@Experimental(Experimental.Kind.SOURCE_SINK) public class CassandraIO { private static final Logger LOG = LoggerFactory.getLogger(CassandraIO.class); diff --git a/sdks/java/io/elasticsearch/src/main/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIO.java b/sdks/java/io/elasticsearch/src/main/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIO.java index fa67fe194f78f..4d7688772a00c 100644 --- a/sdks/java/io/elasticsearch/src/main/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIO.java +++ b/sdks/java/io/elasticsearch/src/main/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIO.java @@ -113,7 +113,7 @@ *

    Optionally, you can provide {@code withBatchSize()} and {@code withBatchSizeBytes()} * to specify the size of the write batch in number of documents or in bytes. */ -@Experimental +@Experimental(Experimental.Kind.SOURCE_SINK) public class ElasticsearchIO { public static Read read() { diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableIO.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableIO.java index 62679bb507d7f..0a90dde94f43c 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableIO.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableIO.java @@ -175,7 +175,7 @@ * pipeline. Please refer to the documentation of corresponding * {@link PipelineRunner PipelineRunners} for more details. */ -@Experimental +@Experimental(Experimental.Kind.SOURCE_SINK) public class BigtableIO { private static final Logger LOG = LoggerFactory.getLogger(BigtableIO.class); @@ -211,7 +211,7 @@ public static Write write() { * * @see BigtableIO */ - @Experimental + @Experimental(Experimental.Kind.SOURCE_SINK) @AutoValue public abstract static class Read extends PTransform> { @@ -415,7 +415,7 @@ BigtableService getBigtableService(PipelineOptions pipelineOptions) { * * @see BigtableIO */ - @Experimental + @Experimental(Experimental.Kind.SOURCE_SINK) @AutoValue public abstract static class Write extends PTransform>>, PDone> { diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java index acf928530a6bd..a247d4cb10fc0 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java @@ -167,7 +167,7 @@ public class SpannerIO { * configured with a {@link Read#withInstanceId} and {@link Read#withDatabaseId} that identify the * Cloud Spanner database. */ - @Experimental + @Experimental(Experimental.Kind.SOURCE_SINK) public static Read read() { return new AutoValue_SpannerIO_Read.Builder() .setSpannerConfig(SpannerConfig.create()) diff --git a/sdks/java/io/hadoop/input-format/src/main/java/org/apache/beam/sdk/io/hadoop/inputformat/HadoopInputFormatIO.java b/sdks/java/io/hadoop/input-format/src/main/java/org/apache/beam/sdk/io/hadoop/inputformat/HadoopInputFormatIO.java index efd47fd85c4c8..0b4c23f65c831 100644 --- a/sdks/java/io/hadoop/input-format/src/main/java/org/apache/beam/sdk/io/hadoop/inputformat/HadoopInputFormatIO.java +++ b/sdks/java/io/hadoop/input-format/src/main/java/org/apache/beam/sdk/io/hadoop/inputformat/HadoopInputFormatIO.java @@ -166,7 +166,7 @@ * } * */ -@Experimental +@Experimental(Experimental.Kind.SOURCE_SINK) public class HadoopInputFormatIO { private static final Logger LOG = LoggerFactory.getLogger(HadoopInputFormatIO.class); diff --git a/sdks/java/io/hbase/src/main/java/org/apache/beam/sdk/io/hbase/HBaseIO.java b/sdks/java/io/hbase/src/main/java/org/apache/beam/sdk/io/hbase/HBaseIO.java index c9afe8908a5c0..90ede4ce01347 100644 --- a/sdks/java/io/hbase/src/main/java/org/apache/beam/sdk/io/hbase/HBaseIO.java +++ b/sdks/java/io/hbase/src/main/java/org/apache/beam/sdk/io/hbase/HBaseIO.java @@ -140,7 +140,7 @@ * it can evolve or be different in some aspects, but the idea is that users can easily migrate * from one to the other

    . */ -@Experimental +@Experimental(Experimental.Kind.SOURCE_SINK) public class HBaseIO { private static final Logger LOG = LoggerFactory.getLogger(HBaseIO.class); diff --git a/sdks/java/io/hcatalog/src/main/java/org/apache/beam/sdk/io/hcatalog/HCatalogIO.java b/sdks/java/io/hcatalog/src/main/java/org/apache/beam/sdk/io/hcatalog/HCatalogIO.java index 1549dab048422..4199b805c0e78 100644 --- a/sdks/java/io/hcatalog/src/main/java/org/apache/beam/sdk/io/hcatalog/HCatalogIO.java +++ b/sdks/java/io/hcatalog/src/main/java/org/apache/beam/sdk/io/hcatalog/HCatalogIO.java @@ -106,7 +106,7 @@ * .withBatchSize(1024L)) //optional, assumes a default batch size of 1024 if none specified * } */ -@Experimental +@Experimental(Experimental.Kind.SOURCE_SINK) public class HCatalogIO { private static final Logger LOG = LoggerFactory.getLogger(HCatalogIO.class); diff --git a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java index 8092da6b971af..bf73dbef63f47 100644 --- a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java +++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java @@ -133,7 +133,7 @@ * Consider using MERGE ("upsert") * statements supported by your database instead. */ -@Experimental +@Experimental(Experimental.Kind.SOURCE_SINK) public class JdbcIO { /** * Read data from a JDBC datasource. diff --git a/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/JmsIO.java b/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/JmsIO.java index c5e51508c3ac6..f8cba5e0d8e17 100644 --- a/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/JmsIO.java +++ b/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/JmsIO.java @@ -98,7 +98,7 @@ * * } */ -@Experimental +@Experimental(Experimental.Kind.SOURCE_SINK) public class JmsIO { private static final Logger LOG = LoggerFactory.getLogger(JmsIO.class); diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java index 4d2a3584f2311..702bdd32b712b 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java @@ -235,7 +235,7 @@ * Note that {@link KafkaRecord#getTimestamp()} reflects timestamp provided by Kafka if any, * otherwise it is set to processing time. */ -@Experimental +@Experimental(Experimental.Kind.SOURCE_SINK) public class KafkaIO { /** diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisIO.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisIO.java index c97316d9f4058..b85eb6347dbce 100644 --- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisIO.java +++ b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisIO.java @@ -100,7 +100,7 @@ * } * */ -@Experimental +@Experimental(Experimental.Kind.SOURCE_SINK) public final class KinesisIO { /** Returns a new {@link Read} transform for reading from Kinesis. */ public static Read read() { diff --git a/sdks/java/io/mongodb/src/main/java/org/apache/beam/sdk/io/mongodb/MongoDbGridFSIO.java b/sdks/java/io/mongodb/src/main/java/org/apache/beam/sdk/io/mongodb/MongoDbGridFSIO.java index b63775da7827a..5b5412c9bee97 100644 --- a/sdks/java/io/mongodb/src/main/java/org/apache/beam/sdk/io/mongodb/MongoDbGridFSIO.java +++ b/sdks/java/io/mongodb/src/main/java/org/apache/beam/sdk/io/mongodb/MongoDbGridFSIO.java @@ -117,7 +117,7 @@ * to the file separated with line feeds. *

    */ -@Experimental +@Experimental(Experimental.Kind.SOURCE_SINK) public class MongoDbGridFSIO { /** diff --git a/sdks/java/io/mongodb/src/main/java/org/apache/beam/sdk/io/mongodb/MongoDbIO.java b/sdks/java/io/mongodb/src/main/java/org/apache/beam/sdk/io/mongodb/MongoDbIO.java index 04d9975a6760a..3b14182f27608 100644 --- a/sdks/java/io/mongodb/src/main/java/org/apache/beam/sdk/io/mongodb/MongoDbIO.java +++ b/sdks/java/io/mongodb/src/main/java/org/apache/beam/sdk/io/mongodb/MongoDbIO.java @@ -94,7 +94,7 @@ * * } */ -@Experimental +@Experimental(Experimental.Kind.SOURCE_SINK) public class MongoDbIO { private static final Logger LOG = LoggerFactory.getLogger(MongoDbIO.class); diff --git a/sdks/java/io/mqtt/src/main/java/org/apache/beam/sdk/io/mqtt/MqttIO.java b/sdks/java/io/mqtt/src/main/java/org/apache/beam/sdk/io/mqtt/MqttIO.java index 228a85d77ab20..add5cb57f6103 100644 --- a/sdks/java/io/mqtt/src/main/java/org/apache/beam/sdk/io/mqtt/MqttIO.java +++ b/sdks/java/io/mqtt/src/main/java/org/apache/beam/sdk/io/mqtt/MqttIO.java @@ -97,7 +97,7 @@ * * } */ -@Experimental +@Experimental(Experimental.Kind.SOURCE_SINK) public class MqttIO { private static final Logger LOG = LoggerFactory.getLogger(MqttIO.class); From fecd64f5ff73e590fcf19019534b9d0ed293ac60 Mon Sep 17 00:00:00 2001 From: Michael Luckey Date: Sun, 25 Jun 2017 15:01:08 +0200 Subject: [PATCH 133/200] [BEAM-2389] moved GcpCoreApiSurfaceTest to corresponding module, adapted exposed packagees --- .../extensions/gcp/GcpCoreApiSurfaceTest.java | 48 +++++++++++-------- 1 file changed, 27 insertions(+), 21 deletions(-) diff --git a/sdks/java/extensions/google-cloud-platform-core/src/test/java/org/apache/beam/sdk/extensions/gcp/GcpCoreApiSurfaceTest.java b/sdks/java/extensions/google-cloud-platform-core/src/test/java/org/apache/beam/sdk/extensions/gcp/GcpCoreApiSurfaceTest.java index 625c24883331c..a0d9e4b668848 100644 --- a/sdks/java/extensions/google-cloud-platform-core/src/test/java/org/apache/beam/sdk/extensions/gcp/GcpCoreApiSurfaceTest.java +++ b/sdks/java/extensions/google-cloud-platform-core/src/test/java/org/apache/beam/sdk/extensions/gcp/GcpCoreApiSurfaceTest.java @@ -15,14 +15,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.beam; +package org.apache.beam.sdk.extensions.gcp; -import static org.apache.beam.sdk.util.ApiSurface.containsOnlyPackages; +import static org.apache.beam.sdk.util.ApiSurface.classesInPackage; +import static org.apache.beam.sdk.util.ApiSurface.containsOnlyClassesMatching; import static org.hamcrest.MatcherAssert.assertThat; import com.google.common.collect.ImmutableSet; import java.util.Set; import org.apache.beam.sdk.util.ApiSurface; +import org.hamcrest.Matcher; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -32,28 +34,32 @@ public class GcpCoreApiSurfaceTest { @Test - public void testApiSurface() throws Exception { + public void testGcpCoreApiSurface() throws Exception { + final Package thisPackage = getClass().getPackage(); + final ClassLoader thisClassLoader = getClass().getClassLoader(); + final ApiSurface apiSurface = + ApiSurface.ofPackage(thisPackage, thisClassLoader) + .pruningPattern("org[.]apache[.]beam[.].*Test.*") + .pruningPattern("org[.]apache[.]beam[.].*IT") + .pruningPattern("java[.]lang.*") + .pruningPattern("java[.]util.*"); @SuppressWarnings("unchecked") - final Set allowed = + final Set>> allowedClasses = ImmutableSet.of( - "org.apache.beam", - "com.google.api.client", - "com.google.api.services.storage", - "com.google.auth", - "com.fasterxml.jackson.annotation", - "com.fasterxml.jackson.core", - "com.fasterxml.jackson.databind", - "org.apache.avro", - "org.hamcrest", - // via DataflowMatchers - "org.codehaus.jackson", - // via Avro - "org.joda.time", - "org.junit", - "sun.reflect"); + classesInPackage("com.google.api.client.googleapis"), + classesInPackage("com.google.api.client.http"), + classesInPackage("com.google.api.client.json"), + classesInPackage("com.google.api.client.util"), + classesInPackage("com.google.api.services.storage"), + classesInPackage("com.google.auth"), + classesInPackage("com.fasterxml.jackson.annotation"), + classesInPackage("java"), + classesInPackage("javax"), + classesInPackage("org.apache.beam.sdk"), + classesInPackage("org.joda.time") + ); - assertThat( - ApiSurface.getSdkApiSurface(getClass().getClassLoader()), containsOnlyPackages(allowed)); + assertThat(apiSurface, containsOnlyClassesMatching(allowedClasses)); } } From 6ade8426edc2ace1a9bec8f9501d8dad17e91365 Mon Sep 17 00:00:00 2001 From: Thomas Groh Date: Wed, 28 Jun 2017 12:51:31 -0700 Subject: [PATCH 134/200] Add a Combine Test for Sliding Windows without Context --- .../beam/sdk/transforms/CombineTest.java | 63 +++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/CombineTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/CombineTest.java index e2469ab3ba6f1..b24d82dee4621 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/CombineTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/CombineTest.java @@ -29,11 +29,13 @@ import static org.junit.Assert.assertThat; import com.google.common.base.MoreObjects; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.io.Serializable; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.HashSet; @@ -323,6 +325,67 @@ public void testFixedWindowsCombineWithContext() { pipeline.run(); } + @Test + @Category(ValidatesRunner.class) + public void testSlidingWindowsCombine() { + PCollection input = + pipeline + .apply( + Create.timestamped( + TimestampedValue.of("a", new Instant(1L)), + TimestampedValue.of("b", new Instant(2L)), + TimestampedValue.of("c", new Instant(3L)))) + .apply( + Window.into( + SlidingWindows.of(Duration.millis(3)).every(Duration.millis(1L)))); + PCollection> combined = + input.apply( + Combine.globally( + new CombineFn, List>() { + @Override + public List createAccumulator() { + return new ArrayList<>(); + } + + @Override + public List addInput(List accumulator, String input) { + accumulator.add(input); + return accumulator; + } + + @Override + public List mergeAccumulators(Iterable> accumulators) { + // Mutate all of the accumulators. Instances should be used in only one + // place, and not + // reused after merging. + List cur = createAccumulator(); + for (List accumulator : accumulators) { + accumulator.addAll(cur); + cur = accumulator; + } + return cur; + } + + @Override + public List extractOutput(List accumulator) { + List result = new ArrayList<>(accumulator); + Collections.sort(result); + return result; + } + }) + .withoutDefaults()); + + PAssert.that(combined) + .containsInAnyOrder( + ImmutableList.of("a"), + ImmutableList.of("a", "b"), + ImmutableList.of("a", "b", "c"), + ImmutableList.of("b", "c"), + ImmutableList.of("c")); + + pipeline.run(); + } + @Test @Category(ValidatesRunner.class) public void testSlidingWindowsCombineWithContext() { From ed815be8f4999aad6b02ae16574d8dbe1edc1c36 Mon Sep 17 00:00:00 2001 From: Stephen Sisk Date: Wed, 28 Jun 2017 15:30:26 -0700 Subject: [PATCH 135/200] Upgrade beam bigtable client dependency to 0.9.7.1 --- pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index 069191ccbb95c..536a11c5bbfdd 100644 --- a/pom.xml +++ b/pom.xml @@ -108,7 +108,7 @@ 1.0.0-rc2 1.8.2 v2-rev295-1.22.0 - 0.9.6.2 + 0.9.7.1 v1-rev6-1.22.0 0.1.0 v2-rev8-1.22.0 From 90cc2bcfdf2256d09ed2bc155c65c5a011c42026 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Wed, 28 Jun 2017 16:34:47 -0700 Subject: [PATCH 136/200] Visit composite nodes when checking for picklability. --- sdks/python/apache_beam/pipeline.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/pipeline.py b/sdks/python/apache_beam/pipeline.py index 724c87d023f14..fe36d85a7a3b0 100644 --- a/sdks/python/apache_beam/pipeline.py +++ b/sdks/python/apache_beam/pipeline.py @@ -477,6 +477,9 @@ def _verify_runner_api_compatible(self): class Visitor(PipelineVisitor): # pylint: disable=used-before-assignment ok = True # Really a nonlocal. + def enter_composite_transform(self, transform_node): + self.visit_transform(transform_node) + def visit_transform(self, transform_node): if transform_node.side_inputs: # No side inputs (yet). @@ -555,7 +558,7 @@ def visit_value(self, value, producer_node): pass def visit_transform(self, transform_node): - """Callback for visiting a transform node in the pipeline DAG.""" + """Callback for visiting a transform leaf node in the pipeline DAG.""" pass def enter_composite_transform(self, transform_node): From 4f9820b1f24103831f3b0a4f5783f9ca726f8cd7 Mon Sep 17 00:00:00 2001 From: = <=> Date: Wed, 14 Jun 2017 20:11:49 -0400 Subject: [PATCH 137/200] Removed OnceTriggerStateMachine --- .../core/triggers/AfterAllStateMachine.java | 25 +++++++---------- ...fterDelayFromFirstElementStateMachine.java | 6 ++--- .../core/triggers/AfterFirstStateMachine.java | 20 +++++++------- .../core/triggers/AfterPaneStateMachine.java | 6 ++--- .../triggers/AfterWatermarkStateMachine.java | 7 ++--- .../ExecutableTriggerStateMachine.java | 23 +++------------- .../core/triggers/NeverStateMachine.java | 5 ++-- .../core/triggers/TriggerStateMachine.java | 27 ------------------- .../triggers/AfterFirstStateMachineTest.java | 5 ++-- .../AfterWatermarkStateMachineTest.java | 7 +++-- .../triggers/StubTriggerStateMachine.java | 7 +++-- 11 files changed, 44 insertions(+), 94 deletions(-) diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/triggers/AfterAllStateMachine.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/triggers/AfterAllStateMachine.java index 0f0c17ca41c04..3530ed1a34cda 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/triggers/AfterAllStateMachine.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/triggers/AfterAllStateMachine.java @@ -23,7 +23,6 @@ import com.google.common.collect.ImmutableList; import java.util.Arrays; import java.util.List; -import org.apache.beam.runners.core.triggers.TriggerStateMachine.OnceTriggerStateMachine; import org.apache.beam.sdk.annotations.Experimental; /** @@ -31,7 +30,7 @@ * have fired. */ @Experimental(Experimental.Kind.TRIGGER) -public class AfterAllStateMachine extends OnceTriggerStateMachine { +public class AfterAllStateMachine extends TriggerStateMachine { private AfterAllStateMachine(List subTriggers) { super(subTriggers); @@ -42,11 +41,11 @@ private AfterAllStateMachine(List subTriggers) { * Returns an {@code AfterAll} {@code Trigger} with the given subtriggers. */ @SafeVarargs - public static OnceTriggerStateMachine of(TriggerStateMachine... triggers) { + public static TriggerStateMachine of(TriggerStateMachine... triggers) { return new AfterAllStateMachine(Arrays.asList(triggers)); } - public static OnceTriggerStateMachine of(Iterable triggers) { + public static TriggerStateMachine of(Iterable triggers) { return new AfterAllStateMachine(ImmutableList.copyOf(triggers)); } @@ -78,24 +77,21 @@ public void onMerge(OnMergeContext c) throws Exception { */ @Override public boolean shouldFire(TriggerContext context) throws Exception { - for (ExecutableTriggerStateMachine subtrigger : context.trigger().subTriggers()) { - if (!context.forTrigger(subtrigger).trigger().isFinished() - && !subtrigger.invokeShouldFire(context)) { + for (ExecutableTriggerStateMachine subTrigger : context.trigger().subTriggers()) { + if (!context.forTrigger(subTrigger).trigger().isFinished() + && !subTrigger.invokeShouldFire(context)) { return false; } } return true; } - /** - * Invokes {@link #onFire} for all subtriggers, eliding redundant calls to {@link #shouldFire} - * because they all must be ready to fire. - */ @Override - public void onOnlyFiring(TriggerContext context) throws Exception { - for (ExecutableTriggerStateMachine subtrigger : context.trigger().subTriggers()) { - subtrigger.invokeOnFire(context); + public void onFire(TriggerContext context) throws Exception { + for (ExecutableTriggerStateMachine subTrigger : context.trigger().subTriggers()) { + subTrigger.invokeOnFire(context); } + context.trigger().setFinished(true); } @Override @@ -103,7 +99,6 @@ public String toString() { StringBuilder builder = new StringBuilder("AfterAll.of("); Joiner.on(", ").appendTo(builder, subTriggers); builder.append(")"); - return builder.toString(); } } diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/triggers/AfterDelayFromFirstElementStateMachine.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/triggers/AfterDelayFromFirstElementStateMachine.java index 8d8d0de40bbe0..06c2066a1c313 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/triggers/AfterDelayFromFirstElementStateMachine.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/triggers/AfterDelayFromFirstElementStateMachine.java @@ -27,7 +27,6 @@ import org.apache.beam.runners.core.StateMerging; import org.apache.beam.runners.core.StateTag; import org.apache.beam.runners.core.StateTags; -import org.apache.beam.runners.core.triggers.TriggerStateMachine.OnceTriggerStateMachine; import org.apache.beam.sdk.annotations.Experimental; import org.apache.beam.sdk.coders.InstantCoder; import org.apache.beam.sdk.state.CombiningState; @@ -50,7 +49,7 @@ // This class should be inlined to subclasses and deleted, simplifying them too // https://issues.apache.org/jira/browse/BEAM-1486 @Experimental(Experimental.Kind.TRIGGER) -public abstract class AfterDelayFromFirstElementStateMachine extends OnceTriggerStateMachine { +public abstract class AfterDelayFromFirstElementStateMachine extends TriggerStateMachine { protected static final List> IDENTITY = ImmutableList.>of(); @@ -237,8 +236,9 @@ && getCurrentTime(context) != null } @Override - protected void onOnlyFiring(TriggerStateMachine.TriggerContext context) throws Exception { + public final void onFire(TriggerContext context) throws Exception { clear(context); + context.trigger().setFinished(true); } protected Instant computeTargetTimestamp(Instant time) { diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/triggers/AfterFirstStateMachine.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/triggers/AfterFirstStateMachine.java index 840a65cfdd04f..58c24c5e82a85 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/triggers/AfterFirstStateMachine.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/triggers/AfterFirstStateMachine.java @@ -23,7 +23,6 @@ import com.google.common.collect.ImmutableList; import java.util.Arrays; import java.util.List; -import org.apache.beam.runners.core.triggers.TriggerStateMachine.OnceTriggerStateMachine; import org.apache.beam.sdk.annotations.Experimental; /** @@ -31,7 +30,7 @@ * sub-triggers have fired. */ @Experimental(Experimental.Kind.TRIGGER) -public class AfterFirstStateMachine extends OnceTriggerStateMachine { +public class AfterFirstStateMachine extends TriggerStateMachine { AfterFirstStateMachine(List subTriggers) { super(subTriggers); @@ -42,12 +41,12 @@ public class AfterFirstStateMachine extends OnceTriggerStateMachine { * Returns an {@code AfterFirst} {@code Trigger} with the given subtriggers. */ @SafeVarargs - public static OnceTriggerStateMachine of( + public static TriggerStateMachine of( TriggerStateMachine... triggers) { return new AfterFirstStateMachine(Arrays.asList(triggers)); } - public static OnceTriggerStateMachine of( + public static TriggerStateMachine of( Iterable triggers) { return new AfterFirstStateMachine(ImmutableList.copyOf(triggers)); } @@ -79,18 +78,19 @@ public boolean shouldFire(TriggerStateMachine.TriggerContext context) throws Exc } @Override - protected void onOnlyFiring(TriggerContext context) throws Exception { - for (ExecutableTriggerStateMachine subtrigger : context.trigger().subTriggers()) { - TriggerContext subContext = context.forTrigger(subtrigger); - if (subtrigger.invokeShouldFire(subContext)) { + public void onFire(TriggerContext context) throws Exception { + for (ExecutableTriggerStateMachine subTrigger : context.trigger().subTriggers()) { + TriggerContext subContext = context.forTrigger(subTrigger); + if (subTrigger.invokeShouldFire(subContext)) { // If the trigger is ready to fire, then do whatever it needs to do. - subtrigger.invokeOnFire(subContext); + subTrigger.invokeOnFire(subContext); } else { // If the trigger is not ready to fire, it is nonetheless true that whatever // pending pane it was tracking is now gone. - subtrigger.invokeClear(subContext); + subTrigger.invokeClear(subContext); } } + context.trigger().setFinished(true); } @Override diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/triggers/AfterPaneStateMachine.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/triggers/AfterPaneStateMachine.java index b9fbac34d0fba..1ce035a7d6c09 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/triggers/AfterPaneStateMachine.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/triggers/AfterPaneStateMachine.java @@ -23,7 +23,6 @@ import org.apache.beam.runners.core.StateMerging; import org.apache.beam.runners.core.StateTag; import org.apache.beam.runners.core.StateTags; -import org.apache.beam.runners.core.triggers.TriggerStateMachine.OnceTriggerStateMachine; import org.apache.beam.sdk.annotations.Experimental; import org.apache.beam.sdk.coders.VarLongCoder; import org.apache.beam.sdk.state.CombiningState; @@ -33,7 +32,7 @@ * {@link TriggerStateMachine}s that fire based on properties of the elements in the current pane. */ @Experimental(Experimental.Kind.TRIGGER) -public class AfterPaneStateMachine extends OnceTriggerStateMachine { +public class AfterPaneStateMachine extends TriggerStateMachine { private static final StateTag> ELEMENTS_IN_PANE_TAG = @@ -130,7 +129,8 @@ public int hashCode() { } @Override - protected void onOnlyFiring(TriggerStateMachine.TriggerContext context) throws Exception { + public void onFire(TriggerStateMachine.TriggerContext context) throws Exception { clear(context); + context.trigger().setFinished(true); } } diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/triggers/AfterWatermarkStateMachine.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/triggers/AfterWatermarkStateMachine.java index c9eee15b857fa..509c96b9995e8 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/triggers/AfterWatermarkStateMachine.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/triggers/AfterWatermarkStateMachine.java @@ -22,7 +22,6 @@ import com.google.common.collect.ImmutableList; import java.util.Objects; import javax.annotation.Nullable; -import org.apache.beam.runners.core.triggers.TriggerStateMachine.OnceTriggerStateMachine; import org.apache.beam.sdk.annotations.Experimental; import org.apache.beam.sdk.state.TimeDomain; @@ -242,7 +241,7 @@ private void onLateFiring(TriggerStateMachine.TriggerContext context) throws Exc /** * A watermark trigger targeted relative to the end of the window. */ - public static class FromEndOfWindow extends OnceTriggerStateMachine { + public static class FromEndOfWindow extends TriggerStateMachine { private FromEndOfWindow() { super(null); @@ -319,6 +318,8 @@ private boolean endOfWindowReached(TriggerStateMachine.TriggerContext context) { } @Override - protected void onOnlyFiring(TriggerStateMachine.TriggerContext context) throws Exception { } + public void onFire(TriggerStateMachine.TriggerContext context) throws Exception { + context.trigger().setFinished(true); + } } } diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/triggers/ExecutableTriggerStateMachine.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/triggers/ExecutableTriggerStateMachine.java index c4d89c287acdd..cdcff644fc30b 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/triggers/ExecutableTriggerStateMachine.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/triggers/ExecutableTriggerStateMachine.java @@ -23,7 +23,6 @@ import java.io.Serializable; import java.util.ArrayList; import java.util.List; -import org.apache.beam.runners.core.triggers.TriggerStateMachine.OnceTriggerStateMachine; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; /** @@ -46,17 +45,14 @@ public static ExecutableTriggerStateMachine create( private static ExecutableTriggerStateMachine create( TriggerStateMachine trigger, int nextUnusedIndex) { - if (trigger instanceof OnceTriggerStateMachine) { - return new ExecutableOnceTriggerStateMachine( - (OnceTriggerStateMachine) trigger, nextUnusedIndex); - } else { + return new ExecutableTriggerStateMachine(trigger, nextUnusedIndex); - } + } public static ExecutableTriggerStateMachine createForOnceTrigger( - OnceTriggerStateMachine trigger, int nextUnusedIndex) { - return new ExecutableOnceTriggerStateMachine(trigger, nextUnusedIndex); + TriggerStateMachine trigger, int nextUnusedIndex) { + return new ExecutableTriggerStateMachine(trigger, nextUnusedIndex); } private ExecutableTriggerStateMachine(TriggerStateMachine trigger, int nextUnusedIndex) { @@ -146,15 +142,4 @@ public void invokeOnFire(TriggerStateMachine.TriggerContext c) throws Exception public void invokeClear(TriggerStateMachine.TriggerContext c) throws Exception { trigger.clear(c.forTrigger(this)); } - - /** - * {@link ExecutableTriggerStateMachine} that enforces the fact that the trigger should always - * FIRE_AND_FINISH and never just FIRE. - */ - private static class ExecutableOnceTriggerStateMachine extends ExecutableTriggerStateMachine { - - public ExecutableOnceTriggerStateMachine(OnceTriggerStateMachine trigger, int nextUnusedIndex) { - super(trigger, nextUnusedIndex); - } - } } diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/triggers/NeverStateMachine.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/triggers/NeverStateMachine.java index f32c7a8d9d5d4..f8c5e8ba5eb01 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/triggers/NeverStateMachine.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/triggers/NeverStateMachine.java @@ -17,7 +17,6 @@ */ package org.apache.beam.runners.core.triggers; -import org.apache.beam.runners.core.triggers.TriggerStateMachine.OnceTriggerStateMachine; import org.apache.beam.sdk.transforms.GroupByKey; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; @@ -27,7 +26,7 @@ *

    Using this trigger will only produce output when the watermark passes the end of the * {@link BoundedWindow window} plus the allowed lateness. */ -public final class NeverStateMachine extends OnceTriggerStateMachine { +public final class NeverStateMachine extends TriggerStateMachine { /** * Returns a trigger which never fires. Output will be produced from the using {@link GroupByKey} * when the {@link BoundedWindow} closes. @@ -53,7 +52,7 @@ public boolean shouldFire(TriggerStateMachine.TriggerContext context) { } @Override - protected void onOnlyFiring(TriggerStateMachine.TriggerContext context) { + public void onFire(TriggerStateMachine.TriggerContext context) { throw new UnsupportedOperationException( String.format("%s should never fire", getClass().getSimpleName())); } diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/triggers/TriggerStateMachine.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/triggers/TriggerStateMachine.java index 6a2cf0c91a201..880aa48cba5a1 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/triggers/TriggerStateMachine.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/triggers/TriggerStateMachine.java @@ -453,35 +453,8 @@ public int hashCode() { * } * * - *

    Note that if {@code t1} is {@link OnceTriggerStateMachine}, then {@code t1.orFinally(t2)} is - * the same as {@code AfterFirst.of(t1, t2)}. */ public TriggerStateMachine orFinally(TriggerStateMachine until) { return new OrFinallyStateMachine(this, until); } - - /** - * {@link TriggerStateMachine}s that are guaranteed to fire at most once should extend from this, - * rather than the general {@link TriggerStateMachine} class to indicate that behavior. - */ - public abstract static class OnceTriggerStateMachine extends TriggerStateMachine { - protected OnceTriggerStateMachine(List subTriggers) { - super(subTriggers); - } - - /** - * {@inheritDoc} - */ - @Override - public final void onFire(TriggerContext context) throws Exception { - onOnlyFiring(context); - context.trigger().setFinished(true); - } - - /** - * Called exactly once by {@link #onFire} when the trigger is fired. By default, - * invokes {@link #onFire} on all subtriggers for which {@link #shouldFire} is {@code true}. - */ - protected abstract void onOnlyFiring(TriggerContext context) throws Exception; - } } diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/triggers/AfterFirstStateMachineTest.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/triggers/AfterFirstStateMachineTest.java index 453c8ff0a63d1..2be90de47c0a2 100644 --- a/runners/core-java/src/test/java/org/apache/beam/runners/core/triggers/AfterFirstStateMachineTest.java +++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/triggers/AfterFirstStateMachineTest.java @@ -21,7 +21,6 @@ import static org.junit.Assert.assertTrue; import static org.mockito.Mockito.when; -import org.apache.beam.runners.core.triggers.TriggerStateMachine.OnceTriggerStateMachine; import org.apache.beam.runners.core.triggers.TriggerStateMachineTester.SimpleTriggerStateMachineTester; import org.apache.beam.sdk.transforms.windowing.FixedWindows; import org.apache.beam.sdk.transforms.windowing.IntervalWindow; @@ -42,8 +41,8 @@ @RunWith(JUnit4.class) public class AfterFirstStateMachineTest { - @Mock private OnceTriggerStateMachine mockTrigger1; - @Mock private OnceTriggerStateMachine mockTrigger2; + @Mock private TriggerStateMachine mockTrigger1; + @Mock private TriggerStateMachine mockTrigger2; private SimpleTriggerStateMachineTester tester; private static TriggerStateMachine.TriggerContext anyTriggerContext() { return Mockito.any(); diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/triggers/AfterWatermarkStateMachineTest.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/triggers/AfterWatermarkStateMachineTest.java index e4d10a0a0fbad..45a5cfb9e6508 100644 --- a/runners/core-java/src/test/java/org/apache/beam/runners/core/triggers/AfterWatermarkStateMachineTest.java +++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/triggers/AfterWatermarkStateMachineTest.java @@ -25,7 +25,6 @@ import static org.mockito.Mockito.when; import org.apache.beam.runners.core.triggers.TriggerStateMachine.OnMergeContext; -import org.apache.beam.runners.core.triggers.TriggerStateMachine.OnceTriggerStateMachine; import org.apache.beam.runners.core.triggers.TriggerStateMachineTester.SimpleTriggerStateMachineTester; import org.apache.beam.sdk.transforms.windowing.FixedWindows; import org.apache.beam.sdk.transforms.windowing.IntervalWindow; @@ -46,8 +45,8 @@ @RunWith(JUnit4.class) public class AfterWatermarkStateMachineTest { - @Mock private OnceTriggerStateMachine mockEarly; - @Mock private OnceTriggerStateMachine mockLate; + @Mock private TriggerStateMachine mockEarly; + @Mock private TriggerStateMachine mockLate; private SimpleTriggerStateMachineTester tester; private static TriggerStateMachine.TriggerContext anyTriggerContext() { @@ -70,7 +69,7 @@ public void setUp() { MockitoAnnotations.initMocks(this); } - public void testRunningAsTrigger(OnceTriggerStateMachine mockTrigger, IntervalWindow window) + public void testRunningAsTrigger(TriggerStateMachine mockTrigger, IntervalWindow window) throws Exception { // Don't fire due to mock saying no diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/triggers/StubTriggerStateMachine.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/triggers/StubTriggerStateMachine.java index 4512848aaa4e8..1bc757ec04132 100644 --- a/runners/core-java/src/test/java/org/apache/beam/runners/core/triggers/StubTriggerStateMachine.java +++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/triggers/StubTriggerStateMachine.java @@ -18,12 +18,11 @@ package org.apache.beam.runners.core.triggers; import com.google.common.collect.Lists; -import org.apache.beam.runners.core.triggers.TriggerStateMachine.OnceTriggerStateMachine; /** - * No-op {@link OnceTriggerStateMachine} implementation for testing. + * No-op {@link TriggerStateMachine} implementation for testing. */ -abstract class StubTriggerStateMachine extends OnceTriggerStateMachine { +abstract class StubTriggerStateMachine extends TriggerStateMachine { /** * Create a stub {@link TriggerStateMachine} instance which returns the specified name on {@link * #toString()}. @@ -42,7 +41,7 @@ protected StubTriggerStateMachine() { } @Override - protected void onOnlyFiring(TriggerContext context) throws Exception { + public void onFire(TriggerContext context) throws Exception { } @Override From 38dd12df6dee2ada31ad9c52f8d9dc99225f1bc2 Mon Sep 17 00:00:00 2001 From: Pei He Date: Tue, 20 Jun 2017 16:09:26 -0700 Subject: [PATCH 138/200] WindowingStrategy: add OnTimeBehavior to control whether to emit empty ON_TIME pane. --- .../WindowingStrategyTranslation.java | 26 ++- .../beam/runners/core/ReduceFnRunner.java | 6 +- .../beam/runners/core/ReduceFnRunnerTest.java | 161 ++++++++++++++++++ .../src/main/proto/beam_runner_api.proto | 14 ++ .../beam/sdk/transforms/windowing/Window.java | 32 ++++ .../beam/sdk/values/WindowingStrategy.java | 46 ++++- 6 files changed, 273 insertions(+), 12 deletions(-) diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/WindowingStrategyTranslation.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/WindowingStrategyTranslation.java index 718efe7c9190a..88ebc01b1df8d 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/WindowingStrategyTranslation.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/WindowingStrategyTranslation.java @@ -38,6 +38,7 @@ import org.apache.beam.sdk.transforms.windowing.TimestampCombiner; import org.apache.beam.sdk.transforms.windowing.Trigger; import org.apache.beam.sdk.transforms.windowing.Window.ClosingBehavior; +import org.apache.beam.sdk.transforms.windowing.Window.OnTimeBehavior; import org.apache.beam.sdk.transforms.windowing.WindowFn; import org.apache.beam.sdk.util.SerializableUtils; import org.apache.beam.sdk.values.WindowingStrategy; @@ -119,6 +120,27 @@ public static ClosingBehavior fromProto(RunnerApi.ClosingBehavior proto) { } } + + public static OnTimeBehavior fromProto(RunnerApi.OnTimeBehavior proto) { + switch (proto) { + case FIRE_ALWAYS: + return OnTimeBehavior.FIRE_ALWAYS; + case FIRE_IF_NONEMPTY: + return OnTimeBehavior.FIRE_IF_NON_EMPTY; + case UNRECOGNIZED: + default: + // Whether or not it is proto that cannot recognize it (due to the version of the + // generated code we link to) or the switch hasn't been updated to handle it, + // the situation is the same: we don't know what this OutputTime means + throw new IllegalArgumentException( + String.format( + "Cannot convert unknown %s to %s: %s", + RunnerApi.OnTimeBehavior.class.getCanonicalName(), + OnTimeBehavior.class.getCanonicalName(), + proto)); + } + } + public static RunnerApi.OutputTime toProto(TimestampCombiner timestampCombiner) { switch(timestampCombiner) { case EARLIEST: @@ -323,13 +345,15 @@ public static RunnerApi.WindowingStrategy toProto( Trigger trigger = TriggerTranslation.fromProto(proto.getTrigger()); ClosingBehavior closingBehavior = fromProto(proto.getClosingBehavior()); Duration allowedLateness = Duration.millis(proto.getAllowedLateness()); + OnTimeBehavior onTimeBehavior = fromProto(proto.getOnTimeBehavior()); return WindowingStrategy.of(windowFn) .withAllowedLateness(allowedLateness) .withMode(accumulationMode) .withTrigger(trigger) .withTimestampCombiner(timestampCombiner) - .withClosingBehavior(closingBehavior); + .withClosingBehavior(closingBehavior) + .withOnTimeBehavior(onTimeBehavior); } public static WindowFn windowFnFromProto(SdkFunctionSpec windowFnSpec) diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/ReduceFnRunner.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/ReduceFnRunner.java index 75b6acda3312a..a33bac1f61361 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/ReduceFnRunner.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/ReduceFnRunner.java @@ -51,6 +51,7 @@ import org.apache.beam.sdk.transforms.windowing.PaneInfo; import org.apache.beam.sdk.transforms.windowing.PaneInfo.Timing; import org.apache.beam.sdk.transforms.windowing.TimestampCombiner; +import org.apache.beam.sdk.transforms.windowing.Window; import org.apache.beam.sdk.transforms.windowing.Window.ClosingBehavior; import org.apache.beam.sdk.transforms.windowing.WindowFn; import org.apache.beam.sdk.util.WindowTracing; @@ -920,8 +921,9 @@ private boolean needToEmit(boolean isEmpty, boolean isFinished, PaneInfo.Timing // The pane has elements. return true; } - if (timing == Timing.ON_TIME) { - // This is the unique ON_TIME pane. + if (timing == Timing.ON_TIME + && windowingStrategy.getOnTimeBehavior() == Window.OnTimeBehavior.FIRE_ALWAYS) { + // This is an empty ON_TIME pane. return true; } if (isFinished && windowingStrategy.getClosingBehavior() == ClosingBehavior.FIRE_ALWAYS) { diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/ReduceFnRunnerTest.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/ReduceFnRunnerTest.java index 4f68038f38a7f..3a2c2205c5ea6 100644 --- a/runners/core-java/src/test/java/org/apache/beam/runners/core/ReduceFnRunnerTest.java +++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/ReduceFnRunnerTest.java @@ -67,6 +67,7 @@ import org.apache.beam.sdk.transforms.windowing.SlidingWindows; import org.apache.beam.sdk.transforms.windowing.TimestampCombiner; import org.apache.beam.sdk.transforms.windowing.Trigger; +import org.apache.beam.sdk.transforms.windowing.Window; import org.apache.beam.sdk.transforms.windowing.Window.ClosingBehavior; import org.apache.beam.sdk.transforms.windowing.WindowFn; import org.apache.beam.sdk.transforms.windowing.WindowMappingFn; @@ -1422,6 +1423,166 @@ public void testEmptyOnTimeFromOrFinally() throws Exception { WindowMatchers.valueWithPaneInfo(PaneInfo.createPane(false, false, Timing.ON_TIME, 1, 0))); } + /** + * Test that it won't fire an empty on-time pane when OnTimeBehavior is FIRE_IF_NON_EMPTY. + */ + @Test + public void testEmptyOnTimeWithOnTimeBehaviorFireIfNonEmpty() throws Exception { + + WindowingStrategy strategy = + WindowingStrategy.of((WindowFn) FixedWindows.of(Duration.millis(10))) + .withTimestampCombiner(TimestampCombiner.EARLIEST) + .withTrigger( + AfterEach.inOrder( + Repeatedly.forever( + AfterProcessingTime.pastFirstElementInPane() + .plusDelayOf(new Duration(5))) + .orFinally(AfterWatermark.pastEndOfWindow()), + Repeatedly.forever( + AfterProcessingTime.pastFirstElementInPane() + .plusDelayOf(new Duration(25))))) + .withMode(AccumulationMode.ACCUMULATING_FIRED_PANES) + .withAllowedLateness(Duration.millis(100)) + .withClosingBehavior(ClosingBehavior.FIRE_ALWAYS) + .withOnTimeBehavior(Window.OnTimeBehavior.FIRE_IF_NON_EMPTY); + + ReduceFnTester tester = + ReduceFnTester.combining(strategy, Sum.ofIntegers(), VarIntCoder.of()); + + tester.advanceInputWatermark(new Instant(0)); + tester.advanceProcessingTime(new Instant(0)); + + // Processing time timer for 5 + tester.injectElements( + TimestampedValue.of(1, new Instant(1)), + TimestampedValue.of(1, new Instant(3)), + TimestampedValue.of(1, new Instant(7)), + TimestampedValue.of(1, new Instant(5))); + + // Should fire early pane + tester.advanceProcessingTime(new Instant(6)); + + // Should not fire empty on time pane + tester.advanceInputWatermark(new Instant(11)); + + // Should fire final GC pane + tester.advanceInputWatermark(new Instant(10 + 100)); + List> output = tester.extractOutput(); + assertEquals(2, output.size()); + + assertThat(output.get(0), WindowMatchers.isSingleWindowedValue(4, 1, 0, 10)); + assertThat(output.get(1), WindowMatchers.isSingleWindowedValue(4, 9, 0, 10)); + + assertThat( + output.get(0), + WindowMatchers.valueWithPaneInfo(PaneInfo.createPane(true, false, Timing.EARLY, 0, -1))); + assertThat( + output.get(1), + WindowMatchers.valueWithPaneInfo(PaneInfo.createPane(false, true, Timing.LATE, 1, 0))); + } + + /** + * Test that it fires an empty on-time isFinished pane when OnTimeBehavior is FIRE_ALWAYS + * and ClosingBehavior is FIRE_IF_NON_EMPTY. + * + *

    This is a test just for backward compatibility. + */ + @Test + public void testEmptyOnTimeWithOnTimeBehaviorBackwardCompatibility() throws Exception { + + WindowingStrategy strategy = + WindowingStrategy.of((WindowFn) FixedWindows.of(Duration.millis(10))) + .withTimestampCombiner(TimestampCombiner.EARLIEST) + .withTrigger(AfterWatermark.pastEndOfWindow() + .withEarlyFirings(AfterPane.elementCountAtLeast(1))) + .withMode(AccumulationMode.ACCUMULATING_FIRED_PANES) + .withAllowedLateness(Duration.millis(0)) + .withClosingBehavior(ClosingBehavior.FIRE_IF_NON_EMPTY); + + ReduceFnTester tester = + ReduceFnTester.combining(strategy, Sum.ofIntegers(), VarIntCoder.of()); + + tester.advanceInputWatermark(new Instant(0)); + tester.advanceProcessingTime(new Instant(0)); + + tester.injectElements( + TimestampedValue.of(1, new Instant(1))); + + // Should fire empty on time isFinished pane + tester.advanceInputWatermark(new Instant(11)); + + List> output = tester.extractOutput(); + assertEquals(2, output.size()); + + assertThat( + output.get(0), + WindowMatchers.valueWithPaneInfo(PaneInfo.createPane(true, false, Timing.EARLY, 0, -1))); + assertThat( + output.get(1), + WindowMatchers.valueWithPaneInfo(PaneInfo.createPane(false, true, Timing.ON_TIME, 1, 0))); + } + + /** + * Test that it won't fire an empty on-time pane when OnTimeBehavior is FIRE_IF_NON_EMPTY + * and when receiving late data. + */ + @Test + public void testEmptyOnTimeWithOnTimeBehaviorFireIfNonEmptyAndLateData() throws Exception { + + WindowingStrategy strategy = + WindowingStrategy.of((WindowFn) FixedWindows.of(Duration.millis(10))) + .withTimestampCombiner(TimestampCombiner.EARLIEST) + .withTrigger( + AfterEach.inOrder( + Repeatedly.forever( + AfterProcessingTime.pastFirstElementInPane() + .plusDelayOf(new Duration(5))) + .orFinally(AfterWatermark.pastEndOfWindow()), + Repeatedly.forever( + AfterProcessingTime.pastFirstElementInPane() + .plusDelayOf(new Duration(25))))) + .withMode(AccumulationMode.ACCUMULATING_FIRED_PANES) + .withAllowedLateness(Duration.millis(100)) + .withOnTimeBehavior(Window.OnTimeBehavior.FIRE_IF_NON_EMPTY); + + ReduceFnTester tester = + ReduceFnTester.combining(strategy, Sum.ofIntegers(), VarIntCoder.of()); + + tester.advanceInputWatermark(new Instant(0)); + tester.advanceProcessingTime(new Instant(0)); + + // Processing time timer for 5 + tester.injectElements( + TimestampedValue.of(1, new Instant(1)), + TimestampedValue.of(1, new Instant(3)), + TimestampedValue.of(1, new Instant(7)), + TimestampedValue.of(1, new Instant(5))); + + // Should fire early pane + tester.advanceProcessingTime(new Instant(6)); + + // Should not fire empty on time pane + tester.advanceInputWatermark(new Instant(11)); + + // Processing late data, and should fire late pane + tester.injectElements( + TimestampedValue.of(1, new Instant(9))); + tester.advanceProcessingTime(new Instant(6 + 25 + 1)); + + List> output = tester.extractOutput(); + assertEquals(2, output.size()); + + assertThat(output.get(0), WindowMatchers.isSingleWindowedValue(4, 1, 0, 10)); + assertThat(output.get(1), WindowMatchers.isSingleWindowedValue(5, 9, 0, 10)); + + assertThat( + output.get(0), + WindowMatchers.valueWithPaneInfo(PaneInfo.createPane(true, false, Timing.EARLY, 0, -1))); + assertThat( + output.get(1), + WindowMatchers.valueWithPaneInfo(PaneInfo.createPane(false, false, Timing.LATE, 1, 0))); + } + /** * Tests for processing time firings after the watermark passes the end of the window. * Specifically, verify the proper triggerings and pane-info of a typical speculative/on-time/late diff --git a/sdks/common/runner-api/src/main/proto/beam_runner_api.proto b/sdks/common/runner-api/src/main/proto/beam_runner_api.proto index 039ecb0ba0407..24e907a72dad3 100644 --- a/sdks/common/runner-api/src/main/proto/beam_runner_api.proto +++ b/sdks/common/runner-api/src/main/proto/beam_runner_api.proto @@ -433,6 +433,9 @@ message WindowingStrategy { // (Required) The duration, in milliseconds, beyond the end of a window at // which the window becomes droppable. int64 allowed_lateness = 8; + + // (Required) Indicate whether empty on-time panes should be omitted. + OnTimeBehavior OnTimeBehavior = 9; } // Whether or not a PCollection's WindowFn is non-merging, merging, or @@ -478,6 +481,17 @@ enum ClosingBehavior { EMIT_IF_NONEMPTY = 1; } +// Controls whether or not an aggregating transform should output data +// when an on-time pane is empty. +enum OnTimeBehavior { + // Always fire the on-time pane. Even if there is no new data since + // the previous firing, an element will be produced. + FIRE_ALWAYS = 0; + + // Only fire the on-time pane if there is new data since the previous firing. + FIRE_IF_NONEMPTY = 1; +} + // When a number of windowed, timestamped inputs are aggregated, the timestamp // for the resulting output. enum OutputTime { diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/Window.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/Window.java index 105ebfbe24afa..a12be6d3b4bf0 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/Window.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/Window.java @@ -162,6 +162,24 @@ public enum ClosingBehavior { } + /** + * Specifies the conditions under which an on-time pane will be created when a window is closed. + */ + public enum OnTimeBehavior { + /** + * Always fire the on-time pane. Even if there is no new data since the previous firing, + * an element will be produced. + * + *

    This is the default behavior. + */ + FIRE_ALWAYS, + /** + * Only fire the on-time pane if there is new data since the previous firing. + */ + FIRE_IF_NON_EMPTY + + } + /** * Creates a {@code Window} {@code PTransform} that uses the given * {@link WindowFn} to window the data. @@ -195,6 +213,7 @@ public static Window configure() { @Nullable abstract AccumulationMode getAccumulationMode(); @Nullable abstract Duration getAllowedLateness(); @Nullable abstract ClosingBehavior getClosingBehavior(); + @Nullable abstract OnTimeBehavior getOnTimeBehavior(); @Nullable abstract TimestampCombiner getTimestampCombiner(); abstract Builder toBuilder(); @@ -206,6 +225,7 @@ abstract static class Builder { abstract Builder setAccumulationMode(AccumulationMode mode); abstract Builder setAllowedLateness(Duration allowedLateness); abstract Builder setClosingBehavior(ClosingBehavior closingBehavior); + abstract Builder setOnTimeBehavior(OnTimeBehavior onTimeBehavior); abstract Builder setTimestampCombiner(TimestampCombiner timestampCombiner); abstract Window build(); @@ -298,6 +318,15 @@ public Window withAllowedLateness(Duration allowedLateness, ClosingBehavior b return toBuilder().setAllowedLateness(allowedLateness).setClosingBehavior(behavior).build(); } + /** + * (Experimental) Override the default {@link OnTimeBehavior}, to control + * whether to output an empty on-time pane. + */ + @Experimental(Kind.TRIGGER) + public Window withOnTimeBehavior(OnTimeBehavior behavior) { + return toBuilder().setOnTimeBehavior(behavior).build(); + } + /** * Get the output strategy of this {@link Window Window PTransform}. For internal use * only. @@ -321,6 +350,9 @@ public Window withAllowedLateness(Duration allowedLateness, ClosingBehavior b if (getClosingBehavior() != null) { result = result.withClosingBehavior(getClosingBehavior()); } + if (getOnTimeBehavior() != null) { + result = result.withOnTimeBehavior(getOnTimeBehavior()); + } if (getTimestampCombiner() != null) { result = result.withTimestampCombiner(getTimestampCombiner()); } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/WindowingStrategy.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/WindowingStrategy.java index 8a773e23f9b5e..3b74e699cc13f 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/WindowingStrategy.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/WindowingStrategy.java @@ -29,6 +29,7 @@ import org.apache.beam.sdk.transforms.windowing.TimestampCombiner; import org.apache.beam.sdk.transforms.windowing.Trigger; import org.apache.beam.sdk.transforms.windowing.Window.ClosingBehavior; +import org.apache.beam.sdk.transforms.windowing.Window.OnTimeBehavior; import org.apache.beam.sdk.transforms.windowing.WindowFn; import org.joda.time.Duration; @@ -59,6 +60,7 @@ public enum AccumulationMode { private final AccumulationMode mode; private final Duration allowedLateness; private final ClosingBehavior closingBehavior; + private final OnTimeBehavior onTimeBehavior; private final TimestampCombiner timestampCombiner; private final boolean triggerSpecified; private final boolean modeSpecified; @@ -71,7 +73,8 @@ private WindowingStrategy( AccumulationMode mode, boolean modeSpecified, Duration allowedLateness, boolean allowedLatenessSpecified, TimestampCombiner timestampCombiner, boolean timestampCombinerSpecified, - ClosingBehavior closingBehavior) { + ClosingBehavior closingBehavior, + OnTimeBehavior onTimeBehavior) { this.windowFn = windowFn; this.trigger = trigger; this.triggerSpecified = triggerSpecified; @@ -80,6 +83,7 @@ private WindowingStrategy( this.allowedLateness = allowedLateness; this.allowedLatenessSpecified = allowedLatenessSpecified; this.closingBehavior = closingBehavior; + this.onTimeBehavior = onTimeBehavior; this.timestampCombiner = timestampCombiner; this.timestampCombinerSpecified = timestampCombinerSpecified; } @@ -98,7 +102,8 @@ public static WindowingStrategy of( AccumulationMode.DISCARDING_FIRED_PANES, false, DEFAULT_ALLOWED_LATENESS, false, TimestampCombiner.END_OF_WINDOW, false, - ClosingBehavior.FIRE_IF_NON_EMPTY); + ClosingBehavior.FIRE_IF_NON_EMPTY, + OnTimeBehavior.FIRE_ALWAYS); } public WindowFn getWindowFn() { @@ -133,6 +138,10 @@ public ClosingBehavior getClosingBehavior() { return closingBehavior; } + public OnTimeBehavior getOnTimeBehavior() { + return onTimeBehavior; + } + public TimestampCombiner getTimestampCombiner() { return timestampCombiner; } @@ -152,7 +161,8 @@ public WindowingStrategy withTrigger(Trigger trigger) { mode, modeSpecified, allowedLateness, allowedLatenessSpecified, timestampCombiner, timestampCombinerSpecified, - closingBehavior); + closingBehavior, + onTimeBehavior); } /** @@ -166,7 +176,8 @@ public WindowingStrategy withMode(AccumulationMode mode) { mode, true, allowedLateness, allowedLatenessSpecified, timestampCombiner, timestampCombinerSpecified, - closingBehavior); + closingBehavior, + onTimeBehavior); } /** @@ -183,7 +194,8 @@ public WindowingStrategy withWindowFn(WindowFn wildcardWindowFn) { mode, modeSpecified, allowedLateness, allowedLatenessSpecified, timestampCombiner, timestampCombinerSpecified, - closingBehavior); + closingBehavior, + onTimeBehavior); } /** @@ -197,7 +209,8 @@ public WindowingStrategy withAllowedLateness(Duration allowedLateness) { mode, modeSpecified, allowedLateness, true, timestampCombiner, timestampCombinerSpecified, - closingBehavior); + closingBehavior, + onTimeBehavior); } public WindowingStrategy withClosingBehavior(ClosingBehavior closingBehavior) { @@ -207,7 +220,19 @@ public WindowingStrategy withClosingBehavior(ClosingBehavior closingBehavi mode, modeSpecified, allowedLateness, allowedLatenessSpecified, timestampCombiner, timestampCombinerSpecified, - closingBehavior); + closingBehavior, + onTimeBehavior); + } + + public WindowingStrategy withOnTimeBehavior(OnTimeBehavior onTimeBehavior) { + return new WindowingStrategy( + windowFn, + trigger, triggerSpecified, + mode, modeSpecified, + allowedLateness, allowedLatenessSpecified, + timestampCombiner, timestampCombinerSpecified, + closingBehavior, + onTimeBehavior); } @Experimental(Experimental.Kind.OUTPUT_TIME) @@ -219,7 +244,8 @@ public WindowingStrategy withTimestampCombiner(TimestampCombiner timestamp mode, modeSpecified, allowedLateness, allowedLatenessSpecified, timestampCombiner, true, - closingBehavior); + closingBehavior, + onTimeBehavior); } @Override @@ -246,6 +272,7 @@ && isTimestampCombinerSpecified() == other.isTimestampCombinerSpecified() && getMode().equals(other.getMode()) && getAllowedLateness().equals(other.getAllowedLateness()) && getClosingBehavior().equals(other.getClosingBehavior()) + && getOnTimeBehavior().equals(other.getOnTimeBehavior()) && getTrigger().equals(other.getTrigger()) && getTimestampCombiner().equals(other.getTimestampCombiner()) && getWindowFn().equals(other.getWindowFn()); @@ -278,6 +305,7 @@ public WindowingStrategy fixDefaults() { mode, true, allowedLateness, true, timestampCombiner, true, - closingBehavior); + closingBehavior, + onTimeBehavior); } } From 2efb0d561fc62ba44bf71db6937a54708944f0f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Author=3A=20=E6=B3=A2=E7=89=B9?= Date: Fri, 26 May 2017 17:46:55 +0800 Subject: [PATCH 139/200] ReduceFnRunner.onTrigger: add short circuit for empty pane, and move inputWM and pane after the short circuit. --- .../apache/beam/runners/core/ReduceFnRunner.java | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/ReduceFnRunner.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/ReduceFnRunner.java index a33bac1f61361..ef33befffc939 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/ReduceFnRunner.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/ReduceFnRunner.java @@ -953,11 +953,6 @@ private Instant onTrigger( ReduceFn.Context renamedContext, final boolean isFinished, boolean isEndOfWindow) throws Exception { - Instant inputWM = timerInternals.currentInputWatermarkTime(); - - // Calculate the pane info. - final PaneInfo pane = paneInfoTracker.getNextPaneInfo(directContext, isFinished).read(); - // Extract the window hold, and as a side effect clear it. final WatermarkHold.OldAndNewHolds pair = watermarkHold.extractAndRelease(renamedContext, isFinished).read(); @@ -966,7 +961,13 @@ private Instant onTrigger( @Nullable Instant newHold = pair.newHold; final boolean isEmpty = nonEmptyPanes.isEmpty(renamedContext.state()).read(); + if (isEmpty + && windowingStrategy.getClosingBehavior() == ClosingBehavior.FIRE_IF_NON_EMPTY + && windowingStrategy.getOnTimeBehavior() == Window.OnTimeBehavior.FIRE_IF_NON_EMPTY) { + return newHold; + } + Instant inputWM = timerInternals.currentInputWatermarkTime(); if (newHold != null) { // We can't be finished yet. checkState( @@ -998,6 +999,9 @@ private Instant onTrigger( } } + // Calculate the pane info. + final PaneInfo pane = paneInfoTracker.getNextPaneInfo(directContext, isFinished).read(); + // Only emit a pane if it has data or empty panes are observable. if (needToEmit(isEmpty, isFinished, pane.getTiming())) { // Run reduceFn.onTrigger method. From dbeba09b0def871066753d96a2fb354bffe18c04 Mon Sep 17 00:00:00 2001 From: Kenneth Knowles Date: Wed, 28 Jun 2017 07:12:51 -0700 Subject: [PATCH 140/200] Only use ASCII 'a' through 'z' for temporary Spanner tables --- .../apache/beam/sdk/io/gcp/spanner/SpannerReadIT.java | 8 +++++--- .../apache/beam/sdk/io/gcp/spanner/SpannerWriteIT.java | 10 +++++----- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerReadIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerReadIT.java index f5d7cbd6c31d3..ca43b40104e88 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerReadIT.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerReadIT.java @@ -41,6 +41,7 @@ import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionView; import org.apache.commons.lang3.RandomStringUtils; +import org.apache.commons.text.RandomStringGenerator; import org.junit.After; import org.junit.Before; import org.junit.Rule; @@ -161,9 +162,10 @@ public void tearDown() throws Exception { private String generateDatabaseName() { String random = - RandomStringUtils.randomAlphanumeric( - MAX_DB_NAME_LENGTH - 1 - options.getDatabaseIdPrefix().length()) - .toLowerCase(); + new RandomStringGenerator.Builder() + .withinRange('a', 'z') + .build() + .generate(MAX_DB_NAME_LENGTH - 1 - options.getDatabaseIdPrefix().length()); return options.getDatabaseIdPrefix() + "-" + random; } } diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteIT.java index 33532c929bab0..613756cff2481 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteIT.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteIT.java @@ -33,7 +33,6 @@ import com.google.cloud.spanner.Statement; import com.google.spanner.admin.database.v1.CreateDatabaseMetadata; import java.util.Collections; - import org.apache.beam.sdk.io.GenerateSequence; import org.apache.beam.sdk.options.Default; import org.apache.beam.sdk.options.Description; @@ -43,7 +42,6 @@ import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.ParDo; import org.apache.commons.text.RandomStringGenerator; - import org.junit.After; import org.junit.Before; import org.junit.Rule; @@ -116,9 +114,11 @@ public void setUp() throws Exception { } private String generateDatabaseName() { - String random = new RandomStringGenerator.Builder().build() - .generate(MAX_DB_NAME_LENGTH - 1 - options.getDatabaseIdPrefix().length()) - .toLowerCase(); + String random = + new RandomStringGenerator.Builder() + .withinRange('a', 'z') + .build() + .generate(MAX_DB_NAME_LENGTH - 1 - options.getDatabaseIdPrefix().length()); return options.getDatabaseIdPrefix() + "-" + random; } From fb7ec28cfb1291b04e0eac738054eefe0bb9a103 Mon Sep 17 00:00:00 2001 From: Charles Chen Date: Mon, 26 Jun 2017 18:03:53 -0700 Subject: [PATCH 141/200] Add PubSub I/O support to Python DirectRunner --- .../examples/streaming_wordcount.py | 12 ++- sdks/python/apache_beam/io/gcp/pubsub.py | 91 ++++++++++++++----- sdks/python/apache_beam/io/gcp/pubsub_test.py | 89 +++++++++++------- .../runners/dataflow/dataflow_runner.py | 11 ++- .../runners/direct/direct_runner.py | 54 +++++++++++ .../runners/direct/transform_evaluator.py | 89 ++++++++++++++++++ 6 files changed, 281 insertions(+), 65 deletions(-) diff --git a/sdks/python/apache_beam/examples/streaming_wordcount.py b/sdks/python/apache_beam/examples/streaming_wordcount.py index 4c29f2b46b302..7696d77893237 100644 --- a/sdks/python/apache_beam/examples/streaming_wordcount.py +++ b/sdks/python/apache_beam/examples/streaming_wordcount.py @@ -28,6 +28,8 @@ import apache_beam as beam +from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.options.pipeline_options import StandardOptions import apache_beam.transforms.window as window @@ -41,13 +43,17 @@ def run(argv=None): parser = argparse.ArgumentParser() parser.add_argument( '--input_topic', required=True, - help='Input PubSub topic of the form "/topics//".') + help=('Input PubSub topic of the form ' + '"projects//topics/".')) parser.add_argument( '--output_topic', required=True, - help='Output PubSub topic of the form "/topics//".') + help=('Output PubSub topic of the form ' + '"projects//topic/".')) known_args, pipeline_args = parser.parse_known_args(argv) + options = PipelineOptions(pipeline_args) + options.view_as(StandardOptions).streaming = True - with beam.Pipeline(argv=pipeline_args) as p: + with beam.Pipeline(options=options) as p: # Read from PubSub into a PCollection. lines = p | beam.io.ReadStringsFromPubSub(known_args.input_topic) diff --git a/sdks/python/apache_beam/io/gcp/pubsub.py b/sdks/python/apache_beam/io/gcp/pubsub.py index fabe29612a8b6..32d388a9e50f6 100644 --- a/sdks/python/apache_beam/io/gcp/pubsub.py +++ b/sdks/python/apache_beam/io/gcp/pubsub.py @@ -24,12 +24,16 @@ from __future__ import absolute_import +import re + from apache_beam import coders from apache_beam.io.iobase import Read from apache_beam.io.iobase import Write from apache_beam.runners.dataflow.native_io import iobase as dataflow_io +from apache_beam.transforms import core from apache_beam.transforms import PTransform from apache_beam.transforms import Map +from apache_beam.transforms import window from apache_beam.transforms.display import DisplayDataItem @@ -43,11 +47,12 @@ def __init__(self, topic=None, subscription=None, id_label=None): """Initializes ``ReadStringsFromPubSub``. Attributes: - topic: Cloud Pub/Sub topic in the form "/topics//". If - provided then subscription must be None. + topic: Cloud Pub/Sub topic in the form "projects//topics/ + ". If provided, subscription must be None. subscription: Existing Cloud Pub/Sub subscription to use in the - form "projects//subscriptions/". If provided then - topic must be None. + form "projects//subscriptions/". If not + specified, a temporary subscription will be created from the specified + topic. If provided, topic must be None. id_label: The attribute on incoming Pub/Sub messages to use as a unique record identifier. When specified, the value of this attribute (which can be any string that uniquely identifies the record) will be used for @@ -56,17 +61,14 @@ def __init__(self, topic=None, subscription=None, id_label=None): case, deduplication of the stream will be strictly best effort. """ super(ReadStringsFromPubSub, self).__init__() - if topic and subscription: - raise ValueError("Only one of topic or subscription should be provided.") - - if not (topic or subscription): - raise ValueError("Either a topic or subscription must be provided.") - self._source = _PubSubPayloadSource( topic, subscription=subscription, id_label=id_label) + def get_windowing(self, unused_inputs): + return core.Windowing(window.GlobalWindows()) + def expand(self, pvalue): pcoll = pvalue.pipeline | Read(self._source) pcoll.element_type = bytes @@ -93,15 +95,45 @@ def expand(self, pcoll): return pcoll | Write(self._sink) +PROJECT_ID_REGEXP = '[a-z][-a-z0-9:.]{4,61}[a-z0-9]' +SUBSCRIPTION_REGEXP = 'projects/([^/]+)/subscriptions/(.+)' +TOPIC_REGEXP = 'projects/([^/]+)/topics/(.+)' + + +def parse_topic(full_topic): + match = re.match(TOPIC_REGEXP, full_topic) + if not match: + raise ValueError( + 'PubSub topic must be in the form "projects//topics' + '/" (got %r).' % full_topic) + project, topic_name = match.group(1), match.group(2) + if not re.match(PROJECT_ID_REGEXP, project): + raise ValueError('Invalid PubSub project name: %r.' % project) + return project, topic_name + + +def parse_subscription(full_subscription): + match = re.match(SUBSCRIPTION_REGEXP, full_subscription) + if not match: + raise ValueError( + 'PubSub subscription must be in the form "projects/' + '/subscriptions/" (got %r).' % full_subscription) + project, subscription_name = match.group(1), match.group(2) + if not re.match(PROJECT_ID_REGEXP, project): + raise ValueError('Invalid PubSub project name: %r.' % project) + return project, subscription_name + + class _PubSubPayloadSource(dataflow_io.NativeSource): """Source for the payload of a message as bytes from a Cloud Pub/Sub topic. Attributes: - topic: Cloud Pub/Sub topic in the form "/topics//". If - provided then topic must be None. + topic: Cloud Pub/Sub topic in the form "projects//topics/". + If provided, subscription must be None. subscription: Existing Cloud Pub/Sub subscription to use in the - form "projects//subscriptions/". If provided then - subscription must be None. + form "projects//subscriptions/". If not specified, + a temporary subscription will be created from the specified topic. If + provided, topic must be None. id_label: The attribute on incoming Pub/Sub messages to use as a unique record identifier. When specified, the value of this attribute (which can be any string that uniquely identifies the record) will be used for @@ -111,13 +143,26 @@ class _PubSubPayloadSource(dataflow_io.NativeSource): """ def __init__(self, topic=None, subscription=None, id_label=None): - # we are using this coder explicitly for portability reasons of PubsubIO + # We are using this coder explicitly for portability reasons of PubsubIO # across implementations in languages. self.coder = coders.BytesCoder() - self.topic = topic - self.subscription = subscription + self.full_topic = topic + self.full_subscription = subscription + self.topic_name = None + self.subscription_name = None self.id_label = id_label + # Perform some validation on the topic and subscription. + if not (topic or subscription): + raise ValueError('Either a topic or subscription must be provided.') + if topic and subscription: + raise ValueError('Only one of topic or subscription should be provided.') + + if topic: + self.project, self.topic_name = parse_topic(topic) + if subscription: + self.project, self.subscription_name = parse_subscription(subscription) + @property def format(self): """Source format name required for remote execution.""" @@ -128,10 +173,10 @@ def display_data(self): DisplayDataItem(self.id_label, label='ID Label Attribute').drop_if_none(), 'topic': - DisplayDataItem(self.topic, - label='Pubsub Topic'), + DisplayDataItem(self.full_topic, + label='Pubsub Topic').drop_if_none(), 'subscription': - DisplayDataItem(self.subscription, + DisplayDataItem(self.full_subscription, label='Pubsub Subscription').drop_if_none()} def reader(self): @@ -146,7 +191,9 @@ def __init__(self, topic): # we are using this coder explicitly for portability reasons of PubsubIO # across implementations in languages. self.coder = coders.BytesCoder() - self.topic = topic + self.full_topic = topic + + self.project, self.topic_name = parse_topic(topic) @property def format(self): @@ -154,7 +201,7 @@ def format(self): return 'pubsub' def display_data(self): - return {'topic': DisplayDataItem(self.topic, label='Pubsub Topic')} + return {'topic': DisplayDataItem(self.full_topic, label='Pubsub Topic')} def writer(self): raise NotImplementedError( diff --git a/sdks/python/apache_beam/io/gcp/pubsub_test.py b/sdks/python/apache_beam/io/gcp/pubsub_test.py index 5d3e985597c0b..0dcc3c39ab5f5 100644 --- a/sdks/python/apache_beam/io/gcp/pubsub_test.py +++ b/sdks/python/apache_beam/io/gcp/pubsub_test.py @@ -31,89 +31,108 @@ from apache_beam.transforms.display_test import DisplayDataItemMatcher +# Protect against environments where the PubSub library is not available. +# pylint: disable=wrong-import-order, wrong-import-position +try: + from google.cloud import pubsub +except ImportError: + pubsub = None +# pylint: enable=wrong-import-order, wrong-import-position + + +@unittest.skipIf(pubsub is None, 'GCP dependencies are not installed') class TestReadStringsFromPubSub(unittest.TestCase): def test_expand_with_topic(self): p = TestPipeline() - pcoll = p | ReadStringsFromPubSub('a_topic', None, 'a_label') + pcoll = p | ReadStringsFromPubSub('projects/fakeprj/topics/a_topic', + None, 'a_label') # Ensure that the output type is str self.assertEqual(unicode, pcoll.element_type) - # Ensure that the type on the intermediate read output PCollection is bytes - read_pcoll = pcoll.producer.inputs[0] - self.assertEqual(bytes, read_pcoll.element_type) - # Ensure that the properties passed through correctly - source = read_pcoll.producer.transform.source - self.assertEqual('a_topic', source.topic) + source = pcoll.producer.transform._source + self.assertEqual('a_topic', source.topic_name) self.assertEqual('a_label', source.id_label) def test_expand_with_subscription(self): p = TestPipeline() - pcoll = p | ReadStringsFromPubSub(None, 'a_subscription', 'a_label') + pcoll = p | ReadStringsFromPubSub( + None, 'projects/fakeprj/subscriptions/a_subscription', 'a_label') # Ensure that the output type is str self.assertEqual(unicode, pcoll.element_type) - # Ensure that the type on the intermediate read output PCollection is bytes - read_pcoll = pcoll.producer.inputs[0] - self.assertEqual(bytes, read_pcoll.element_type) - # Ensure that the properties passed through correctly - source = read_pcoll.producer.transform.source - self.assertEqual('a_subscription', source.subscription) + source = pcoll.producer.transform._source + self.assertEqual('a_subscription', source.subscription_name) self.assertEqual('a_label', source.id_label) - def test_expand_with_both_topic_and_subscription(self): - with self.assertRaisesRegexp( - ValueError, "Only one of topic or subscription should be provided."): - ReadStringsFromPubSub('a_topic', 'a_subscription', 'a_label') - def test_expand_with_no_topic_or_subscription(self): with self.assertRaisesRegexp( ValueError, "Either a topic or subscription must be provided."): ReadStringsFromPubSub(None, None, 'a_label') + def test_expand_with_both_topic_and_subscription(self): + with self.assertRaisesRegexp( + ValueError, "Only one of topic or subscription should be provided."): + ReadStringsFromPubSub('a_topic', 'a_subscription', 'a_label') + +@unittest.skipIf(pubsub is None, 'GCP dependencies are not installed') class TestWriteStringsToPubSub(unittest.TestCase): def test_expand(self): p = TestPipeline() - pdone = p | ReadStringsFromPubSub('baz') | WriteStringsToPubSub('a_topic') + pdone = (p + | ReadStringsFromPubSub('projects/fakeprj/topics/baz') + | WriteStringsToPubSub('projects/fakeprj/topics/a_topic')) # Ensure that the properties passed through correctly - sink = pdone.producer.transform.sink - self.assertEqual('a_topic', sink.topic) - - # Ensure that the type on the intermediate payload transformer output - # PCollection is bytes - write_pcoll = pdone.producer.inputs[0] - self.assertEqual(bytes, write_pcoll.element_type) + self.assertEqual('a_topic', pdone.producer.transform.dofn.topic_name) +@unittest.skipIf(pubsub is None, 'GCP dependencies are not installed') class TestPubSubSource(unittest.TestCase): - def test_display_data(self): - source = _PubSubPayloadSource('a_topic', 'a_subscription', 'a_label') + def test_display_data_topic(self): + source = _PubSubPayloadSource( + 'projects/fakeprj/topics/a_topic', + None, + 'a_label') + dd = DisplayData.create_from(source) + expected_items = [ + DisplayDataItemMatcher( + 'topic', 'projects/fakeprj/topics/a_topic'), + DisplayDataItemMatcher('id_label', 'a_label')] + + hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items)) + + def test_display_data_subscription(self): + source = _PubSubPayloadSource( + None, + 'projects/fakeprj/subscriptions/a_subscription', + 'a_label') dd = DisplayData.create_from(source) expected_items = [ - DisplayDataItemMatcher('topic', 'a_topic'), - DisplayDataItemMatcher('subscription', 'a_subscription'), + DisplayDataItemMatcher( + 'subscription', 'projects/fakeprj/subscriptions/a_subscription'), DisplayDataItemMatcher('id_label', 'a_label')] hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items)) def test_display_data_no_subscription(self): - source = _PubSubPayloadSource('a_topic') + source = _PubSubPayloadSource('projects/fakeprj/topics/a_topic') dd = DisplayData.create_from(source) expected_items = [ - DisplayDataItemMatcher('topic', 'a_topic')] + DisplayDataItemMatcher('topic', 'projects/fakeprj/topics/a_topic')] hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items)) +@unittest.skipIf(pubsub is None, 'GCP dependencies are not installed') class TestPubSubSink(unittest.TestCase): def test_display_data(self): - sink = _PubSubPayloadSink('a_topic') + sink = _PubSubPayloadSink('projects/fakeprj/topics/a_topic') dd = DisplayData.create_from(sink) expected_items = [ - DisplayDataItemMatcher('topic', 'a_topic')] + DisplayDataItemMatcher('topic', 'projects/fakeprj/topics/a_topic')] hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items)) diff --git a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py b/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py index f213b3b9db33c..57bcc5e8cdda0 100644 --- a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py +++ b/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py @@ -668,11 +668,12 @@ def run_Read(self, transform_node): raise ValueError('PubSubPayloadSource is currently available for use ' 'only in streaming pipelines.') # Only one of topic or subscription should be set. - if transform.source.topic: - step.add_property(PropertyNames.PUBSUB_TOPIC, transform.source.topic) - elif transform.source.subscription: + if transform.source.full_subscription: step.add_property(PropertyNames.PUBSUB_SUBSCRIPTION, - transform.source.subscription) + transform.source.full_subscription) + elif transform.source.full_topic: + step.add_property(PropertyNames.PUBSUB_TOPIC, + transform.source.full_topic) if transform.source.id_label: step.add_property(PropertyNames.PUBSUB_ID_LABEL, transform.source.id_label) @@ -756,7 +757,7 @@ def run__NativeWrite(self, transform_node): if not standard_options.streaming: raise ValueError('PubSubPayloadSink is currently available for use ' 'only in streaming pipelines.') - step.add_property(PropertyNames.PUBSUB_TOPIC, transform.sink.topic) + step.add_property(PropertyNames.PUBSUB_TOPIC, transform.sink.full_topic) else: raise ValueError( 'Sink %r has unexpected format %s.' % ( diff --git a/sdks/python/apache_beam/runners/direct/direct_runner.py b/sdks/python/apache_beam/runners/direct/direct_runner.py index 2a75977576128..1a94b3d2b458d 100644 --- a/sdks/python/apache_beam/runners/direct/direct_runner.py +++ b/sdks/python/apache_beam/runners/direct/direct_runner.py @@ -26,8 +26,10 @@ import collections import logging +import apache_beam as beam from apache_beam import typehints from apache_beam.metrics.execution import MetricsEnvironment +from apache_beam.pvalue import PCollection from apache_beam.runners.direct.bundle_factory import BundleFactory from apache_beam.runners.runner import PipelineResult from apache_beam.runners.runner import PipelineRunner @@ -107,6 +109,58 @@ def apply__GroupAlsoByWindow(self, transform, pcoll): .with_output_types(*type_hints.output_types[0])) return transform.expand(pcoll) + def apply_ReadStringsFromPubSub(self, transform, pcoll): + try: + from google.cloud import pubsub as unused_pubsub + except ImportError: + raise ImportError('Google Cloud PubSub not available, please install ' + 'apache_beam[gcp]') + # Execute this as a native transform. + output = PCollection(pcoll.pipeline) + output.element_type = unicode + return output + + def apply_WriteStringsToPubSub(self, transform, pcoll): + try: + from google.cloud import pubsub + except ImportError: + raise ImportError('Google Cloud PubSub not available, please install ' + 'apache_beam[gcp]') + project = transform._sink.project + topic_name = transform._sink.topic_name + + class DirectWriteToPubSub(beam.DoFn): + _topic = None + + def __init__(self, project, topic_name): + self.project = project + self.topic_name = topic_name + + def start_bundle(self): + if self._topic is None: + self._topic = pubsub.Client(project=self.project).topic( + self.topic_name) + self._buffer = [] + + def process(self, elem): + self._buffer.append(elem.encode('utf-8')) + if len(self._buffer) >= 100: + self._flush() + + def finish_bundle(self): + self._flush() + + def _flush(self): + if self._buffer: + with self._topic.batch() as batch: + for datum in self._buffer: + batch.publish(datum) + self._buffer = [] + + output = pcoll | beam.ParDo(DirectWriteToPubSub(project, topic_name)) + output.element_type = unicode + return output + def run(self, pipeline): """Execute the entire pipeline and returns an DirectPipelineResult.""" diff --git a/sdks/python/apache_beam/runners/direct/transform_evaluator.py b/sdks/python/apache_beam/runners/direct/transform_evaluator.py index 67b24927e2e38..641291d4857c2 100644 --- a/sdks/python/apache_beam/runners/direct/transform_evaluator.py +++ b/sdks/python/apache_beam/runners/direct/transform_evaluator.py @@ -20,6 +20,8 @@ from __future__ import absolute_import import collections +import random +import time from apache_beam import coders from apache_beam import pvalue @@ -48,6 +50,7 @@ from apache_beam.typehints.typecheck import TypeCheckError from apache_beam.typehints.typecheck import TypeCheckWrapperDoFn from apache_beam.utils import counters +from apache_beam.utils.timestamp import Timestamp from apache_beam.utils.timestamp import MIN_TIMESTAMP from apache_beam.options.pipeline_options import TypeOptions @@ -63,6 +66,7 @@ def __init__(self, evaluation_context): self._evaluation_context = evaluation_context self._evaluators = { io.Read: _BoundedReadEvaluator, + io.ReadStringsFromPubSub: _PubSubReadEvaluator, core.Flatten: _FlattenEvaluator, core.ParDo: _ParDoEvaluator, core._GroupByKeyOnly: _GroupByKeyOnlyEvaluator, @@ -357,6 +361,91 @@ def finish_bundle(self): {None: hold}) +class _PubSubSubscriptionWrapper(object): + """Wrapper for garbage-collecting temporary PubSub subscriptions.""" + + def __init__(self, subscription, should_cleanup): + self.subscription = subscription + self.should_cleanup = should_cleanup + + def __del__(self): + if self.should_cleanup: + self.subscription.delete() + + +class _PubSubReadEvaluator(_TransformEvaluator): + """TransformEvaluator for PubSub read.""" + + _subscription_cache = {} + + def __init__(self, evaluation_context, applied_ptransform, + input_committed_bundle, side_inputs, scoped_metrics_container): + assert not side_inputs + super(_PubSubReadEvaluator, self).__init__( + evaluation_context, applied_ptransform, input_committed_bundle, + side_inputs, scoped_metrics_container) + + source = self._applied_ptransform.transform._source + self._subscription = _PubSubReadEvaluator.get_subscription( + self._applied_ptransform, source.project, source.topic_name, + source.subscription_name) + + @classmethod + def get_subscription(cls, transform, project, topic, subscription_name): + if transform not in cls._subscription_cache: + from google.cloud import pubsub + should_create = not subscription_name + if should_create: + subscription_name = 'beam_%d_%x' % ( + int(time.time()), random.randrange(1 << 32)) + cls._subscription_cache[transform] = _PubSubSubscriptionWrapper( + pubsub.Client(project=project).topic(topic).subscription( + subscription_name), + should_create) + if should_create: + cls._subscription_cache[transform].subscription.create() + return cls._subscription_cache[transform].subscription + + def start_bundle(self): + pass + + def process_element(self, element): + pass + + def _read_from_pubsub(self): + from google.cloud import pubsub + # Because of the AutoAck, we are not able to reread messages if this + # evaluator fails with an exception before emitting a bundle. However, + # the DirectRunner currently doesn't retry work items anyway, so the + # pipeline would enter an inconsistent state on any error. + with pubsub.subscription.AutoAck( + self._subscription, return_immediately=True, + max_messages=10) as results: + return [message.data for unused_ack_id, message in results.items()] + + def finish_bundle(self): + data = self._read_from_pubsub() + if data: + output_pcollection = list(self._outputs)[0] + bundle = self._evaluation_context.create_bundle(output_pcollection) + # TODO(ccy): we currently do not use the PubSub message timestamp or + # respect the PubSub source's id_label field. + now = Timestamp.of(time.time()) + for message_data in data: + bundle.output(GlobalWindows.windowed_value(message_data, timestamp=now)) + bundles = [bundle] + else: + bundles = [] + input_pvalue = self._applied_ptransform.inputs + if not input_pvalue: + input_pvalue = pvalue.PBegin(self._applied_ptransform.transform.pipeline) + unprocessed_bundle = self._evaluation_context.create_bundle( + input_pvalue) + return TransformResult( + self._applied_ptransform, bundles, + [unprocessed_bundle], None, {None: Timestamp.of(time.time())}) + + class _FlattenEvaluator(_TransformEvaluator): """TransformEvaluator for Flatten transform.""" From 39a2ed0ccb53bcc96c179c64405c80226bac7b9b Mon Sep 17 00:00:00 2001 From: Mairbek Khadikov Date: Thu, 29 Jun 2017 10:12:50 -0700 Subject: [PATCH 142/200] Ditch apache commons --- sdks/java/io/google-cloud-platform/pom.xml | 11 ----- .../beam/sdk/io/gcp/spanner/RandomUtils.java | 41 +++++++++++++++++++ .../sdk/io/gcp/spanner/SpannerReadIT.java | 11 ++--- .../sdk/io/gcp/spanner/SpannerWriteIT.java | 10 ++--- 4 files changed, 47 insertions(+), 26 deletions(-) create mode 100644 sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/RandomUtils.java diff --git a/sdks/java/io/google-cloud-platform/pom.xml b/sdks/java/io/google-cloud-platform/pom.xml index 94066c7b038a5..09a430a0d146f 100644 --- a/sdks/java/io/google-cloud-platform/pom.xml +++ b/sdks/java/io/google-cloud-platform/pom.xml @@ -258,18 +258,7 @@ proto-google-common-protos - - org.apache.commons - commons-lang3 - provided - - - - org.apache.commons - commons-text - test - org.apache.beam beam-sdks-java-core diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/RandomUtils.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/RandomUtils.java new file mode 100644 index 0000000000000..f479b4a4bfafd --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/RandomUtils.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.spanner; + +import java.util.Random; + +/** + * Useful randomness related utilities. + */ +public class RandomUtils { + + private static final char[] ALPHANUMERIC = "1234567890abcdefghijklmnopqrstuvwxyz".toCharArray(); + + private RandomUtils() { + } + + public static String randomAlphaNumeric(int length) { + Random random = new Random(); + char[] result = new char[length]; + for (int i = 0; i < length; i++) { + result[i] = ALPHANUMERIC[random.nextInt(ALPHANUMERIC.length)]; + } + return new String(result); + } + +} diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerReadIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerReadIT.java index ca43b40104e88..9f7c64eaadac4 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerReadIT.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerReadIT.java @@ -40,8 +40,6 @@ import org.apache.beam.sdk.transforms.Count; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionView; -import org.apache.commons.lang3.RandomStringUtils; -import org.apache.commons.text.RandomStringGenerator; import org.junit.After; import org.junit.Before; import org.junit.Rule; @@ -127,7 +125,7 @@ public void testRead() throws Exception { .set("key") .to((long) i) .set("value") - .to(RandomStringUtils.random(100, true, true)) + .to(RandomUtils.randomAlphaNumeric(100)) .build()); } @@ -161,11 +159,8 @@ public void tearDown() throws Exception { } private String generateDatabaseName() { - String random = - new RandomStringGenerator.Builder() - .withinRange('a', 'z') - .build() - .generate(MAX_DB_NAME_LENGTH - 1 - options.getDatabaseIdPrefix().length()); + String random = RandomUtils + .randomAlphaNumeric(MAX_DB_NAME_LENGTH - 1 - options.getDatabaseIdPrefix().length()); return options.getDatabaseIdPrefix() + "-" + random; } } diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteIT.java index 613756cff2481..2f6cd55b3f56f 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteIT.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteIT.java @@ -41,7 +41,6 @@ import org.apache.beam.sdk.testing.TestPipelineOptions; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.ParDo; -import org.apache.commons.text.RandomStringGenerator; import org.junit.After; import org.junit.Before; import org.junit.Rule; @@ -114,11 +113,8 @@ public void setUp() throws Exception { } private String generateDatabaseName() { - String random = - new RandomStringGenerator.Builder() - .withinRange('a', 'z') - .build() - .generate(MAX_DB_NAME_LENGTH - 1 - options.getDatabaseIdPrefix().length()); + String random = RandomUtils + .randomAlphaNumeric(MAX_DB_NAME_LENGTH - 1 - options.getDatabaseIdPrefix().length()); return options.getDatabaseIdPrefix() + "-" + random; } @@ -166,7 +162,7 @@ public void processElement(ProcessContext c) { Mutation.WriteBuilder builder = Mutation.newInsertOrUpdateBuilder(table); Long key = c.element(); builder.set("Key").to(key); - builder.set("Value").to(new RandomStringGenerator.Builder().build().generate(valueSize)); + builder.set("Value").to(RandomUtils.randomAlphaNumeric(valueSize)); Mutation mutation = builder.build(); c.output(mutation); } From f46a40c279499737bb7fb45af5e299d76f6af49b Mon Sep 17 00:00:00 2001 From: Valentyn Tymofieiev Date: Wed, 28 Jun 2017 16:41:03 -0700 Subject: [PATCH 143/200] Use SDK harness container for FnAPI jobs when worker_harness_container_image is not specified. Add a separate image tag to use with the SDK harness container. --- .../runners/dataflow/internal/apiclient.py | 6 +-- .../runners/dataflow/internal/dependency.py | 44 ++++++++++++++++--- 2 files changed, 39 insertions(+), 11 deletions(-) diff --git a/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py b/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py index df1a3f22d9cd7..edac9d7d55858 100644 --- a/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py +++ b/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py @@ -38,7 +38,6 @@ from apache_beam.io.gcp.internal.clients import storage from apache_beam.runners.dataflow.internal import dependency from apache_beam.runners.dataflow.internal.clients import dataflow -from apache_beam.runners.dataflow.internal.dependency import get_required_container_version from apache_beam.runners.dataflow.internal.dependency import get_sdk_name_and_version from apache_beam.runners.dataflow.internal.names import PropertyNames from apache_beam.transforms import cy_combiners @@ -205,11 +204,8 @@ def __init__(self, packages, options, environment_version): pool.workerHarnessContainerImage = ( self.worker_options.worker_harness_container_image) else: - # Default to using the worker harness container image for the current SDK - # version. pool.workerHarnessContainerImage = ( - 'dataflow.gcr.io/v1beta3/python:%s' % - get_required_container_version()) + dependency.get_default_container_image_for_current_sdk(job_type)) if self.worker_options.use_public_ips is not None: if self.worker_options.use_public_ips: pool.ipConfiguration = ( diff --git a/sdks/python/apache_beam/runners/dataflow/internal/dependency.py b/sdks/python/apache_beam/runners/dataflow/internal/dependency.py index 03e17940e5943..a40a58273d24c 100644 --- a/sdks/python/apache_beam/runners/dataflow/internal/dependency.py +++ b/sdks/python/apache_beam/runners/dataflow/internal/dependency.py @@ -71,9 +71,15 @@ # Update this version to the next version whenever there is a change that will -# require changes to the execution environment. +# require changes to legacy Dataflow worker execution environment. # This should be in the beam-[version]-[date] format, date is optional. +# BEAM_CONTAINER_VERSION and BEAM_FNAPI_CONTAINER version should coincide +# when we make a release. BEAM_CONTAINER_VERSION = 'beam-2.1.0-20170626' +# Update this version to the next version whenever there is a change that +# requires changes to SDK harness container or SDK harness launcher. +# This should be in the beam-[version]-[date] format, date is optional. +BEAM_FNAPI_CONTAINER_VERSION = 'beam-2.1.0-20170621' # Standard file names used for staging files. WORKFLOW_TARBALL_FILE = 'workflow.tar.gz' @@ -474,10 +480,33 @@ def _stage_beam_sdk_tarball(sdk_remote_location, staged_path, temp_dir): 'type of location: %s' % sdk_remote_location) -def get_required_container_version(): +def get_default_container_image_for_current_sdk(job_type): + """For internal use only; no backwards-compatibility guarantees. + + Args: + job_type (str): BEAM job type. + + Returns: + str: Google Cloud Dataflow container image for remote execution. + """ + # TODO(tvalentyn): Use enumerated type instead of strings for job types. + if job_type == 'FNAPI_BATCH' or job_type == 'FNAPI_STREAMING': + image_name = 'dataflow.gcr.io/v1beta3/python-fnapi' + else: + image_name = 'dataflow.gcr.io/v1beta3/python' + image_tag = _get_required_container_version(job_type) + return image_name + ':' + image_tag + + +def _get_required_container_version(job_type=None): """For internal use only; no backwards-compatibility guarantees. - Returns the Google Cloud Dataflow container version for remote execution. + Args: + job_type (str, optional): BEAM job type. Defaults to None. + + Returns: + str: The tag of worker container images in GCR that corresponds to + current version of the SDK. """ # TODO(silviuc): Handle apache-beam versions when we have official releases. import pkg_resources as pkg @@ -493,7 +522,10 @@ def get_required_container_version(): except pkg.DistributionNotFound: # This case covers Apache Beam end-to-end testing scenarios. All these tests # will run with a special container version. - return BEAM_CONTAINER_VERSION + if job_type == 'FNAPI_BATCH' or job_type == 'FNAPI_STREAMING': + return BEAM_FNAPI_CONTAINER_VERSION + else: + return BEAM_CONTAINER_VERSION def get_sdk_name_and_version(): @@ -501,7 +533,7 @@ def get_sdk_name_and_version(): Returns name and version of SDK reported to Google Cloud Dataflow.""" import pkg_resources as pkg - container_version = get_required_container_version() + container_version = _get_required_container_version() try: pkg.get_distribution(GOOGLE_PACKAGE_NAME) return ('Google Cloud Dataflow SDK for Python', container_version) @@ -513,7 +545,7 @@ def get_sdk_package_name(): """For internal use only; no backwards-compatibility guarantees. Returns the PyPI package name to be staged to Google Cloud Dataflow.""" - container_version = get_required_container_version() + container_version = _get_required_container_version() if container_version == BEAM_CONTAINER_VERSION: return BEAM_PACKAGE_NAME return GOOGLE_PACKAGE_NAME From c6de4233d1c1bc812ba2dea45291d9dcb40aa152 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jean-Baptiste=20Onofr=C3=A9?= Date: Thu, 29 Jun 2017 13:25:36 +0200 Subject: [PATCH 144/200] Define the projectId in the SpannerIO Read Test (utest, not itest) --- .../apache/beam/sdk/io/gcp/spanner/SpannerIOReadTest.java | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOReadTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOReadTest.java index e5d4e72f51981..5ba2da09e77c3 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOReadTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOReadTest.java @@ -150,6 +150,7 @@ public void validQuery() throws Exception { public void runQuery() throws Exception { SpannerIO.Read read = SpannerIO.read() + .withProjectId("test") .withInstanceId("123") .withDatabaseId("aaa") .withTimestamp(Timestamp.now()) @@ -176,6 +177,7 @@ public void runQuery() throws Exception { public void runRead() throws Exception { SpannerIO.Read read = SpannerIO.read() + .withProjectId("test") .withInstanceId("123") .withDatabaseId("aaa") .withTimestamp(Timestamp.now()) @@ -202,6 +204,7 @@ public void runRead() throws Exception { public void runReadUsingIndex() throws Exception { SpannerIO.Read read = SpannerIO.read() + .withProjectId("test") .withInstanceId("123") .withDatabaseId("aaa") .withTimestamp(Timestamp.now()) @@ -232,11 +235,13 @@ public void readPipeline() throws Exception { PCollectionView tx = pipeline .apply("tx", SpannerIO.createTransaction() + .withProjectId("test") .withInstanceId("123") .withDatabaseId("aaa") .withServiceFactory(serviceFactory)); PCollection one = pipeline.apply("read q", SpannerIO.read() + .withProjectId("test") .withInstanceId("123") .withDatabaseId("aaa") .withTimestamp(Timestamp.now()) @@ -244,6 +249,7 @@ public void readPipeline() throws Exception { .withServiceFactory(serviceFactory) .withTransaction(tx)); PCollection two = pipeline.apply("read r", SpannerIO.read() + .withProjectId("test") .withInstanceId("123") .withDatabaseId("aaa") .withTimestamp(Timestamp.now()) From 5744fa84520c16cce73752e7d04e8b6628ef8979 Mon Sep 17 00:00:00 2001 From: Michael Luckey Date: Mon, 29 May 2017 01:00:48 +0200 Subject: [PATCH 145/200] [BEAM-2373] Upgrade commons-compress dependency version to 1.14 --- pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index 536a11c5bbfdd..fe51660fdc8e8 100644 --- a/pom.xml +++ b/pom.xml @@ -101,7 +101,7 @@ - 1.9 + 1.14 3.6 1.1 2.24.0 From dc1dca8633775545b5b4b509724716108d5d01e4 Mon Sep 17 00:00:00 2001 From: Ahmet Altay Date: Thu, 29 Jun 2017 10:56:25 -0700 Subject: [PATCH 146/200] Select SDK distribution based on the selected SDK name --- .../runners/dataflow/internal/dependency.py | 23 +++++++++++++------ 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/sdks/python/apache_beam/runners/dataflow/internal/dependency.py b/sdks/python/apache_beam/runners/dataflow/internal/dependency.py index a40a58273d24c..62c09ed141352 100644 --- a/sdks/python/apache_beam/runners/dataflow/internal/dependency.py +++ b/sdks/python/apache_beam/runners/dataflow/internal/dependency.py @@ -69,12 +69,15 @@ from apache_beam.options.pipeline_options import GoogleCloudOptions from apache_beam.options.pipeline_options import SetupOptions +# All constants are for internal use only; no backwards-compatibility +# guarantees. +# In a released version BEAM_CONTAINER_VERSION and BEAM_FNAPI_CONTAINER_VERSION +# should match each other, and should be in the same format as the SDK version +# (i.e. MAJOR.MINOR.PATCH). For non-released (dev) versions, read below. # Update this version to the next version whenever there is a change that will # require changes to legacy Dataflow worker execution environment. # This should be in the beam-[version]-[date] format, date is optional. -# BEAM_CONTAINER_VERSION and BEAM_FNAPI_CONTAINER version should coincide -# when we make a release. BEAM_CONTAINER_VERSION = 'beam-2.1.0-20170626' # Update this version to the next version whenever there is a change that # requires changes to SDK harness container or SDK harness launcher. @@ -86,9 +89,14 @@ REQUIREMENTS_FILE = 'requirements.txt' EXTRA_PACKAGES_FILE = 'extra_packages.txt' +# Package names for different distributions GOOGLE_PACKAGE_NAME = 'google-cloud-dataflow' BEAM_PACKAGE_NAME = 'apache-beam' +# SDK identifiers for different distributions +GOOGLE_SDK_NAME = 'Google Cloud Dataflow SDK for Python' +BEAM_SDK_NAME = 'Apache Beam SDK for Python' + def _dependency_file_copy(from_path, to_path): """Copies a local file to a GCS file or vice versa.""" @@ -536,19 +544,20 @@ def get_sdk_name_and_version(): container_version = _get_required_container_version() try: pkg.get_distribution(GOOGLE_PACKAGE_NAME) - return ('Google Cloud Dataflow SDK for Python', container_version) + return (GOOGLE_SDK_NAME, container_version) except pkg.DistributionNotFound: - return ('Apache Beam SDK for Python', beam_version.__version__) + return (BEAM_SDK_NAME, beam_version.__version__) def get_sdk_package_name(): """For internal use only; no backwards-compatibility guarantees. Returns the PyPI package name to be staged to Google Cloud Dataflow.""" - container_version = _get_required_container_version() - if container_version == BEAM_CONTAINER_VERSION: + sdk_name, _ = get_sdk_name_and_version() + if sdk_name == GOOGLE_SDK_NAME: + return GOOGLE_PACKAGE_NAME + else: return BEAM_PACKAGE_NAME - return GOOGLE_PACKAGE_NAME def _download_pypi_sdk_package(temp_dir): From 0acfe70439b50184a69601ca4bb8cff9780fa4ef Mon Sep 17 00:00:00 2001 From: Stephen Sisk Date: Wed, 28 Jun 2017 15:34:45 -0700 Subject: [PATCH 147/200] GCP IO ITs now all use --project option Up until now, some IO ITs used --projectId and others used --project This mixing meant that running all the tests in one test run was impossible. --- .../beam/sdk/io/gcp/bigtable/BigtableReadIT.java | 5 ++++- .../sdk/io/gcp/bigtable/BigtableTestOptions.java | 5 ----- .../sdk/io/gcp/bigtable/BigtableWriteIT.java | 4 +++- .../beam/sdk/io/gcp/spanner/SpannerReadIT.java | 16 ++++++++-------- .../beam/sdk/io/gcp/spanner/SpannerWriteIT.java | 15 +++++++-------- 5 files changed, 22 insertions(+), 23 deletions(-) diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableReadIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableReadIT.java index a064bd64235ba..e47fd0f23c4cc 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableReadIT.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableReadIT.java @@ -20,6 +20,7 @@ import com.google.bigtable.v2.Row; import com.google.cloud.bigtable.config.BigtableOptions; import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.extensions.gcp.options.GcpOptions; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; @@ -41,8 +42,10 @@ public void testE2EBigtableRead() throws Exception { BigtableTestOptions options = TestPipeline.testingPipelineOptions() .as(BigtableTestOptions.class); + String project = TestPipeline.testingPipelineOptions().as(GcpOptions.class).getProject(); + BigtableOptions.Builder bigtableOptionsBuilder = new BigtableOptions.Builder() - .setProjectId(options.getProjectId()) + .setProjectId(project) .setInstanceId(options.getInstanceId()); final String tableId = "BigtableReadTest"; diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableTestOptions.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableTestOptions.java index 0ab757621109e..03cb6979e2d75 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableTestOptions.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableTestOptions.java @@ -25,11 +25,6 @@ * Properties needed when using Bigtable with the Beam SDK. */ public interface BigtableTestOptions extends TestPipelineOptions { - @Description("Project ID for Bigtable") - @Default.String("apache-beam-testing") - String getProjectId(); - void setProjectId(String value); - @Description("Instance ID for Bigtable") @Default.String("beam-test") String getInstanceId(); diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableWriteIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableWriteIT.java index 1d168f169985d..72ba8363a4dbe 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableWriteIT.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableWriteIT.java @@ -73,15 +73,17 @@ public class BigtableWriteIT implements Serializable { private static BigtableTableAdminClient tableAdminClient; private final String tableId = String.format("BigtableWriteIT-%tF-% mutations = new ArrayList<>(); for (int i = 0; i < 5L; i++) { @@ -134,7 +134,7 @@ public void testRead() throws Exception { databaseClient.writeAtLeastOnce(mutations); SpannerConfig spannerConfig = SpannerConfig.create() - .withProjectId(options.getProjectId()) + .withProjectId(project) .withInstanceId(options.getInstanceId()) .withDatabaseId(databaseName); diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteIT.java index 613756cff2481..78a360f36679d 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteIT.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteIT.java @@ -33,6 +33,7 @@ import com.google.cloud.spanner.Statement; import com.google.spanner.admin.database.v1.CreateDatabaseMetadata; import java.util.Collections; +import org.apache.beam.sdk.extensions.gcp.options.GcpOptions; import org.apache.beam.sdk.io.GenerateSequence; import org.apache.beam.sdk.options.Default; import org.apache.beam.sdk.options.Description; @@ -59,11 +60,6 @@ public class SpannerWriteIT { /** Pipeline options for this test. */ public interface SpannerTestPipelineOptions extends TestPipelineOptions { - @Description("Project ID for Spanner") - @Default.String("apache-beam-testing") - String getProjectId(); - void setProjectId(String value); - @Description("Instance ID to write to in Spanner") @Default.String("beam-test") String getInstanceId(); @@ -84,13 +80,16 @@ public interface SpannerTestPipelineOptions extends TestPipelineOptions { private DatabaseAdminClient databaseAdminClient; private SpannerTestPipelineOptions options; private String databaseName; + private String project; @Before public void setUp() throws Exception { PipelineOptionsFactory.register(SpannerTestPipelineOptions.class); options = TestPipeline.testingPipelineOptions().as(SpannerTestPipelineOptions.class); - spanner = SpannerOptions.newBuilder().setProjectId(options.getProjectId()).build().getService(); + project = TestPipeline.testingPipelineOptions().as(GcpOptions.class).getProject(); + + spanner = SpannerOptions.newBuilder().setProjectId(project).build().getService(); databaseName = generateDatabaseName(); @@ -128,7 +127,7 @@ public void testWrite() throws Exception { .apply(ParDo.of(new GenerateMutations(options.getTable()))) .apply( SpannerIO.write() - .withProjectId(options.getProjectId()) + .withProjectId(project) .withInstanceId(options.getInstanceId()) .withDatabaseId(databaseName)); @@ -136,7 +135,7 @@ public void testWrite() throws Exception { DatabaseClient databaseClient = spanner.getDatabaseClient( DatabaseId.of( - options.getProjectId(), options.getInstanceId(), databaseName)); + project, options.getInstanceId(), databaseName)); ResultSet resultSet = databaseClient From ab7f6f6dbd3ad67c5da577bc9395cb09e35069d9 Mon Sep 17 00:00:00 2001 From: Stephen Sisk Date: Thu, 29 Jun 2017 13:29:22 -0700 Subject: [PATCH 148/200] Don't call .testingPipelineOptions() a second time --- .../org/apache/beam/sdk/io/gcp/bigtable/BigtableReadIT.java | 2 +- .../org/apache/beam/sdk/io/gcp/bigtable/BigtableWriteIT.java | 2 +- .../java/org/apache/beam/sdk/io/gcp/spanner/SpannerReadIT.java | 2 +- .../java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteIT.java | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableReadIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableReadIT.java index e47fd0f23c4cc..91f0baefac634 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableReadIT.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableReadIT.java @@ -42,7 +42,7 @@ public void testE2EBigtableRead() throws Exception { BigtableTestOptions options = TestPipeline.testingPipelineOptions() .as(BigtableTestOptions.class); - String project = TestPipeline.testingPipelineOptions().as(GcpOptions.class).getProject(); + String project = options.as(GcpOptions.class).getProject(); BigtableOptions.Builder bigtableOptionsBuilder = new BigtableOptions.Builder() .setProjectId(project) diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableWriteIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableWriteIT.java index 72ba8363a4dbe..010bcc40590f1 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableWriteIT.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableWriteIT.java @@ -79,7 +79,7 @@ public class BigtableWriteIT implements Serializable { public void setup() throws Exception { PipelineOptionsFactory.register(BigtableTestOptions.class); options = TestPipeline.testingPipelineOptions().as(BigtableTestOptions.class); - project = TestPipeline.testingPipelineOptions().as(GcpOptions.class).getProject(); + project = options.as(GcpOptions.class).getProject(); bigtableOptions = new Builder() diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerReadIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerReadIT.java index 32183f91788c3..bfbda503abc03 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerReadIT.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerReadIT.java @@ -88,7 +88,7 @@ public void setUp() throws Exception { PipelineOptionsFactory.register(SpannerTestPipelineOptions.class); options = TestPipeline.testingPipelineOptions().as(SpannerTestPipelineOptions.class); - project = TestPipeline.testingPipelineOptions().as(GcpOptions.class).getProject(); + project = options.as(GcpOptions.class).getProject(); spanner = SpannerOptions.newBuilder().setProjectId(project).build().getService(); diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteIT.java index 78a360f36679d..436d01eff70fd 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteIT.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteIT.java @@ -87,7 +87,7 @@ public void setUp() throws Exception { PipelineOptionsFactory.register(SpannerTestPipelineOptions.class); options = TestPipeline.testingPipelineOptions().as(SpannerTestPipelineOptions.class); - project = TestPipeline.testingPipelineOptions().as(GcpOptions.class).getProject(); + project = options.as(GcpOptions.class).getProject(); spanner = SpannerOptions.newBuilder().setProjectId(project).build().getService(); From 7b8cd6401cb5ed6e184ed36571a89d3ae324dd5f Mon Sep 17 00:00:00 2001 From: Jeremie Lenfant-Engelmann Date: Wed, 28 Jun 2017 18:32:56 -0700 Subject: [PATCH 149/200] Properly convert milliseconds whether there's less than 3/more than 9 digits. TimeUtil did not properly convert (and returned null) when the number of digits for fractions of seconds was less than 3 digits or more than 9 digits. The solution is to pad with zeros when there is less than 3 digits and to truncate when there is more than 3. --- .../beam/runners/dataflow/util/TimeUtil.java | 24 +++++++------------ .../runners/dataflow/util/TimeUtilTest.java | 6 +++++ 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/TimeUtil.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/TimeUtil.java index bff379fc1cc70..172dc6ee03dbb 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/TimeUtil.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/TimeUtil.java @@ -17,6 +17,7 @@ */ package org.apache.beam.runners.dataflow.util; +import com.google.common.base.Strings; import java.util.regex.Matcher; import java.util.regex.Pattern; import javax.annotation.Nullable; @@ -98,26 +99,19 @@ public static Instant fromCloudTime(String time) { int hour = Integer.valueOf(matcher.group(4)); int minute = Integer.valueOf(matcher.group(5)); int second = Integer.valueOf(matcher.group(6)); - int millis = 0; - - String frac = matcher.group(7); - if (frac != null) { - int fracs = Integer.valueOf(frac); - if (frac.length() == 3) { // millisecond resolution - millis = fracs; - } else if (frac.length() == 6) { // microsecond resolution - millis = fracs / 1000; - } else if (frac.length() == 9) { // nanosecond resolution - millis = fracs / 1000000; - } else { - return null; - } - } + int millis = computeMillis(matcher.group(7)); return new DateTime(year, month, day, hour, minute, second, millis, ISOChronology.getInstanceUTC()).toInstant(); } + private static int computeMillis(String frac) { + if (frac == null) { + return 0; + } + return Integer.valueOf(frac.length() > 3 ? frac.substring(0, 3) : Strings.padEnd(frac, 3, '0')); + } + /** * Converts a {@link ReadableDuration} into a Dataflow API duration string. */ diff --git a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/util/TimeUtilTest.java b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/util/TimeUtilTest.java index e0785d424fe70..1ac9fabf6a455 100644 --- a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/util/TimeUtilTest.java +++ b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/util/TimeUtilTest.java @@ -47,8 +47,14 @@ public void fromCloudTimeShouldParseTimeStrings() { assertEquals(new Instant(1), fromCloudTime("1970-01-01T00:00:00.001001Z")); assertEquals(new Instant(1), fromCloudTime("1970-01-01T00:00:00.001000000Z")); assertEquals(new Instant(1), fromCloudTime("1970-01-01T00:00:00.001000001Z")); + assertEquals(new Instant(0), fromCloudTime("1970-01-01T00:00:00.0Z")); + assertEquals(new Instant(0), fromCloudTime("1970-01-01T00:00:00.00Z")); + assertEquals(new Instant(420), fromCloudTime("1970-01-01T00:00:00.42Z")); + assertEquals(new Instant(300), fromCloudTime("1970-01-01T00:00:00.3Z")); + assertEquals(new Instant(20), fromCloudTime("1970-01-01T00:00:00.02Z")); assertNull(fromCloudTime("")); assertNull(fromCloudTime("1970-01-01T00:00:00")); + assertNull(fromCloudTime("1970-01-01T00:00:00.1e3Z")); } @Test From 56cb6c51748fde6ad56522733ab10edca062e802 Mon Sep 17 00:00:00 2001 From: Kenneth Knowles Date: Tue, 13 Jun 2017 10:29:50 -0700 Subject: [PATCH 150/200] Add support for PipelineOptions parameters This is a step towards eliminating catch-all context parameters and making DoFns express their fine-grained data needs. --- ...oundedSplittableProcessElementInvoker.java | 5 ++ .../beam/runners/core/SimpleDoFnRunner.java | 20 ++++++ .../beam/sdk/transforms/DoFnTester.java | 5 ++ .../reflect/ByteBuddyDoFnInvokerFactory.java | 6 ++ .../sdk/transforms/reflect/DoFnInvoker.java | 13 +++- .../sdk/transforms/reflect/DoFnSignature.java | 23 +++++++ .../transforms/reflect/DoFnSignatures.java | 22 ++++++- .../apache/beam/sdk/transforms/ParDoTest.java | 63 +++++++++++++++++++ .../reflect/DoFnSignaturesTest.java | 14 +++++ 9 files changed, 169 insertions(+), 2 deletions(-) diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvoker.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvoker.java index 2db6531e050ca..475abf25eaa06 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvoker.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvoker.java @@ -117,6 +117,11 @@ public BoundedWindow window() { "Access to window of the element not supported in Splittable DoFn"); } + @Override + public PipelineOptions pipelineOptions() { + return pipelineOptions; + } + @Override public StartBundleContext startBundleContext(DoFn doFn) { throw new IllegalStateException( diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/SimpleDoFnRunner.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/SimpleDoFnRunner.java index 7d7babd1397c6..c3bfef6fec832 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/SimpleDoFnRunner.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/SimpleDoFnRunner.java @@ -232,6 +232,11 @@ public BoundedWindow window() { "Cannot access window outside of @ProcessElement and @OnTimer methods."); } + @Override + public PipelineOptions pipelineOptions() { + return getPipelineOptions(); + } + @Override public DoFn.StartBundleContext startBundleContext(DoFn doFn) { return this; @@ -297,6 +302,11 @@ public BoundedWindow window() { "Cannot access window outside of @ProcessElement and @OnTimer methods."); } + @Override + public PipelineOptions pipelineOptions() { + return getPipelineOptions(); + } + @Override public DoFn.StartBundleContext startBundleContext(DoFn doFn) { throw new UnsupportedOperationException( @@ -466,6 +476,11 @@ public BoundedWindow window() { return Iterables.getOnlyElement(elem.getWindows()); } + @Override + public PipelineOptions pipelineOptions() { + return getPipelineOptions(); + } + @Override public DoFn.StartBundleContext startBundleContext(DoFn doFn) { throw new UnsupportedOperationException("StartBundleContext parameters are not supported."); @@ -567,6 +582,11 @@ public BoundedWindow window() { return window; } + @Override + public PipelineOptions pipelineOptions() { + return getPipelineOptions(); + } + @Override public DoFn.StartBundleContext startBundleContext(DoFn doFn) { throw new UnsupportedOperationException("StartBundleContext parameters are not supported."); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java index 4da9a8096f44f..b2377dd2befd8 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java @@ -289,6 +289,11 @@ public BoundedWindow window() { return window; } + @Override + public PipelineOptions pipelineOptions() { + return getPipelineOptions(); + } + @Override public DoFn.StartBundleContext startBundleContext( DoFn doFn) { diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/ByteBuddyDoFnInvokerFactory.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/ByteBuddyDoFnInvokerFactory.java index 4f67db4b33ce3..837820411d9bd 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/ByteBuddyDoFnInvokerFactory.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/ByteBuddyDoFnInvokerFactory.java @@ -90,6 +90,7 @@ public class ByteBuddyDoFnInvokerFactory implements DoFnInvokerFactory { public static final String PROCESS_CONTEXT_PARAMETER_METHOD = "processContext"; public static final String ON_TIMER_CONTEXT_PARAMETER_METHOD = "onTimerContext"; public static final String WINDOW_PARAMETER_METHOD = "window"; + public static final String PIPELINE_OPTIONS_PARAMETER_METHOD = "pipelineOptions"; public static final String RESTRICTION_TRACKER_PARAMETER_METHOD = "restrictionTracker"; public static final String STATE_PARAMETER_METHOD = "state"; public static final String TIMER_PARAMETER_METHOD = "timer"; @@ -627,6 +628,11 @@ public StackManipulation dispatch(TimerParameter p) { getExtraContextFactoryMethodDescription(TIMER_PARAMETER_METHOD, String.class)), TypeCasting.to(new TypeDescription.ForLoadedType(Timer.class))); } + + @Override + public StackManipulation dispatch(DoFnSignature.Parameter.PipelineOptionsParameter p) { + return simpleExtraContextParameter(PIPELINE_OPTIONS_PARAMETER_METHOD); + } }); } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvoker.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvoker.java index ed81f42870bbe..3b22fdaccb01c 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvoker.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvoker.java @@ -19,6 +19,7 @@ import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.CoderRegistry; +import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.state.State; import org.apache.beam.sdk.state.Timer; import org.apache.beam.sdk.transforms.DoFn; @@ -102,7 +103,12 @@ interface ArgumentProvider { */ BoundedWindow window(); - /** Provide a {@link DoFn.StartBundleContext} to use with the given {@link DoFn}. */ + /** Provide {@link PipelineOptions}. */ + PipelineOptions pipelineOptions(); + + /** + * Provide a {@link DoFn.StartBundleContext} to use with the given {@link DoFn}. + */ DoFn.StartBundleContext startBundleContext(DoFn doFn); /** Provide a {@link DoFn.FinishBundleContext} to use with the given {@link DoFn}. */ @@ -139,6 +145,11 @@ public BoundedWindow window() { return null; } + @Override + public PipelineOptions pipelineOptions() { + return null; + } + @Override public DoFn.StartBundleContext startBundleContext(DoFn doFn) { return null; diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java index 0b4bf90071a8d..6eeed8e054950 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java @@ -27,6 +27,7 @@ import java.util.Map; import javax.annotation.Nullable; import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.state.State; import org.apache.beam.sdk.state.StateSpec; import org.apache.beam.sdk.state.Timer; @@ -193,6 +194,8 @@ public ResultT match(Cases cases) { return cases.dispatch((StateParameter) this); } else if (this instanceof TimerParameter) { return cases.dispatch((TimerParameter) this); + } else if (this instanceof PipelineOptionsParameter) { + return cases.dispatch((PipelineOptionsParameter) this); } else { throw new IllegalStateException( String.format("Attempt to case match on unknown %s subclass %s", @@ -212,6 +215,7 @@ public interface Cases { ResultT dispatch(RestrictionTrackerParameter p); ResultT dispatch(StateParameter p); ResultT dispatch(TimerParameter p); + ResultT dispatch(PipelineOptionsParameter p); /** * A base class for a visitor with a default method for cases it is not interested in. @@ -259,6 +263,11 @@ public ResultT dispatch(StateParameter p) { public ResultT dispatch(TimerParameter p) { return dispatchDefault(p); } + + @Override + public ResultT dispatch(PipelineOptionsParameter p) { + return dispatchDefault(p); + } } } @@ -287,6 +296,11 @@ public static WindowParameter boundedWindow(TypeDescriptor> ALLOWED_SPLITTABLE_PROCESS_ELEMENT_PARAMETERS = ImmutableList.of( - Parameter.ProcessContextParameter.class, Parameter.RestrictionTrackerParameter.class); + Parameter.PipelineOptionsParameter.class, + Parameter.ProcessContextParameter.class, + Parameter.RestrictionTrackerParameter.class); private static final Collection> ALLOWED_ON_TIMER_PARAMETERS = ImmutableList.of( Parameter.OnTimerContextParameter.class, Parameter.WindowParameter.class, + Parameter.PipelineOptionsParameter.class, Parameter.TimerParameter.class, Parameter.StateParameter.class); @@ -187,6 +192,15 @@ public boolean hasWindowParameter() { extraParameters, Predicates.instanceOf(WindowParameter.class)); } + /** + * Indicates whether a {@link Parameter.PipelineOptionsParameter} is + * known in this context. + */ + public boolean hasPipelineOptionsParamter() { + return Iterables.any( + extraParameters, Predicates.instanceOf(Parameter.PipelineOptionsParameter.class)); + } + /** The window type, if any, used by this method. */ @Nullable public TypeDescriptor getWindowType() { @@ -789,6 +803,12 @@ private static Parameter analyzeExtraParameter( "Multiple %s parameters", BoundedWindow.class.getSimpleName()); return Parameter.boundedWindow((TypeDescriptor) paramT); + } else if (PipelineOptions.class.equals(rawType)) { + methodErrors.checkArgument( + !methodContext.hasPipelineOptionsParamter(), + "Multiple %s parameters", + PipelineOptions.class.getSimpleName()); + return Parameter.pipelineOptions(); } else if (RestrictionTracker.class.isAssignableFrom(rawType)) { methodErrors.checkArgument( !methodContext.hasRestrictionTrackerParameter(), diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java index c67cf2a758fe7..5b60ef3ed03b1 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java @@ -62,6 +62,8 @@ import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.coders.VarIntCoder; import org.apache.beam.sdk.io.GenerateSequence; +import org.apache.beam.sdk.options.Default; +import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.state.BagState; import org.apache.beam.sdk.state.CombiningState; import org.apache.beam.sdk.state.MapState; @@ -2942,4 +2944,65 @@ public void onTimer(BoundedWindow w) {} // If it doesn't crash, we made it! } + + /** A {@link PipelineOptions} subclass for testing passing to a {@link DoFn}. */ + public interface MyOptions extends PipelineOptions { + @Default.String("fake option") + String getFakeOption(); + void setFakeOption(String value); + } + + @Test + @Category(ValidatesRunner.class) + public void testPipelineOptionsParameter() { + PCollection results = pipeline + .apply(Create.of(1)) + .apply( + ParDo.of( + new DoFn() { + @ProcessElement + public void process(ProcessContext c, PipelineOptions options) { + c.output(options.as(MyOptions.class).getFakeOption()); + } + })); + + String testOptionValue = "not fake anymore"; + pipeline.getOptions().as(MyOptions.class).setFakeOption(testOptionValue); + PAssert.that(results).containsInAnyOrder("not fake anymore"); + + pipeline.run(); + } + + @Test + @Category({ValidatesRunner.class, UsesTimersInParDo.class}) + public void testPipelineOptionsParameterOnTimer() { + final String timerId = "thisTimer"; + + PCollection results = + pipeline + .apply(Create.of(KV.of(0, 0))) + .apply( + ParDo.of( + new DoFn, String>() { + @TimerId(timerId) + private final TimerSpec spec = TimerSpecs.timer(TimeDomain.EVENT_TIME); + + @ProcessElement + public void process( + ProcessContext c, BoundedWindow w, @TimerId(timerId) Timer timer) { + timer.set(w.maxTimestamp()); + } + + @OnTimer(timerId) + public void onTimer(OnTimerContext c, PipelineOptions options) { + c.output(options.as(MyOptions.class).getFakeOption()); + } + })); + + String testOptionValue = "not fake anymore"; + pipeline.getOptions().as(MyOptions.class).setFakeOption(testOptionValue); + PAssert.that(results).containsInAnyOrder("not fake anymore"); + + pipeline.run(); + } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java index cffb0adf613a0..70c8dfdb312f4 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java @@ -29,6 +29,7 @@ import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.coders.VarIntCoder; import org.apache.beam.sdk.coders.VarLongCoder; +import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.state.StateSpec; import org.apache.beam.sdk.state.StateSpecs; import org.apache.beam.sdk.state.TimeDomain; @@ -328,6 +329,19 @@ public void onTimer(BoundedWindow w) {} instanceOf(WindowParameter.class)); } + @Test + public void testPipelineOptionsParameter() throws Exception { + DoFnSignature sig = + DoFnSignatures.getSignature(new DoFn() { + @ProcessElement + public void process(ProcessContext c, PipelineOptions options) {} + }.getClass()); + + assertThat( + sig.processElement().extraParameters(), + Matchers.hasItem(instanceOf(Parameter.PipelineOptionsParameter.class))); + } + @Test public void testDeclAndUsageOfTimerInSuperclass() throws Exception { DoFnSignature sig = From f99ab1a472868e4ce175a86e5d76823b1c09c10b Mon Sep 17 00:00:00 2001 From: Kenneth Knowles Date: Fri, 30 Jun 2017 21:42:17 -0700 Subject: [PATCH 151/200] Fix DoFn javadoc: StateSpec does not require a key --- .../core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java index fb6d0ee4ffe58..a2e5c162c7cb5 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java @@ -385,7 +385,7 @@ public interface OutputReceiver { *

    {@literal new DoFn, Baz>()} {
        *
        *  {@literal @StateId("my-state-id")}
    -   *  {@literal private final StateSpec>} myStateSpec =
    +   *  {@literal private final StateSpec>} myStateSpec =
        *       StateSpecs.value(new MyStateCoder());
        *
        *  {@literal @ProcessElement}
    
    From ce4e51747501111ae2c4b1691c6994bd0f92e161 Mon Sep 17 00:00:00 2001
    From: =?UTF-8?q?Isma=C3=ABl=20Mej=C3=ADa?= 
    Date: Sun, 4 Jun 2017 22:55:05 +0200
    Subject: [PATCH 152/200] Make modules that depend on Hadoop and Spark use the
     same version property
    
    ---
     examples/java/pom.xml                         | 18 +++--
     examples/java8/pom.xml                        | 18 +++--
     pom.xml                                       | 65 ++++++++++++++++++-
     runners/apex/pom.xml                          |  2 +-
     runners/spark/pom.xml                         |  7 --
     sdks/java/extensions/sorter/pom.xml           |  6 --
     sdks/java/io/hadoop-file-system/pom.xml       | 31 ---------
     sdks/java/io/hadoop/jdk1.8-tests/pom.xml      |  2 -
     sdks/java/io/hbase/pom.xml                    |  9 ++-
     sdks/java/io/hcatalog/pom.xml                 |  6 +-
     sdks/java/io/jdbc/pom.xml                     |  2 -
     sdks/java/io/pom.xml                          | 31 ---------
     sdks/java/javadoc/pom.xml                     |  2 -
     .../resources/archetype-resources/pom.xml     |  1 -
     .../resources/archetype-resources/pom.xml     |  1 -
     15 files changed, 98 insertions(+), 103 deletions(-)
    
    diff --git a/examples/java/pom.xml b/examples/java/pom.xml
    index 701e4fe76cce7..7ae4e6ad36e66 100644
    --- a/examples/java/pom.xml
    +++ b/examples/java/pom.xml
    @@ -34,10 +34,6 @@
     
       jar
     
    -  
    -    1.6.2
    -  
    -
       
     
         
    +    2.7.3
         1.3
         2.8.9
         3.0.1
    @@ -139,7 +145,7 @@
         v1-rev10-1.22.0
         1.7.14
         0.20.0-beta
    -    1.6.2
    +    1.6.3
         4.3.5.RELEASE
         3.1.4
         v1-rev71-1.22.0
    @@ -1075,6 +1081,42 @@
             ${snappy-java.version}
           
     
    +      
    +        org.apache.hadoop
    +        hadoop-client
    +        ${hadoop.version}
    +      
    +
    +      
    +        org.apache.hadoop
    +        hadoop-common
    +        ${hadoop.version}
    +      
    +
    +      
    +        org.apache.hadoop
    +        hadoop-mapreduce-client-core
    +        ${hadoop.version}
    +      
    +
    +      
    +        org.apache.spark
    +        spark-core_2.10
    +        ${spark.version}
    +      
    +
    +      
    +        org.apache.spark
    +        spark-streaming_2.10
    +        ${spark.version}
    +      
    +
    +      
    +        org.apache.spark
    +        spark-network-common_2.10
    +        ${spark.version}
    +      
    +
           
     
           
    @@ -1144,6 +1186,27 @@
             test
           
     
    +      
    +        org.apache.hadoop
    +        hadoop-minicluster
    +        ${hadoop.version}
    +        test
    +      
    +
    +      
    +        org.apache.hadoop
    +        hadoop-hdfs
    +        ${hadoop.version}
    +        test
    +      
    +
    +      
    +        org.apache.hadoop
    +        hadoop-hdfs
    +        ${hadoop.version}
    +        tests
    +        test
    +      
         
       
     
    diff --git a/runners/apex/pom.xml b/runners/apex/pom.xml
    index 2c5465499995a..88ff0f2d937ba 100644
    --- a/runners/apex/pom.xml
    +++ b/runners/apex/pom.xml
    @@ -261,7 +261,7 @@
                     com.esotericsoftware.kryo:kryo::${apex.kryo.version}
                     com.datatorrent:netlet::1.3.0
                     org.slf4j:slf4j-api:jar:1.7.14
    -                org.apache.hadoop:hadoop-common:jar:2.6.0
    +                org.apache.hadoop:hadoop-common:jar:${hadoop.version}
                     joda-time:joda-time:jar:2.4
                     com.google.guava:guava:jar:20.0
                   
    diff --git a/runners/spark/pom.xml b/runners/spark/pom.xml
    index ee72dd96fdf5c..1d934273ac80e 100644
    --- a/runners/spark/pom.xml
    +++ b/runners/spark/pom.xml
    @@ -34,8 +34,6 @@
       
         UTF-8
         UTF-8
    -    1.6.3
    -    2.2.0
         0.9.0.1
         2.4.4
         3.1.2
    @@ -135,31 +133,26 @@
         
           org.apache.spark
           spark-core_2.10
    -      ${spark.version}
           provided
         
         
           org.apache.spark
           spark-streaming_2.10
    -      ${spark.version}
           provided
         
         
           org.apache.spark
           spark-network-common_2.10
    -      ${spark.version}
           provided
         
         
           org.apache.hadoop
           hadoop-common
    -      ${hadoop.version}
           provided
         
         
           org.apache.hadoop
           hadoop-mapreduce-client-core
    -      ${hadoop.version}
           provided
         
         
    diff --git a/sdks/java/extensions/sorter/pom.xml b/sdks/java/extensions/sorter/pom.xml
    index 9d25f9d991ddb..ac61f76e6d237 100644
    --- a/sdks/java/extensions/sorter/pom.xml
    +++ b/sdks/java/extensions/sorter/pom.xml
    @@ -29,10 +29,6 @@
       beam-sdks-java-extensions-sorter
       Apache Beam :: SDKs :: Java :: Extensions :: Sorter
     
    -  
    -    2.7.1
    -  
    -
       
         
           org.apache.beam
    @@ -42,14 +38,12 @@
         
           org.apache.hadoop
           hadoop-mapreduce-client-core
    -      ${hadoop.version}
           provided
         
         
         
           org.apache.hadoop
           hadoop-common
    -      ${hadoop.version}
           provided
         
         
    diff --git a/sdks/java/io/hadoop-file-system/pom.xml b/sdks/java/io/hadoop-file-system/pom.xml
    index db5a1db1783ea..a54977e8d822c 100644
    --- a/sdks/java/io/hadoop-file-system/pom.xml
    +++ b/sdks/java/io/hadoop-file-system/pom.xml
    @@ -44,37 +44,6 @@
         
       
     
    -  
    -    
    -    2.7.3
    -  
    -
    -  
    -    
    -    
    -      
    -        org.apache.hadoop
    -        hadoop-hdfs
    -        tests
    -        ${hadoop.version}
    -      
    -
    -      
    -        org.apache.hadoop
    -        hadoop-minicluster
    -        ${hadoop.version}
    -      
    -    
    -  
    -
       
         
           org.apache.beam
    diff --git a/sdks/java/io/hadoop/jdk1.8-tests/pom.xml b/sdks/java/io/hadoop/jdk1.8-tests/pom.xml
    index 9f84e881de0c6..baaa9821abb64 100644
    --- a/sdks/java/io/hadoop/jdk1.8-tests/pom.xml
    +++ b/sdks/java/io/hadoop/jdk1.8-tests/pom.xml
    @@ -108,13 +108,11 @@
             
               org.apache.spark
               spark-streaming_2.10
    -          ${spark.version}
               runtime
             
             
               org.apache.spark
               spark-core_2.10
    -          ${spark.version}
               runtime
               
                 
    diff --git a/sdks/java/io/hbase/pom.xml b/sdks/java/io/hbase/pom.xml
    index 4d9d600f246de..9d5e2aad86096 100644
    --- a/sdks/java/io/hbase/pom.xml
    +++ b/sdks/java/io/hbase/pom.xml
    @@ -32,7 +32,6 @@
     
       
         1.2.6
    -    2.5.1
       
     
       
    @@ -109,14 +108,18 @@
         
           org.apache.hadoop
           hadoop-minicluster
    -      ${hbase.hadoop.version}
    +      test
    +    
    +
    +    
    +      org.apache.hadoop
    +      hadoop-hdfs
           test
         
     
         
           org.apache.hadoop
           hadoop-common
    -      ${hbase.hadoop.version}
           test
         
     
    diff --git a/sdks/java/io/hcatalog/pom.xml b/sdks/java/io/hcatalog/pom.xml
    index 19b62a5826285..8af740d8b1fb0 100644
    --- a/sdks/java/io/hcatalog/pom.xml
    +++ b/sdks/java/io/hcatalog/pom.xml
    @@ -39,14 +39,14 @@
         
           
             org.apache.maven.plugins
    -        maven-surefire-plugin
    +        maven-shade-plugin
             
    -          true
    +          false
             
           
         
       
    -  
    +
       
         
           org.apache.beam
    diff --git a/sdks/java/io/jdbc/pom.xml b/sdks/java/io/jdbc/pom.xml
    index 17c26a058f678..45ec06c73ebe7 100644
    --- a/sdks/java/io/jdbc/pom.xml
    +++ b/sdks/java/io/jdbc/pom.xml
    @@ -49,13 +49,11 @@
             
               org.apache.spark
               spark-streaming_2.10
    -          ${spark.version}
               runtime
             
             
               org.apache.spark
               spark-core_2.10
    -          ${spark.version}
               runtime
               
                 
    diff --git a/sdks/java/io/pom.xml b/sdks/java/io/pom.xml
    index e5db41b726292..458dfaf6c79a6 100644
    --- a/sdks/java/io/pom.xml
    +++ b/sdks/java/io/pom.xml
    @@ -32,37 +32,6 @@
       Beam SDK Java IO provides different connectivity components
       (sources and sinks) to consume and produce data from systems.
     
    -  
    -    
    -    2.7.3
    -  
    -
    -  
    -    
    -      
    -        org.apache.hadoop
    -        hadoop-client
    -        ${hadoop.version}
    -      
    -
    -      
    -        org.apache.hadoop
    -        hadoop-common
    -        ${hadoop.version}
    -      
    -
    -      
    -        org.apache.hadoop
    -        hadoop-mapreduce-client-core
    -        ${hadoop.version}
    -      
    -    
    -  
    -
       
         amqp
         cassandra
    diff --git a/sdks/java/javadoc/pom.xml b/sdks/java/javadoc/pom.xml
    index 54dae3ad7268d..08d5ec6eb0896 100644
    --- a/sdks/java/javadoc/pom.xml
    +++ b/sdks/java/javadoc/pom.xml
    @@ -196,13 +196,11 @@
         
           org.apache.spark
           spark-core_2.10
    -      ${spark.version}
         
     
         
           org.apache.spark
           spark-streaming_2.10
    -      ${spark.version}
         
       
     
    diff --git a/sdks/java/maven-archetypes/examples-java8/src/main/resources/archetype-resources/pom.xml b/sdks/java/maven-archetypes/examples-java8/src/main/resources/archetype-resources/pom.xml
    index af4fbd3832e38..45178614af421 100644
    --- a/sdks/java/maven-archetypes/examples-java8/src/main/resources/archetype-resources/pom.xml
    +++ b/sdks/java/maven-archetypes/examples-java8/src/main/resources/archetype-resources/pom.xml
    @@ -242,7 +242,6 @@
             
               org.apache.spark
               spark-streaming_2.10
    -          ${spark.version}
               runtime
               
                 
    diff --git a/sdks/java/maven-archetypes/examples/src/main/resources/archetype-resources/pom.xml b/sdks/java/maven-archetypes/examples/src/main/resources/archetype-resources/pom.xml
    index b8b9c9f0fa492..d039ddba5a896 100644
    --- a/sdks/java/maven-archetypes/examples/src/main/resources/archetype-resources/pom.xml
    +++ b/sdks/java/maven-archetypes/examples/src/main/resources/archetype-resources/pom.xml
    @@ -241,7 +241,6 @@
             
               org.apache.spark
               spark-streaming_2.10
    -          ${spark.version}
               runtime
               
                 
    
    From 75475ef3dc23a09fa9bbba478d6fdbc468f7dd2e Mon Sep 17 00:00:00 2001
    From: =?UTF-8?q?Isma=C3=ABl=20Mej=C3=ADa?= 
    Date: Wed, 28 Jun 2017 16:58:55 +0200
    Subject: [PATCH 153/200] [BEAM-2530] Fix compilation of modules with Java 9
     that depend on jdk.tools
    
    ---
     runners/apex/pom.xml          |  7 +++++++
     runners/spark/pom.xml         |  7 +++++++
     sdks/java/io/hbase/pom.xml    |  7 +++++++
     sdks/java/io/hcatalog/pom.xml | 12 ++++++++++++
     4 files changed, 33 insertions(+)
    
    diff --git a/runners/apex/pom.xml b/runners/apex/pom.xml
    index 88ff0f2d937ba..20f2d281d5783 100644
    --- a/runners/apex/pom.xml
    +++ b/runners/apex/pom.xml
    @@ -75,6 +75,13 @@
           apex-engine
           ${apex.core.version}
           runtime
    +      
    +        
    +        
    +          jdk.tools
    +          jdk.tools
    +        
    +      
         
     
         
    diff --git a/runners/spark/pom.xml b/runners/spark/pom.xml
    index 1d934273ac80e..8a69496639ace 100644
    --- a/runners/spark/pom.xml
    +++ b/runners/spark/pom.xml
    @@ -149,6 +149,13 @@
           org.apache.hadoop
           hadoop-common
           provided
    +      
    +        
    +        
    +          jdk.tools
    +          jdk.tools
    +        
    +      
         
         
           org.apache.hadoop
    diff --git a/sdks/java/io/hbase/pom.xml b/sdks/java/io/hbase/pom.xml
    index 9d5e2aad86096..40ac8dfc0bccf 100644
    --- a/sdks/java/io/hbase/pom.xml
    +++ b/sdks/java/io/hbase/pom.xml
    @@ -121,6 +121,13 @@
           org.apache.hadoop
           hadoop-common
           test
    +      
    +        
    +        
    +          jdk.tools
    +          jdk.tools
    +        
    +      
         
     
         
    diff --git a/sdks/java/io/hcatalog/pom.xml b/sdks/java/io/hcatalog/pom.xml
    index 8af740d8b1fb0..a31ff86d89f22 100644
    --- a/sdks/java/io/hcatalog/pom.xml
    +++ b/sdks/java/io/hcatalog/pom.xml
    @@ -61,6 +61,13 @@
         
           org.apache.hadoop
           hadoop-common
    +      
    +        
    +        
    +          jdk.tools
    +          jdk.tools
    +        
    +      
         
     
         
    @@ -109,6 +116,11 @@
               com.google.protobuf
               protobuf-java
             
    +        
    +        
    +          jdk.tools
    +          jdk.tools
    +        
           
         
     
    
    From 68f1fb64fd2565e287e322d715ca778d01e7137b Mon Sep 17 00:00:00 2001
    From: Ahmet Altay 
    Date: Fri, 30 Jun 2017 17:37:33 -0700
    Subject: [PATCH 154/200] For GCS operations use an http client with a default
     timeout value.
    
    ---
     sdks/python/apache_beam/io/gcp/gcsio.py | 10 +++++++++-
     1 file changed, 9 insertions(+), 1 deletion(-)
    
    diff --git a/sdks/python/apache_beam/io/gcp/gcsio.py b/sdks/python/apache_beam/io/gcp/gcsio.py
    index d43c8ba0d9728..643fbc75c0020 100644
    --- a/sdks/python/apache_beam/io/gcp/gcsio.py
    +++ b/sdks/python/apache_beam/io/gcp/gcsio.py
    @@ -31,6 +31,7 @@
     import threading
     import time
     import traceback
    +import httplib2
     
     from apache_beam.utils import retry
     
    @@ -68,6 +69,10 @@
     # +---------------+------------+-------------+-------------+-------------+
     DEFAULT_READ_BUFFER_SIZE = 16 * 1024 * 1024
     
    +# This is the number of seconds the library will wait for GCS operations to
    +# complete.
    +DEFAULT_HTTP_TIMEOUT_SECONDS = 60
    +
     # This is the number of seconds the library will wait for a partial-file read
     # operation from GCS to complete before retrying.
     DEFAULT_READ_SEGMENT_TIMEOUT_SECONDS = 60
    @@ -99,6 +104,7 @@ class GcsIO(object):
     
       def __new__(cls, storage_client=None):
         if storage_client:
    +      # This path is only used for testing.
           return super(GcsIO, cls).__new__(cls, storage_client)
         else:
           # Create a single storage client for each thread.  We would like to avoid
    @@ -108,7 +114,9 @@ def __new__(cls, storage_client=None):
           local_state = threading.local()
           if getattr(local_state, 'gcsio_instance', None) is None:
             credentials = auth.get_service_credentials()
    -        storage_client = storage.StorageV1(credentials=credentials)
    +        storage_client = storage.StorageV1(
    +            credentials=credentials,
    +            http=httplib2.Http(timeout=DEFAULT_HTTP_TIMEOUT_SECONDS))
             local_state.gcsio_instance = (
                 super(GcsIO, cls).__new__(cls, storage_client))
             local_state.gcsio_instance.client = storage_client
    
    From 51877a3405dbf778c3bb88f19bb194e54c3b3def Mon Sep 17 00:00:00 2001
    From: =?UTF-8?q?Jean-Baptiste=20Onofr=C3=A9?= 
    Date: Wed, 5 Jul 2017 16:47:29 +0200
    Subject: [PATCH 155/200] [maven-release-plugin] prepare branch release-2.1.0
    
    ---
     pom.xml                     | 2 +-
     runners/direct-java/pom.xml | 2 +-
     2 files changed, 2 insertions(+), 2 deletions(-)
    
    diff --git a/pom.xml b/pom.xml
    index c0207ef60ecf1..057954aadceac 100644
    --- a/pom.xml
    +++ b/pom.xml
    @@ -48,7 +48,7 @@
         scm:git:https://git-wip-us.apache.org/repos/asf/beam.git
         scm:git:https://git-wip-us.apache.org/repos/asf/beam.git
         https://git-wip-us.apache.org/repos/asf?p=beam.git;a=summary
    -    HEAD
    +    release-2.1.0
       
     
       
    diff --git a/runners/direct-java/pom.xml b/runners/direct-java/pom.xml
    index 63465757ff14a..5b5aec22e2f5a 100644
    --- a/runners/direct-java/pom.xml
    +++ b/runners/direct-java/pom.xml
    @@ -117,7 +117,7 @@
                       
                     
                     
    -                  
    +                  
                     
                   
                 
    
    From 7f0723cf7fd587c4f8dbe8a8b1c9d298a7b1e5e3 Mon Sep 17 00:00:00 2001
    From: =?UTF-8?q?Jean-Baptiste=20Onofr=C3=A9?= 
    Date: Wed, 5 Jul 2017 16:47:38 +0200
    Subject: [PATCH 156/200] [maven-release-plugin] prepare for next development
     iteration
    
    ---
     examples/java/pom.xml                                   | 2 +-
     examples/java8/pom.xml                                  | 2 +-
     examples/pom.xml                                        | 2 +-
     pom.xml                                                 | 4 ++--
     runners/apex/pom.xml                                    | 2 +-
     runners/core-construction-java/pom.xml                  | 2 +-
     runners/core-java/pom.xml                               | 2 +-
     runners/direct-java/pom.xml                             | 2 +-
     runners/flink/pom.xml                                   | 2 +-
     runners/google-cloud-dataflow-java/pom.xml              | 2 +-
     runners/pom.xml                                         | 2 +-
     runners/spark/pom.xml                                   | 2 +-
     sdks/common/fn-api/pom.xml                              | 2 +-
     sdks/common/pom.xml                                     | 2 +-
     sdks/common/runner-api/pom.xml                          | 2 +-
     sdks/java/build-tools/pom.xml                           | 2 +-
     sdks/java/core/pom.xml                                  | 2 +-
     sdks/java/extensions/google-cloud-platform-core/pom.xml | 2 +-
     sdks/java/extensions/jackson/pom.xml                    | 2 +-
     sdks/java/extensions/join-library/pom.xml               | 2 +-
     sdks/java/extensions/pom.xml                            | 2 +-
     sdks/java/extensions/protobuf/pom.xml                   | 2 +-
     sdks/java/extensions/sorter/pom.xml                     | 2 +-
     sdks/java/harness/pom.xml                               | 2 +-
     sdks/java/io/amqp/pom.xml                               | 2 +-
     sdks/java/io/cassandra/pom.xml                          | 2 +-
     sdks/java/io/common/pom.xml                             | 2 +-
     sdks/java/io/elasticsearch/pom.xml                      | 2 +-
     sdks/java/io/google-cloud-platform/pom.xml              | 2 +-
     sdks/java/io/hadoop-common/pom.xml                      | 2 +-
     sdks/java/io/hadoop-file-system/pom.xml                 | 2 +-
     sdks/java/io/hadoop/input-format/pom.xml                | 2 +-
     sdks/java/io/hadoop/jdk1.8-tests/pom.xml                | 2 +-
     sdks/java/io/hadoop/pom.xml                             | 2 +-
     sdks/java/io/hbase/pom.xml                              | 2 +-
     sdks/java/io/hcatalog/pom.xml                           | 2 +-
     sdks/java/io/jdbc/pom.xml                               | 2 +-
     sdks/java/io/jms/pom.xml                                | 2 +-
     sdks/java/io/kafka/pom.xml                              | 2 +-
     sdks/java/io/kinesis/pom.xml                            | 2 +-
     sdks/java/io/mongodb/pom.xml                            | 2 +-
     sdks/java/io/mqtt/pom.xml                               | 2 +-
     sdks/java/io/pom.xml                                    | 2 +-
     sdks/java/io/xml/pom.xml                                | 2 +-
     sdks/java/java8tests/pom.xml                            | 2 +-
     sdks/java/javadoc/pom.xml                               | 2 +-
     sdks/java/maven-archetypes/examples-java8/pom.xml       | 2 +-
     sdks/java/maven-archetypes/examples/pom.xml             | 2 +-
     sdks/java/maven-archetypes/pom.xml                      | 2 +-
     sdks/java/maven-archetypes/starter/pom.xml              | 2 +-
     sdks/java/pom.xml                                       | 2 +-
     sdks/pom.xml                                            | 2 +-
     sdks/python/pom.xml                                     | 2 +-
     53 files changed, 54 insertions(+), 54 deletions(-)
    
    diff --git a/examples/java/pom.xml b/examples/java/pom.xml
    index 7ae4e6ad36e66..ae64a79340d0b 100644
    --- a/examples/java/pom.xml
    +++ b/examples/java/pom.xml
    @@ -22,7 +22,7 @@
       
         org.apache.beam
         beam-examples-parent
    -    2.1.0-SNAPSHOT
    +    2.2.0-SNAPSHOT
         ../pom.xml
       
     
    diff --git a/examples/java8/pom.xml b/examples/java8/pom.xml
    index a0ce708b6125b..6fd29a496bb82 100644
    --- a/examples/java8/pom.xml
    +++ b/examples/java8/pom.xml
    @@ -22,7 +22,7 @@
       
         org.apache.beam
         beam-examples-parent
    -    2.1.0-SNAPSHOT
    +    2.2.0-SNAPSHOT
         ../pom.xml
       
     
    diff --git a/examples/pom.xml b/examples/pom.xml
    index a7e61dd2ff1e3..51f4c35030e25 100644
    --- a/examples/pom.xml
    +++ b/examples/pom.xml
    @@ -22,7 +22,7 @@
       
         org.apache.beam
         beam-parent
    -    2.1.0-SNAPSHOT
    +    2.2.0-SNAPSHOT
         ../pom.xml
       
     
    diff --git a/pom.xml b/pom.xml
    index 057954aadceac..49e6eddc8ac00 100644
    --- a/pom.xml
    +++ b/pom.xml
    @@ -34,7 +34,7 @@
       http://beam.apache.org/
       2016
     
    -  2.1.0-SNAPSHOT
    +  2.2.0-SNAPSHOT
     
       
         
    @@ -48,7 +48,7 @@
         scm:git:https://git-wip-us.apache.org/repos/asf/beam.git
         scm:git:https://git-wip-us.apache.org/repos/asf/beam.git
         https://git-wip-us.apache.org/repos/asf?p=beam.git;a=summary
    -    release-2.1.0
    +    HEAD
       
     
       
    diff --git a/runners/apex/pom.xml b/runners/apex/pom.xml
    index 20f2d281d5783..fd5aafb992be6 100644
    --- a/runners/apex/pom.xml
    +++ b/runners/apex/pom.xml
    @@ -22,7 +22,7 @@
       
         org.apache.beam
         beam-runners-parent
    -    2.1.0-SNAPSHOT
    +    2.2.0-SNAPSHOT
         ../pom.xml
       
     
    diff --git a/runners/core-construction-java/pom.xml b/runners/core-construction-java/pom.xml
    index 67951e9334f41..b85b5f5ff5413 100644
    --- a/runners/core-construction-java/pom.xml
    +++ b/runners/core-construction-java/pom.xml
    @@ -24,7 +24,7 @@
       
         beam-runners-parent
         org.apache.beam
    -    2.1.0-SNAPSHOT
    +    2.2.0-SNAPSHOT
         ../pom.xml
       
     
    diff --git a/runners/core-java/pom.xml b/runners/core-java/pom.xml
    index c3a8d254f4dd9..8c8e5996627e0 100644
    --- a/runners/core-java/pom.xml
    +++ b/runners/core-java/pom.xml
    @@ -22,7 +22,7 @@
       
         org.apache.beam
         beam-runners-parent
    -    2.1.0-SNAPSHOT
    +    2.2.0-SNAPSHOT
         ../pom.xml
       
     
    diff --git a/runners/direct-java/pom.xml b/runners/direct-java/pom.xml
    index 5b5aec22e2f5a..0e1f73a4f3cb5 100644
    --- a/runners/direct-java/pom.xml
    +++ b/runners/direct-java/pom.xml
    @@ -22,7 +22,7 @@
       
         org.apache.beam
         beam-runners-parent
    -    2.1.0-SNAPSHOT
    +    2.2.0-SNAPSHOT
         ../pom.xml
       
     
    diff --git a/runners/flink/pom.xml b/runners/flink/pom.xml
    index 339aa8e445a97..c063a2de425e2 100644
    --- a/runners/flink/pom.xml
    +++ b/runners/flink/pom.xml
    @@ -22,7 +22,7 @@
       
         org.apache.beam
         beam-runners-parent
    -    2.1.0-SNAPSHOT
    +    2.2.0-SNAPSHOT
         ../pom.xml
       
     
    diff --git a/runners/google-cloud-dataflow-java/pom.xml b/runners/google-cloud-dataflow-java/pom.xml
    index 2ba163bdf84e0..91908cdcbbcf4 100644
    --- a/runners/google-cloud-dataflow-java/pom.xml
    +++ b/runners/google-cloud-dataflow-java/pom.xml
    @@ -22,7 +22,7 @@
       
         org.apache.beam
         beam-runners-parent
    -    2.1.0-SNAPSHOT
    +    2.2.0-SNAPSHOT
         ../pom.xml
       
     
    diff --git a/runners/pom.xml b/runners/pom.xml
    index 38aada80aa2d9..b00ba9ccc56a8 100644
    --- a/runners/pom.xml
    +++ b/runners/pom.xml
    @@ -22,7 +22,7 @@
       
         org.apache.beam
         beam-parent
    -    2.1.0-SNAPSHOT
    +    2.2.0-SNAPSHOT
         ../pom.xml
       
     
    diff --git a/runners/spark/pom.xml b/runners/spark/pom.xml
    index 8a69496639ace..7f70204e00943 100644
    --- a/runners/spark/pom.xml
    +++ b/runners/spark/pom.xml
    @@ -22,7 +22,7 @@
       
         org.apache.beam
         beam-runners-parent
    -    2.1.0-SNAPSHOT
    +    2.2.0-SNAPSHOT
         ../pom.xml
       
     
    diff --git a/sdks/common/fn-api/pom.xml b/sdks/common/fn-api/pom.xml
    index 77a9ba52e40b1..6810667d9091b 100644
    --- a/sdks/common/fn-api/pom.xml
    +++ b/sdks/common/fn-api/pom.xml
    @@ -22,7 +22,7 @@
       
         org.apache.beam
         beam-sdks-common-parent
    -    2.1.0-SNAPSHOT
    +    2.2.0-SNAPSHOT
         ../pom.xml
       
     
    diff --git a/sdks/common/pom.xml b/sdks/common/pom.xml
    index c621ed51087b7..40eefa7cc5cd9 100644
    --- a/sdks/common/pom.xml
    +++ b/sdks/common/pom.xml
    @@ -22,7 +22,7 @@
       
         org.apache.beam
         beam-sdks-parent
    -    2.1.0-SNAPSHOT
    +    2.2.0-SNAPSHOT
         ../pom.xml
       
     
    diff --git a/sdks/common/runner-api/pom.xml b/sdks/common/runner-api/pom.xml
    index f5536a76a8ce8..8bc4123cb541d 100644
    --- a/sdks/common/runner-api/pom.xml
    +++ b/sdks/common/runner-api/pom.xml
    @@ -22,7 +22,7 @@
       
         org.apache.beam
         beam-sdks-common-parent
    -    2.1.0-SNAPSHOT
    +    2.2.0-SNAPSHOT
         ../pom.xml
       
     
    diff --git a/sdks/java/build-tools/pom.xml b/sdks/java/build-tools/pom.xml
    index 5a2c498553dbb..d7d25f65ed702 100644
    --- a/sdks/java/build-tools/pom.xml
    +++ b/sdks/java/build-tools/pom.xml
    @@ -22,7 +22,7 @@
       
         org.apache.beam
         beam-parent
    -    2.1.0-SNAPSHOT
    +    2.2.0-SNAPSHOT
         ../../../pom.xml
       
     
    diff --git a/sdks/java/core/pom.xml b/sdks/java/core/pom.xml
    index 11b68e664eac2..3f12dc48f856a 100644
    --- a/sdks/java/core/pom.xml
    +++ b/sdks/java/core/pom.xml
    @@ -22,7 +22,7 @@
       
         org.apache.beam
         beam-sdks-java-parent
    -    2.1.0-SNAPSHOT
    +    2.2.0-SNAPSHOT
         ../pom.xml
       
     
    diff --git a/sdks/java/extensions/google-cloud-platform-core/pom.xml b/sdks/java/extensions/google-cloud-platform-core/pom.xml
    index e4e951b20f547..7d54990f9c3d1 100644
    --- a/sdks/java/extensions/google-cloud-platform-core/pom.xml
    +++ b/sdks/java/extensions/google-cloud-platform-core/pom.xml
    @@ -22,7 +22,7 @@
       
         org.apache.beam
         beam-sdks-java-extensions-parent
    -    2.1.0-SNAPSHOT
    +    2.2.0-SNAPSHOT
         ../pom.xml
       
     
    diff --git a/sdks/java/extensions/jackson/pom.xml b/sdks/java/extensions/jackson/pom.xml
    index 4b09c1176fafe..7fd38e0d4198a 100644
    --- a/sdks/java/extensions/jackson/pom.xml
    +++ b/sdks/java/extensions/jackson/pom.xml
    @@ -22,7 +22,7 @@
       
         org.apache.beam
         beam-sdks-java-extensions-parent
    -    2.1.0-SNAPSHOT
    +    2.2.0-SNAPSHOT
         ../pom.xml
       
     
    diff --git a/sdks/java/extensions/join-library/pom.xml b/sdks/java/extensions/join-library/pom.xml
    index 556ec4086ec91..ea24b7510c296 100644
    --- a/sdks/java/extensions/join-library/pom.xml
    +++ b/sdks/java/extensions/join-library/pom.xml
    @@ -22,7 +22,7 @@
       
         org.apache.beam
         beam-sdks-java-extensions-parent
    -    2.1.0-SNAPSHOT
    +    2.2.0-SNAPSHOT
         ../pom.xml
       
     
    diff --git a/sdks/java/extensions/pom.xml b/sdks/java/extensions/pom.xml
    index 3d63626f30bd8..1222476ec904f 100644
    --- a/sdks/java/extensions/pom.xml
    +++ b/sdks/java/extensions/pom.xml
    @@ -22,7 +22,7 @@
       
         org.apache.beam
         beam-sdks-java-parent
    -    2.1.0-SNAPSHOT
    +    2.2.0-SNAPSHOT
         ../pom.xml
       
     
    diff --git a/sdks/java/extensions/protobuf/pom.xml b/sdks/java/extensions/protobuf/pom.xml
    index ae909abb49b15..63855f87e9bb1 100644
    --- a/sdks/java/extensions/protobuf/pom.xml
    +++ b/sdks/java/extensions/protobuf/pom.xml
    @@ -22,7 +22,7 @@
       
         org.apache.beam
         beam-sdks-java-extensions-parent
    -    2.1.0-SNAPSHOT
    +    2.2.0-SNAPSHOT
         ../pom.xml
       
     
    diff --git a/sdks/java/extensions/sorter/pom.xml b/sdks/java/extensions/sorter/pom.xml
    index ac61f76e6d237..395c73fae1da7 100644
    --- a/sdks/java/extensions/sorter/pom.xml
    +++ b/sdks/java/extensions/sorter/pom.xml
    @@ -22,7 +22,7 @@
       
         org.apache.beam
         beam-sdks-java-extensions-parent
    -    2.1.0-SNAPSHOT
    +    2.2.0-SNAPSHOT
         ../pom.xml
       
     
    diff --git a/sdks/java/harness/pom.xml b/sdks/java/harness/pom.xml
    index a35481d7b58d7..9cfadc215edac 100644
    --- a/sdks/java/harness/pom.xml
    +++ b/sdks/java/harness/pom.xml
    @@ -23,7 +23,7 @@
       
         org.apache.beam
         beam-sdks-java-parent
    -    2.1.0-SNAPSHOT
    +    2.2.0-SNAPSHOT
         ../pom.xml
       
     
    diff --git a/sdks/java/io/amqp/pom.xml b/sdks/java/io/amqp/pom.xml
    index 45b295dfce244..8da94483bff52 100644
    --- a/sdks/java/io/amqp/pom.xml
    +++ b/sdks/java/io/amqp/pom.xml
    @@ -22,7 +22,7 @@
       
         org.apache.beam
         beam-sdks-java-io-parent
    -    2.1.0-SNAPSHOT
    +    2.2.0-SNAPSHOT
         ../pom.xml
       
     
    diff --git a/sdks/java/io/cassandra/pom.xml b/sdks/java/io/cassandra/pom.xml
    index 8249f57c0a347..c74477e9741fb 100644
    --- a/sdks/java/io/cassandra/pom.xml
    +++ b/sdks/java/io/cassandra/pom.xml
    @@ -22,7 +22,7 @@
       
         org.apache.beam
         beam-sdks-java-io-parent
    -    2.1.0-SNAPSHOT
    +    2.2.0-SNAPSHOT
         ../pom.xml
       
     
    diff --git a/sdks/java/io/common/pom.xml b/sdks/java/io/common/pom.xml
    index f7525fd87c17d..df0d94bea53c2 100644
    --- a/sdks/java/io/common/pom.xml
    +++ b/sdks/java/io/common/pom.xml
    @@ -22,7 +22,7 @@
         
             org.apache.beam
             beam-sdks-java-io-parent
    -        2.1.0-SNAPSHOT
    +        2.2.0-SNAPSHOT
             ../pom.xml
         
     
    diff --git a/sdks/java/io/elasticsearch/pom.xml b/sdks/java/io/elasticsearch/pom.xml
    index c8e308c3ceac8..e0a7f21e0f23a 100644
    --- a/sdks/java/io/elasticsearch/pom.xml
    +++ b/sdks/java/io/elasticsearch/pom.xml
    @@ -22,7 +22,7 @@
       
         org.apache.beam
         beam-sdks-java-io-parent
    -    2.1.0-SNAPSHOT
    +    2.2.0-SNAPSHOT
         ../pom.xml
       
     
    diff --git a/sdks/java/io/google-cloud-platform/pom.xml b/sdks/java/io/google-cloud-platform/pom.xml
    index 09a430a0d146f..a1495f2df8ea9 100644
    --- a/sdks/java/io/google-cloud-platform/pom.xml
    +++ b/sdks/java/io/google-cloud-platform/pom.xml
    @@ -22,7 +22,7 @@
       
         org.apache.beam
         beam-sdks-java-io-parent
    -    2.1.0-SNAPSHOT
    +    2.2.0-SNAPSHOT
         ../pom.xml
       
     
    diff --git a/sdks/java/io/hadoop-common/pom.xml b/sdks/java/io/hadoop-common/pom.xml
    index 8749243e1e767..4bcbcd742dd0d 100644
    --- a/sdks/java/io/hadoop-common/pom.xml
    +++ b/sdks/java/io/hadoop-common/pom.xml
    @@ -22,7 +22,7 @@
       
         org.apache.beam
         beam-sdks-java-io-parent
    -    2.1.0-SNAPSHOT
    +    2.2.0-SNAPSHOT
         ../pom.xml
       
     
    diff --git a/sdks/java/io/hadoop-file-system/pom.xml b/sdks/java/io/hadoop-file-system/pom.xml
    index a54977e8d822c..a9c2e57b15dbf 100644
    --- a/sdks/java/io/hadoop-file-system/pom.xml
    +++ b/sdks/java/io/hadoop-file-system/pom.xml
    @@ -22,7 +22,7 @@
       
         org.apache.beam
         beam-sdks-java-io-parent
    -    2.1.0-SNAPSHOT
    +    2.2.0-SNAPSHOT
         ../pom.xml
       
     
    diff --git a/sdks/java/io/hadoop/input-format/pom.xml b/sdks/java/io/hadoop/input-format/pom.xml
    index 06f9f113d1f22..095311969d8e3 100644
    --- a/sdks/java/io/hadoop/input-format/pom.xml
    +++ b/sdks/java/io/hadoop/input-format/pom.xml
    @@ -20,7 +20,7 @@
       
         org.apache.beam
         beam-sdks-java-io-hadoop-parent
    -    2.1.0-SNAPSHOT
    +    2.2.0-SNAPSHOT
         ../pom.xml
       
       beam-sdks-java-io-hadoop-input-format
    diff --git a/sdks/java/io/hadoop/jdk1.8-tests/pom.xml b/sdks/java/io/hadoop/jdk1.8-tests/pom.xml
    index baaa9821abb64..12944f49c3581 100644
    --- a/sdks/java/io/hadoop/jdk1.8-tests/pom.xml
    +++ b/sdks/java/io/hadoop/jdk1.8-tests/pom.xml
    @@ -26,7 +26,7 @@
       
         org.apache.beam
         beam-sdks-java-io-hadoop-parent
    -    2.1.0-SNAPSHOT
    +    2.2.0-SNAPSHOT
         ../pom.xml
       
       beam-sdks-java-io-hadoop-jdk1.8-tests
    diff --git a/sdks/java/io/hadoop/pom.xml b/sdks/java/io/hadoop/pom.xml
    index a1c7a2e94e37b..bc3569def304d 100644
    --- a/sdks/java/io/hadoop/pom.xml
    +++ b/sdks/java/io/hadoop/pom.xml
    @@ -20,7 +20,7 @@
       
         org.apache.beam
         beam-sdks-java-io-parent
    -    2.1.0-SNAPSHOT
    +    2.2.0-SNAPSHOT
         ../pom.xml
       
       pom
    diff --git a/sdks/java/io/hbase/pom.xml b/sdks/java/io/hbase/pom.xml
    index 40ac8dfc0bccf..40f516abcb708 100644
    --- a/sdks/java/io/hbase/pom.xml
    +++ b/sdks/java/io/hbase/pom.xml
    @@ -22,7 +22,7 @@
       
         org.apache.beam
         beam-sdks-java-io-parent
    -    2.1.0-SNAPSHOT
    +    2.2.0-SNAPSHOT
         ../pom.xml
       
     
    diff --git a/sdks/java/io/hcatalog/pom.xml b/sdks/java/io/hcatalog/pom.xml
    index a31ff86d89f22..2aa661ef7edfb 100644
    --- a/sdks/java/io/hcatalog/pom.xml
    +++ b/sdks/java/io/hcatalog/pom.xml
    @@ -22,7 +22,7 @@
       
         org.apache.beam
         beam-sdks-java-io-parent
    -    2.1.0-SNAPSHOT
    +    2.2.0-SNAPSHOT
         ../pom.xml
       
     
    diff --git a/sdks/java/io/jdbc/pom.xml b/sdks/java/io/jdbc/pom.xml
    index 45ec06c73ebe7..050fc6a5facc2 100644
    --- a/sdks/java/io/jdbc/pom.xml
    +++ b/sdks/java/io/jdbc/pom.xml
    @@ -22,7 +22,7 @@
       
         org.apache.beam
         beam-sdks-java-io-parent
    -    2.1.0-SNAPSHOT
    +    2.2.0-SNAPSHOT
         ../pom.xml
       
     
    diff --git a/sdks/java/io/jms/pom.xml b/sdks/java/io/jms/pom.xml
    index 58009a102eea2..c2074afc8d866 100644
    --- a/sdks/java/io/jms/pom.xml
    +++ b/sdks/java/io/jms/pom.xml
    @@ -22,7 +22,7 @@
       
         org.apache.beam
         beam-sdks-java-io-parent
    -    2.1.0-SNAPSHOT
    +    2.2.0-SNAPSHOT
         ../pom.xml
       
     
    diff --git a/sdks/java/io/kafka/pom.xml b/sdks/java/io/kafka/pom.xml
    index 29350ccfd431e..1256c46d17bfd 100644
    --- a/sdks/java/io/kafka/pom.xml
    +++ b/sdks/java/io/kafka/pom.xml
    @@ -21,7 +21,7 @@
       
         org.apache.beam
         beam-sdks-java-io-parent
    -    2.1.0-SNAPSHOT
    +    2.2.0-SNAPSHOT
         ../pom.xml
       
     
    diff --git a/sdks/java/io/kinesis/pom.xml b/sdks/java/io/kinesis/pom.xml
    index cb7064bcdd825..46d5e2604303c 100644
    --- a/sdks/java/io/kinesis/pom.xml
    +++ b/sdks/java/io/kinesis/pom.xml
    @@ -21,7 +21,7 @@
       
         org.apache.beam
         beam-sdks-java-io-parent
    -    2.1.0-SNAPSHOT
    +    2.2.0-SNAPSHOT
         ../pom.xml
       
     
    diff --git a/sdks/java/io/mongodb/pom.xml b/sdks/java/io/mongodb/pom.xml
    index 912e20cb785bc..d93cc41b41915 100644
    --- a/sdks/java/io/mongodb/pom.xml
    +++ b/sdks/java/io/mongodb/pom.xml
    @@ -22,7 +22,7 @@
       
         org.apache.beam
         beam-sdks-java-io-parent
    -    2.1.0-SNAPSHOT
    +    2.2.0-SNAPSHOT
         ../pom.xml
       
     
    diff --git a/sdks/java/io/mqtt/pom.xml b/sdks/java/io/mqtt/pom.xml
    index baaf771dcfb24..9fa1dc07f67fc 100644
    --- a/sdks/java/io/mqtt/pom.xml
    +++ b/sdks/java/io/mqtt/pom.xml
    @@ -22,7 +22,7 @@
       
         org.apache.beam
         beam-sdks-java-io-parent
    -    2.1.0-SNAPSHOT
    +    2.2.0-SNAPSHOT
         ../pom.xml
       
     
    diff --git a/sdks/java/io/pom.xml b/sdks/java/io/pom.xml
    index 458dfaf6c79a6..b7909fa7f6c7e 100644
    --- a/sdks/java/io/pom.xml
    +++ b/sdks/java/io/pom.xml
    @@ -22,7 +22,7 @@
       
         org.apache.beam
         beam-sdks-java-parent
    -    2.1.0-SNAPSHOT
    +    2.2.0-SNAPSHOT
         ../pom.xml
       
     
    diff --git a/sdks/java/io/xml/pom.xml b/sdks/java/io/xml/pom.xml
    index cf7dd3364d5e6..7b5804eea9b46 100644
    --- a/sdks/java/io/xml/pom.xml
    +++ b/sdks/java/io/xml/pom.xml
    @@ -22,7 +22,7 @@
       
         org.apache.beam
         beam-sdks-java-io-parent
    -    2.1.0-SNAPSHOT
    +    2.2.0-SNAPSHOT
         ../pom.xml
       
     
    diff --git a/sdks/java/java8tests/pom.xml b/sdks/java/java8tests/pom.xml
    index b90a757c7f448..2378014af2f54 100644
    --- a/sdks/java/java8tests/pom.xml
    +++ b/sdks/java/java8tests/pom.xml
    @@ -22,7 +22,7 @@
       
         org.apache.beam
         beam-sdks-java-parent
    -    2.1.0-SNAPSHOT
    +    2.2.0-SNAPSHOT
         ../pom.xml
       
     
    diff --git a/sdks/java/javadoc/pom.xml b/sdks/java/javadoc/pom.xml
    index 08d5ec6eb0896..ddb92cfb8e2b6 100644
    --- a/sdks/java/javadoc/pom.xml
    +++ b/sdks/java/javadoc/pom.xml
    @@ -22,7 +22,7 @@
       
         org.apache.beam
         beam-parent
    -    2.1.0-SNAPSHOT
    +    2.2.0-SNAPSHOT
         ../../../pom.xml
       
     
    diff --git a/sdks/java/maven-archetypes/examples-java8/pom.xml b/sdks/java/maven-archetypes/examples-java8/pom.xml
    index b57644d74aaef..b60a6954acbfa 100644
    --- a/sdks/java/maven-archetypes/examples-java8/pom.xml
    +++ b/sdks/java/maven-archetypes/examples-java8/pom.xml
    @@ -22,7 +22,7 @@
       
         org.apache.beam
         beam-sdks-java-maven-archetypes-parent
    -    2.1.0-SNAPSHOT
    +    2.2.0-SNAPSHOT
         ../pom.xml
       
     
    diff --git a/sdks/java/maven-archetypes/examples/pom.xml b/sdks/java/maven-archetypes/examples/pom.xml
    index c1378cbc385c0..2a02039052207 100644
    --- a/sdks/java/maven-archetypes/examples/pom.xml
    +++ b/sdks/java/maven-archetypes/examples/pom.xml
    @@ -22,7 +22,7 @@
       
         org.apache.beam
         beam-sdks-java-maven-archetypes-parent
    -    2.1.0-SNAPSHOT
    +    2.2.0-SNAPSHOT
         ../pom.xml
       
     
    diff --git a/sdks/java/maven-archetypes/pom.xml b/sdks/java/maven-archetypes/pom.xml
    index b7fe2747daa28..d676b316652f1 100644
    --- a/sdks/java/maven-archetypes/pom.xml
    +++ b/sdks/java/maven-archetypes/pom.xml
    @@ -22,7 +22,7 @@
       
         org.apache.beam
         beam-sdks-java-parent
    -    2.1.0-SNAPSHOT
    +    2.2.0-SNAPSHOT
         ../pom.xml
       
     
    diff --git a/sdks/java/maven-archetypes/starter/pom.xml b/sdks/java/maven-archetypes/starter/pom.xml
    index 06b41c8f08507..8024b52779b9b 100644
    --- a/sdks/java/maven-archetypes/starter/pom.xml
    +++ b/sdks/java/maven-archetypes/starter/pom.xml
    @@ -22,7 +22,7 @@
       
         org.apache.beam
         beam-sdks-java-maven-archetypes-parent
    -    2.1.0-SNAPSHOT
    +    2.2.0-SNAPSHOT
         ../pom.xml
       
     
    diff --git a/sdks/java/pom.xml b/sdks/java/pom.xml
    index 250c85aa74604..3144193b9839e 100644
    --- a/sdks/java/pom.xml
    +++ b/sdks/java/pom.xml
    @@ -22,7 +22,7 @@
       
         org.apache.beam
         beam-sdks-parent
    -    2.1.0-SNAPSHOT
    +    2.2.0-SNAPSHOT
         ../pom.xml
       
     
    diff --git a/sdks/pom.xml b/sdks/pom.xml
    index 27b9610d11e6a..aec8762fb8c99 100644
    --- a/sdks/pom.xml
    +++ b/sdks/pom.xml
    @@ -22,7 +22,7 @@
       
         org.apache.beam
         beam-parent
    -    2.1.0-SNAPSHOT
    +    2.2.0-SNAPSHOT
         ../pom.xml
       
     
    diff --git a/sdks/python/pom.xml b/sdks/python/pom.xml
    index 129565424e057..10776892ba114 100644
    --- a/sdks/python/pom.xml
    +++ b/sdks/python/pom.xml
    @@ -22,7 +22,7 @@
       
         org.apache.beam
         beam-sdks-parent
    -    2.1.0-SNAPSHOT
    +    2.2.0-SNAPSHOT
         ../pom.xml
       
     
    
    From 14fa7f79f0830739122b7573e032ad0aea172a98 Mon Sep 17 00:00:00 2001
    From: =?UTF-8?q?Jean-Baptiste=20Onofr=C3=A9?= 
    Date: Wed, 5 Jul 2017 16:52:48 +0200
    Subject: [PATCH 157/200] Update Python SDK version
    
    ---
     sdks/python/apache_beam/version.py | 2 +-
     1 file changed, 1 insertion(+), 1 deletion(-)
    
    diff --git a/sdks/python/apache_beam/version.py b/sdks/python/apache_beam/version.py
    index ae92a235c53bc..8b0a430ddb739 100644
    --- a/sdks/python/apache_beam/version.py
    +++ b/sdks/python/apache_beam/version.py
    @@ -18,4 +18,4 @@
     """Apache Beam SDK version information and utilities."""
     
     
    -__version__ = '2.1.0.dev'
    +__version__ = '2.2.0.dev'
    
    From 06897b1cc142f658437ac7779c849e5182e331f1 Mon Sep 17 00:00:00 2001
    From: Jason Kuster 
    Date: Fri, 9 Jun 2017 01:39:15 -0700
    Subject: [PATCH 158/200] Website Mergebot Job
    
    Signed-off-by: Jason Kuster 
    ---
     .../jenkins/common_job_properties.groovy      |  5 +-
     .../job_beam_PreCommit_Website_Merge.groovy   | 59 +++++++++++++++++++
     2 files changed, 62 insertions(+), 2 deletions(-)
     create mode 100644 .test-infra/jenkins/job_beam_PreCommit_Website_Merge.groovy
    
    diff --git a/.test-infra/jenkins/common_job_properties.groovy b/.test-infra/jenkins/common_job_properties.groovy
    index 0e047eac70d56..70534c6ac3bdb 100644
    --- a/.test-infra/jenkins/common_job_properties.groovy
    +++ b/.test-infra/jenkins/common_job_properties.groovy
    @@ -23,11 +23,12 @@
     class common_job_properties {
     
       // Sets common top-level job properties for website repository jobs.
    -  static void setTopLevelWebsiteJobProperties(context) {
    +  static void setTopLevelWebsiteJobProperties(context,
    +                                              String branch = 'asf-site') {
         setTopLevelJobProperties(
                 context,
                 'beam-site',
    -            'asf-site',
    +            branch,
                 'beam',
                 30)
       }
    diff --git a/.test-infra/jenkins/job_beam_PreCommit_Website_Merge.groovy b/.test-infra/jenkins/job_beam_PreCommit_Website_Merge.groovy
    new file mode 100644
    index 0000000000000..0e2ae3fc5526b
    --- /dev/null
    +++ b/.test-infra/jenkins/job_beam_PreCommit_Website_Merge.groovy
    @@ -0,0 +1,59 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one
    + * or more contributor license agreements.  See the NOTICE file
    + * distributed with this work for additional information
    + * regarding copyright ownership.  The ASF licenses this file
    + * to you under the Apache License, Version 2.0 (the
    + * "License"); you may not use this file except in compliance
    + * with the License.  You may obtain a copy of the License at
    + *
    + *     http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +import common_job_properties
    +
    +// Defines a job.
    +job('beam_PreCommit_Website_Merge') {
    +  description('Runs website tests for mergebot.')
    +
    +  // Set common parameters.
    +  common_job_properties.setTopLevelWebsiteJobProperties(delegate, 'mergebot')
    +
    +  triggers {
    +    githubPush()
    +  }
    +
    +  steps {
    +    // Run the following shell script as a build step.
    +    shell '''
    +        # Install RVM per instructions at https://rvm.io/rvm/install.
    +        RVM_GPG_KEY=409B6B1796C275462A1703113804BB82D39DC0E3
    +        gpg --keyserver hkp://keys.gnupg.net --recv-keys $RVM_GPG_KEY
    +            
    +        \\curl -sSL https://get.rvm.io | bash
    +        source /home/jenkins/.rvm/scripts/rvm
    +
    +        # Install Ruby.
    +        RUBY_VERSION_NUM=2.3.0
    +        rvm install ruby $RUBY_VERSION_NUM --autolibs=read-only
    +
    +        # Install Bundler gem
    +        PATH=~/.gem/ruby/$RUBY_VERSION_NUM/bin:$PATH
    +        GEM_PATH=~/.gem/ruby/$RUBY_VERSION_NUM/:$GEM_PATH
    +        gem install bundler --user-install
    +
    +        # Install all needed gems.
    +        bundle install --path ~/.gem/
    +
    +        # Build the new site and test it.
    +        rm -fr ./content/
    +        bundle exec rake test
    +    '''.stripIndent().trim()
    +  }
    +}
    
    From 6ca410a908f1f4e7ac1e141ee1335f7a537bb150 Mon Sep 17 00:00:00 2001
    From: Luke Cwik 
    Date: Wed, 5 Jul 2017 10:38:44 -0700
    Subject: [PATCH 159/200] [BEAM-2553] Update Maven exec plugin to 1.6.0 to
     incorporate messaging improvements
    
    ---
     pom.xml                                                         | 2 +-
     .../starter/src/test/resources/projects/basic/reference/pom.xml | 2 +-
     2 files changed, 2 insertions(+), 2 deletions(-)
    
    diff --git a/pom.xml b/pom.xml
    index 49e6eddc8ac00..01474c1a001c7 100644
    --- a/pom.xml
    +++ b/pom.xml
    @@ -159,7 +159,7 @@
         2.20
         3.6.1
         3.0.1
    -    1.4.0
    +    1.6.0
         3.0.2
         3.0.2
         3.0.0
    diff --git a/sdks/java/maven-archetypes/starter/src/test/resources/projects/basic/reference/pom.xml b/sdks/java/maven-archetypes/starter/src/test/resources/projects/basic/reference/pom.xml
    index 60405e6dbb178..6056fb0083ea7 100644
    --- a/sdks/java/maven-archetypes/starter/src/test/resources/projects/basic/reference/pom.xml
    +++ b/sdks/java/maven-archetypes/starter/src/test/resources/projects/basic/reference/pom.xml
    @@ -28,7 +28,7 @@
         @project.version@
     
         3.6.1
    -    1.4.0
    +    1.6.0
         1.7.14
       
     
    
    From 29c2bca4649317f2ebb1c89f92bf97fbb27602ca Mon Sep 17 00:00:00 2001
    From: Thomas Groh 
    Date: Wed, 5 Jul 2017 14:16:50 -0700
    Subject: [PATCH 160/200] Disallow Combiner Lifting for multi-window WindowFns
    
    ---
     .../apache/beam/runners/dataflow/DataflowPipelineTranslator.java | 1 +
     1 file changed, 1 insertion(+)
    
    diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java
    index 28fd1bb1af02c..f1783def50409 100644
    --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java
    +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java
    @@ -793,6 +793,7 @@ private  void groupByKeyHelper(
                     context.getPipelineOptions().as(StreamingOptions.class).isStreaming();
                 boolean disallowCombinerLifting =
                     !windowingStrategy.getWindowFn().isNonMerging()
    +                    || !windowingStrategy.getWindowFn().assignsToOneWindow()
                         || (isStreaming && !transform.fewKeys())
                         // TODO: Allow combiner lifting on the non-default trigger, as appropriate.
                         || !(windowingStrategy.getTrigger() instanceof DefaultTrigger);
    
    From 23e385faa193a00f9b10e3f8f0afe832087bff06 Mon Sep 17 00:00:00 2001
    From: Ahmet Altay 
    Date: Wed, 5 Jul 2017 14:34:07 -0700
    Subject: [PATCH 161/200] Update SDK dependencies
    
    ---
     sdks/python/setup.py | 4 ++--
     1 file changed, 2 insertions(+), 2 deletions(-)
    
    diff --git a/sdks/python/setup.py b/sdks/python/setup.py
    index 6646a58e529ab..8a0c9aefab16d 100644
    --- a/sdks/python/setup.py
    +++ b/sdks/python/setup.py
    @@ -120,9 +120,9 @@ def get_version():
       'google-apitools>=0.5.10,<=0.5.11',
       'proto-google-cloud-datastore-v1>=0.90.0,<=0.90.4',
       'googledatastore==7.0.1',
    -  'google-cloud-pubsub==0.25.0',
    +  'google-cloud-pubsub==0.26.0',
       # GCP packages required by tests
    -  'google-cloud-bigquery>=0.23.0,<0.25.0',
    +  'google-cloud-bigquery>=0.23.0,<0.26.0',
     ]
     
     
    
    From a75202f344f22be5c5fdf62b3eb54a151ad29af6 Mon Sep 17 00:00:00 2001
    From: Charles Chen 
    Date: Wed, 5 Jul 2017 16:18:51 -0700
    Subject: [PATCH 162/200] Fix PValue input in _PubSubReadEvaluator
    
    ---
     .../python/apache_beam/runners/direct/transform_evaluator.py | 5 +++--
     1 file changed, 3 insertions(+), 2 deletions(-)
    
    diff --git a/sdks/python/apache_beam/runners/direct/transform_evaluator.py b/sdks/python/apache_beam/runners/direct/transform_evaluator.py
    index 641291d4857c2..cb2ace29f0eac 100644
    --- a/sdks/python/apache_beam/runners/direct/transform_evaluator.py
    +++ b/sdks/python/apache_beam/runners/direct/transform_evaluator.py
    @@ -436,8 +436,9 @@ def finish_bundle(self):
           bundles = [bundle]
         else:
           bundles = []
    -    input_pvalue = self._applied_ptransform.inputs
    -    if not input_pvalue:
    +    if self._applied_ptransform.inputs:
    +      input_pvalue = self._applied_ptransform.inputs[0]
    +    else:
           input_pvalue = pvalue.PBegin(self._applied_ptransform.transform.pipeline)
         unprocessed_bundle = self._evaluation_context.create_bundle(
             input_pvalue)
    
    From 2259c309c5b81a5d1e32732dd35e1102766401fa Mon Sep 17 00:00:00 2001
    From: Raghu Angadi 
    Date: Wed, 28 Jun 2017 12:07:06 -0700
    Subject: [PATCH 163/200] [BEAM-2534] Handle offset gaps in Kafka messages.
    
    KafkaIO logged a warning when there is a gap in offstes for messages.
    Kafka also support 'KV' store style topics where some of the messages
    are deleted leading gaps in offsets. This PR removes the log and
    accounts for offset gaps in backlog estimate.
    ---
     .../org/apache/beam/sdk/io/kafka/KafkaIO.java | 49 +++++++++++--------
     1 file changed, 29 insertions(+), 20 deletions(-)
    
    diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java
    index 702bdd32b712b..e520367f057bf 100644
    --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java
    +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java
    @@ -904,6 +904,22 @@ public String toString() {
           return name;
         }
     
    +    // Maintains approximate average over last 1000 elements
    +    private static class MovingAvg {
    +      private static final int MOVING_AVG_WINDOW = 1000;
    +      private double avg = 0;
    +      private long numUpdates = 0;
    +
    +      void update(double quantity) {
    +        numUpdates++;
    +        avg += (quantity - avg) / Math.min(MOVING_AVG_WINDOW, numUpdates);
    +      }
    +
    +      double get() {
    +        return avg;
    +      }
    +    }
    +
         // maintains state of each assigned partition (buffered records, consumed offset, etc)
         private static class PartitionState {
           private final TopicPartition topicPartition;
    @@ -911,9 +927,8 @@ private static class PartitionState {
           private long latestOffset;
           private Iterator> recordIter = Collections.emptyIterator();
     
    -      // simple moving average for size of each record in bytes
    -      private double avgRecordSize = 0;
    -      private static final int movingAvgWindow = 1000; // very roughly avg of last 1000 elements
    +      private MovingAvg avgRecordSize = new MovingAvg();
    +      private MovingAvg avgOffsetGap = new MovingAvg(); // > 0 only when log compaction is enabled.
     
           PartitionState(TopicPartition partition, long nextOffset) {
             this.topicPartition = partition;
    @@ -921,17 +936,13 @@ private static class PartitionState {
             this.latestOffset = UNINITIALIZED_OFFSET;
           }
     
    -      // update consumedOffset and avgRecordSize
    -      void recordConsumed(long offset, int size) {
    +      // Update consumedOffset, avgRecordSize, and avgOffsetGap
    +      void recordConsumed(long offset, int size, long offsetGap) {
             nextOffset = offset + 1;
     
    -        // this is always updated from single thread. probably not worth making it an AtomicDouble
    -        if (avgRecordSize <= 0) {
    -          avgRecordSize = size;
    -        } else {
    -          // initially, first record heavily contributes to average.
    -          avgRecordSize += ((size - avgRecordSize) / movingAvgWindow);
    -        }
    +        // This is always updated from single thread. Probably not worth making atomic.
    +        avgRecordSize.update(size);
    +        avgOffsetGap.update(offsetGap);
           }
     
           synchronized void setLatestOffset(long latestOffset) {
    @@ -944,14 +955,15 @@ synchronized long approxBacklogInBytes() {
             if (backlogMessageCount == UnboundedReader.BACKLOG_UNKNOWN) {
               return UnboundedReader.BACKLOG_UNKNOWN;
             }
    -        return (long) (backlogMessageCount * avgRecordSize);
    +        return (long) (backlogMessageCount * avgRecordSize.get());
           }
     
           synchronized long backlogMessageCount() {
             if (latestOffset < 0 || nextOffset < 0) {
               return UnboundedReader.BACKLOG_UNKNOWN;
             }
    -        return Math.max(0, (latestOffset - nextOffset));
    +        double remaining = (latestOffset - nextOffset) / (1 + avgOffsetGap.get());
    +        return Math.max(0, (long) Math.ceil(remaining));
           }
         }
     
    @@ -1154,14 +1166,11 @@ public boolean advance() throws IOException {
                 continue;
               }
     
    -          // sanity check
    -          if (offset != expected) {
    -            LOG.warn("{}: gap in offsets for {} at {}. {} records missing.",
    -                this, pState.topicPartition, expected, offset - expected);
    -          }
    +          long offsetGap = offset - expected; // could be > 0 when Kafka log compaction is enabled.
     
               if (curRecord == null) {
                 LOG.info("{}: first record offset {}", name, offset);
    +            offsetGap = 0;
               }
     
               curRecord = null; // user coders below might throw.
    @@ -1182,7 +1191,7 @@ public boolean advance() throws IOException {
     
               int recordSize = (rawRecord.key() == null ? 0 : rawRecord.key().length)
                   + (rawRecord.value() == null ? 0 : rawRecord.value().length);
    -          pState.recordConsumed(offset, recordSize);
    +          pState.recordConsumed(offset, recordSize, offsetGap);
               bytesRead.inc(recordSize);
               bytesReadBySplit.inc(recordSize);
               return true;
    
    From 29c35cdece61f65881a94d480257291a1bd7fc83 Mon Sep 17 00:00:00 2001
    From: Kenneth Knowles 
    Date: Thu, 6 Jul 2017 11:07:38 -0700
    Subject: [PATCH 164/200] Update Dataflow container version to 20170706
    
    ---
     runners/google-cloud-dataflow-java/pom.xml | 2 +-
     1 file changed, 1 insertion(+), 1 deletion(-)
    
    diff --git a/runners/google-cloud-dataflow-java/pom.xml b/runners/google-cloud-dataflow-java/pom.xml
    index 91908cdcbbcf4..c8d63ac56a272 100644
    --- a/runners/google-cloud-dataflow-java/pom.xml
    +++ b/runners/google-cloud-dataflow-java/pom.xml
    @@ -33,7 +33,7 @@
       jar
     
       
    -    beam-master-20170623
    +    beam-master-20170706
         1
         6
       
    
    From 526037b6786315b9f9fdca6edb636baeb6f83e3f Mon Sep 17 00:00:00 2001
    From: Raghu Angadi 
    Date: Mon, 3 Jul 2017 23:54:10 -0700
    Subject: [PATCH 165/200] Add timeout to initialization of partition in KafkaIO
    
    ---
     .../org/apache/beam/sdk/io/kafka/KafkaIO.java | 81 ++++++++++++++-----
     .../apache/beam/sdk/io/kafka/KafkaIOTest.java | 30 +++++++
     2 files changed, 92 insertions(+), 19 deletions(-)
    
    diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java
    index e520367f057bf..026313ab2a8cb 100644
    --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java
    +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java
    @@ -49,9 +49,11 @@
     import java.util.Set;
     import java.util.concurrent.ExecutorService;
     import java.util.concurrent.Executors;
    +import java.util.concurrent.Future;
     import java.util.concurrent.ScheduledExecutorService;
     import java.util.concurrent.SynchronousQueue;
     import java.util.concurrent.TimeUnit;
    +import java.util.concurrent.TimeoutException;
     import java.util.concurrent.atomic.AtomicBoolean;
     import javax.annotation.Nullable;
     import org.apache.beam.sdk.annotations.Experimental;
    @@ -1061,8 +1063,32 @@ private void nextBatch() {
           curBatch = Iterators.cycle(nonEmpty);
         }
     
    +    private void setupInitialOffset(PartitionState pState) {
    +      Read spec = source.spec;
    +
    +      if (pState.nextOffset != UNINITIALIZED_OFFSET) {
    +        consumer.seek(pState.topicPartition, pState.nextOffset);
    +      } else {
    +        // nextOffset is unininitialized here, meaning start reading from latest record as of now
    +        // ('latest' is the default, and is configurable) or 'look up offset by startReadTime.
    +        // Remember the current position without waiting until the first record is read. This
    +        // ensures checkpoint is accurate even if the reader is closed before reading any records.
    +        Instant startReadTime = spec.getStartReadTime();
    +        if (startReadTime != null) {
    +          pState.nextOffset =
    +              consumerSpEL.offsetForTime(consumer, pState.topicPartition, spec.getStartReadTime());
    +          consumer.seek(pState.topicPartition, pState.nextOffset);
    +        } else {
    +          pState.nextOffset = consumer.position(pState.topicPartition);
    +        }
    +      }
    +    }
    +
         @Override
         public boolean start() throws IOException {
    +      final int defaultPartitionInitTimeout = 60 * 1000;
    +      final int kafkaRequestTimeoutMultiple = 2;
    +
           Read spec = source.spec;
           consumer = spec.getConsumerFactoryFn().apply(spec.getConsumerConfig());
           consumerSpEL.evaluateAssign(consumer, spec.getTopicPartitions());
    @@ -1077,25 +1103,38 @@ public boolean start() throws IOException {
           keyDeserializerInstance.configure(spec.getConsumerConfig(), true);
           valueDeserializerInstance.configure(spec.getConsumerConfig(), false);
     
    -      for (PartitionState p : partitionStates) {
    -        if (p.nextOffset != UNINITIALIZED_OFFSET) {
    -          consumer.seek(p.topicPartition, p.nextOffset);
    -        } else {
    -          // nextOffset is unininitialized here, meaning start reading from latest record as of now
    -          // ('latest' is the default, and is configurable) or 'look up offset by startReadTime.
    -          // Remember the current position without waiting until the first record is read. This
    -          // ensures checkpoint is accurate even if the reader is closed before reading any records.
    -          Instant startReadTime = spec.getStartReadTime();
    -          if (startReadTime != null) {
    -            p.nextOffset =
    -                consumerSpEL.offsetForTime(consumer, p.topicPartition, spec.getStartReadTime());
    -            consumer.seek(p.topicPartition, p.nextOffset);
    -          } else {
    -            p.nextOffset = consumer.position(p.topicPartition);
    +      // Seek to start offset for each partition. This is the first interaction with the server.
    +      // Unfortunately it can block forever in case of network issues like incorrect ACLs.
    +      // Initialize partition in a separate thread and cancel it if takes longer than a minute.
    +      for (final PartitionState pState : partitionStates) {
    +        Future future =  consumerPollThread.submit(new Runnable() {
    +          public void run() {
    +            setupInitialOffset(pState);
               }
    -        }
    +        });
     
    -        LOG.info("{}: reading from {} starting at offset {}", name, p.topicPartition, p.nextOffset);
    +        try {
    +          // Timeout : 1 minute OR 2 * Kafka consumer request timeout if it is set.
    +          Integer reqTimeout = (Integer) source.spec.getConsumerConfig().get(
    +              ConsumerConfig.REQUEST_TIMEOUT_MS_CONFIG);
    +          future.get(reqTimeout != null ? kafkaRequestTimeoutMultiple * reqTimeout
    +                         : defaultPartitionInitTimeout,
    +                     TimeUnit.MILLISECONDS);
    +        } catch (TimeoutException e) {
    +          consumer.wakeup(); // This unblocks consumer stuck on network I/O.
    +          // Likely reason : Kafka servers are configured to advertise internal ips, but
    +          // those ips are not accessible from workers outside.
    +          String msg = String.format(
    +              "%s: Timeout while initializing partition '%s'. "
    +                  + "Kafka client may not be able to connect to servers.",
    +              this, pState.topicPartition);
    +          LOG.error("{}", msg);
    +          throw new IOException(msg);
    +        } catch (Exception e) {
    +          throw new IOException(e);
    +        }
    +        LOG.info("{}: reading from {} starting at offset {}",
    +                 name, pState.topicPartition, pState.nextOffset);
           }
     
           // Start consumer read loop.
    @@ -1329,8 +1368,12 @@ public void close() throws IOException {
           // might block to enqueue right after availableRecordsQueue.poll() below.
           while (!isShutdown) {
     
    -        consumer.wakeup();
    -        offsetConsumer.wakeup();
    +        if (consumer != null) {
    +          consumer.wakeup();
    +        }
    +        if (offsetConsumer != null) {
    +          offsetConsumer.wakeup();
    +        }
             availableRecordsQueue.poll(); // drain unread batch, this unblocks consumer thread.
             try {
               isShutdown = consumerPollThread.awaitTermination(10, TimeUnit.SECONDS)
    diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOTest.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOTest.java
    index b69bc83561fcc..482f5a276f8ae 100644
    --- a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOTest.java
    +++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOTest.java
    @@ -83,6 +83,7 @@
     import org.apache.beam.sdk.values.PCollection;
     import org.apache.beam.sdk.values.PCollectionList;
     import org.apache.kafka.clients.consumer.Consumer;
    +import org.apache.kafka.clients.consumer.ConsumerConfig;
     import org.apache.kafka.clients.consumer.ConsumerRecord;
     import org.apache.kafka.clients.consumer.MockConsumer;
     import org.apache.kafka.clients.consumer.OffsetResetStrategy;
    @@ -363,6 +364,35 @@ public void testUnboundedSource() {
         p.run();
       }
     
    +  @Test
    +  public void testUnreachableKafkaBrokers() {
    +    // Expect an exception when the Kafka brokers are not reachable on the workers.
    +    // We specify partitions explicitly so that splitting does not involve server interaction.
    +    // Set request timeout to 10ms so that test does not take long.
    +
    +    thrown.expect(Exception.class);
    +    thrown.expectMessage("Reader-0: Timeout while initializing partition 'test-0'");
    +
    +    int numElements = 1000;
    +    PCollection input = p
    +        .apply(KafkaIO.read()
    +            .withBootstrapServers("8.8.8.8:9092") // Google public DNS ip.
    +            .withTopicPartitions(ImmutableList.of(new TopicPartition("test", 0)))
    +            .withKeyDeserializer(IntegerDeserializer.class)
    +            .withValueDeserializer(LongDeserializer.class)
    +            .updateConsumerProperties(ImmutableMap.of(
    +                ConsumerConfig.REQUEST_TIMEOUT_MS_CONFIG, 10,
    +                ConsumerConfig.HEARTBEAT_INTERVAL_MS_CONFIG, 5,
    +                ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG, 8,
    +                ConsumerConfig.FETCH_MAX_WAIT_MS_CONFIG, 8))
    +            .withMaxNumRecords(10)
    +            .withoutMetadata())
    +        .apply(Values.create());
    +
    +    addCountingAsserts(input, numElements);
    +    p.run();
    +  }
    +
       @Test
       public void testUnboundedSourceWithSingleTopic() {
         // same as testUnboundedSource, but with single topic
    
    From bd631b89a8434f0756e1596875e89013fb623ab5 Mon Sep 17 00:00:00 2001
    From: Kenneth Knowles 
    Date: Thu, 22 Jun 2017 18:09:11 -0700
    Subject: [PATCH 166/200] Ignore processing time timers in expired windows
    
    ---
     .../beam/runners/core/ReduceFnRunner.java     | 10 ++++++
     .../beam/runners/core/ReduceFnRunnerTest.java | 32 +++++++++++++++++++
     2 files changed, 42 insertions(+)
    
    diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/ReduceFnRunner.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/ReduceFnRunner.java
    index ef33befffc939..0632c052912be 100644
    --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/ReduceFnRunner.java
    +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/ReduceFnRunner.java
    @@ -693,6 +693,11 @@ public void onTimers(Iterable timers) throws Exception {
           @SuppressWarnings("unchecked")
             WindowNamespace windowNamespace = (WindowNamespace) timer.getNamespace();
           W window = windowNamespace.getWindow();
    +
    +      if (TimeDomain.PROCESSING_TIME == timer.getDomain() && windowIsExpired(window)) {
    +        continue;
    +      }
    +
           ReduceFn.Context directContext =
               contextFactory.base(window, StateStyle.DIRECT);
           ReduceFn.Context renamedContext =
    @@ -1090,4 +1095,9 @@ private void cancelEndOfWindowAndGarbageCollectionTimers(
         }
       }
     
    +  private boolean windowIsExpired(BoundedWindow w) {
    +    return timerInternals
    +        .currentInputWatermarkTime()
    +        .isAfter(w.maxTimestamp().plus(windowingStrategy.getAllowedLateness()));
    +  }
     }
    diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/ReduceFnRunnerTest.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/ReduceFnRunnerTest.java
    index 3a2c2205c5ea6..79ee91b38495f 100644
    --- a/runners/core-java/src/test/java/org/apache/beam/runners/core/ReduceFnRunnerTest.java
    +++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/ReduceFnRunnerTest.java
    @@ -284,6 +284,38 @@ public void testOnElementCombiningDiscarding() throws Exception {
         tester.assertHasOnlyGlobalAndFinishedSetsFor(firstWindow);
       }
     
    +  /**
    +   * Tests that when a processing time timer comes in after a window is expired
    +   * it is just ignored.
    +   */
    +  @Test
    +  public void testLateProcessingTimeTimer() throws Exception {
    +    WindowingStrategy strategy =
    +        WindowingStrategy.of((WindowFn) FixedWindows.of(Duration.millis(100)))
    +            .withTimestampCombiner(TimestampCombiner.EARLIEST)
    +            .withMode(AccumulationMode.ACCUMULATING_FIRED_PANES)
    +            .withAllowedLateness(Duration.ZERO)
    +            .withTrigger(
    +                Repeatedly.forever(
    +                    AfterProcessingTime.pastFirstElementInPane().plusDelayOf(Duration.millis(10))));
    +
    +    ReduceFnTester tester =
    +        ReduceFnTester.combining(strategy, Sum.ofIntegers(), VarIntCoder.of());
    +
    +    tester.advanceProcessingTime(new Instant(5000));
    +    injectElement(tester, 2); // processing timer @ 5000 + 10; EOW timer @ 100
    +    injectElement(tester, 5);
    +
    +    // After this advancement, the window is expired and only the GC process
    +    // should be allowed to touch it
    +    tester.advanceInputWatermarkNoTimers(new Instant(100));
    +
    +    // This should not output
    +    tester.advanceProcessingTime(new Instant(6000));
    +
    +    assertThat(tester.extractOutput(), emptyIterable());
    +  }
    +
       /**
        * Tests that when a processing time timer comes in after a window is expired
        * but in the same bundle it does not cause a spurious output.
    
    From 935c077341de580dddd4b29ffee3926795acf403 Mon Sep 17 00:00:00 2001
    From: Kenneth Knowles 
    Date: Thu, 22 Jun 2017 18:43:39 -0700
    Subject: [PATCH 167/200] Process timer firings for a window together
    
    ---
     .../complete/game/LeaderBoardTest.java        |  2 +
     .../beam/runners/core/ReduceFnRunner.java     | 98 +++++++++++++------
     .../beam/runners/core/ReduceFnRunnerTest.java | 49 +++++++++-
     3 files changed, 115 insertions(+), 34 deletions(-)
    
    diff --git a/examples/java8/src/test/java/org/apache/beam/examples/complete/game/LeaderBoardTest.java b/examples/java8/src/test/java/org/apache/beam/examples/complete/game/LeaderBoardTest.java
    index 745c210c56357..611e2b3184936 100644
    --- a/examples/java8/src/test/java/org/apache/beam/examples/complete/game/LeaderBoardTest.java
    +++ b/examples/java8/src/test/java/org/apache/beam/examples/complete/game/LeaderBoardTest.java
    @@ -276,6 +276,8 @@ public void testTeamScoresDroppablyLate() {
             .addElements(event(TestUser.RED_ONE, 4, Duration.standardMinutes(2)),
                 event(TestUser.BLUE_TWO, 3, Duration.ZERO),
                 event(TestUser.BLUE_ONE, 3, Duration.standardMinutes(3)))
    +        // Move the watermark to the end of the window to output on time
    +        .advanceWatermarkTo(baseTime.plus(TEAM_WINDOW_DURATION))
             // Move the watermark past the end of the allowed lateness plus the end of the window
             .advanceWatermarkTo(baseTime.plus(ALLOWED_LATENESS)
                 .plus(TEAM_WINDOW_DURATION).plus(Duration.standardMinutes(1)))
    diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/ReduceFnRunner.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/ReduceFnRunner.java
    index 0632c052912be..634a2d13e7990 100644
    --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/ReduceFnRunner.java
    +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/ReduceFnRunner.java
    @@ -29,7 +29,6 @@
     import java.util.Collections;
     import java.util.HashMap;
     import java.util.HashSet;
    -import java.util.LinkedList;
     import java.util.List;
     import java.util.Map;
     import java.util.Set;
    @@ -638,11 +637,9 @@ private void processElement(Map windowToMergeResult, WindowedValue
       }
     
       /**
    -   * Enriches TimerData with state necessary for processing a timer as well as
    -   * common queries about a timer.
    +   * A descriptor of the activation for a window based on a timer.
        */
    -  private class EnrichedTimerData {
    -    public final Instant timestamp;
    +  private class WindowActivation {
         public final ReduceFn.Context directContext;
         public final ReduceFn.Context renamedContext;
         // If this is an end-of-window timer then we may need to set a garbage collection timer
    @@ -653,19 +650,34 @@ private class EnrichedTimerData {
         // end-of-window time to be a signal to garbage collect.
         public final boolean isGarbageCollection;
     
    -    EnrichedTimerData(
    -        TimerData timer,
    +    WindowActivation(
             ReduceFn.Context directContext,
             ReduceFn.Context renamedContext) {
    -      this.timestamp = timer.getTimestamp();
           this.directContext = directContext;
           this.renamedContext = renamedContext;
           W window = directContext.window();
    -      this.isEndOfWindow = TimeDomain.EVENT_TIME == timer.getDomain()
    -          && timer.getTimestamp().equals(window.maxTimestamp());
    -      Instant cleanupTime = LateDataUtils.garbageCollectionTime(window, windowingStrategy);
    +
    +      // The output watermark is before the end of the window if it is either unknown
    +      // or it is known to be before it. If it is unknown, that means that there hasn't been
    +      // enough data to advance it.
    +      boolean outputWatermarkBeforeEOW =
    +              timerInternals.currentOutputWatermarkTime() == null
    +          || !timerInternals.currentOutputWatermarkTime().isAfter(window.maxTimestamp());
    +
    +      // The "end of the window" is reached when the local input watermark (for this key) surpasses
    +      // it but the local output watermark (also for this key) has not. After data is emitted and
    +      // the output watermark hold is released, the output watermark on this key will immediately
    +      // exceed the end of the window (otherwise we could see multiple ON_TIME outputs)
    +      this.isEndOfWindow =
    +          timerInternals.currentInputWatermarkTime().isAfter(window.maxTimestamp())
    +              && outputWatermarkBeforeEOW;
    +
    +      // The "GC time" is reached when the input watermark surpasses the end of the window
    +      // plus allowed lateness. After this, the window is expired and expunged.
           this.isGarbageCollection =
    -          TimeDomain.EVENT_TIME == timer.getDomain() && !timer.getTimestamp().isBefore(cleanupTime);
    +          timerInternals
    +              .currentInputWatermarkTime()
    +              .isAfter(LateDataUtils.garbageCollectionTime(window, windowingStrategy));
         }
     
         // Has this window had its trigger finish?
    @@ -684,9 +696,10 @@ public void onTimers(Iterable timers) throws Exception {
           return;
         }
     
    -    // Create a reusable context for each timer and begin prefetching necessary
    +    // Create a reusable context for each window and begin prefetching necessary
         // state.
    -    List enrichedTimers = new LinkedList();
    +    Map windowActivations = new HashMap();
    +
         for (TimerData timer : timers) {
           checkArgument(timer.getNamespace() instanceof WindowNamespace,
               "Expected timer to be in WindowNamespace, but was in %s", timer.getNamespace());
    @@ -694,7 +707,24 @@ public void onTimers(Iterable timers) throws Exception {
             WindowNamespace windowNamespace = (WindowNamespace) timer.getNamespace();
           W window = windowNamespace.getWindow();
     
    -      if (TimeDomain.PROCESSING_TIME == timer.getDomain() && windowIsExpired(window)) {
    +      WindowTracing.debug("{}: Received timer key:{}; window:{}; data:{} with "
    +              + "inputWatermark:{}; outputWatermark:{}",
    +          ReduceFnRunner.class.getSimpleName(),
    +          key, window, timer,
    +          timerInternals.currentInputWatermarkTime(),
    +          timerInternals.currentOutputWatermarkTime());
    +
    +      // Processing time timers for an expired window are ignored, just like elements
    +      // that show up too late. Window GC is management by an event time timer
    +      if (TimeDomain.EVENT_TIME != timer.getDomain() && windowIsExpired(window)) {
    +        continue;
    +      }
    +
    +      // How a window is processed is a function only of the current state, not the details
    +      // of the timer. This makes us robust to large leaps in processing time and watermark
    +      // time, where both EOW and GC timers come in together and we need to GC and emit
    +      // the final pane.
    +      if (windowActivations.containsKey(window)) {
             continue;
           }
     
    @@ -702,11 +732,11 @@ public void onTimers(Iterable timers) throws Exception {
               contextFactory.base(window, StateStyle.DIRECT);
           ReduceFn.Context renamedContext =
               contextFactory.base(window, StateStyle.RENAMED);
    -      EnrichedTimerData enrichedTimer = new EnrichedTimerData(timer, directContext, renamedContext);
    -      enrichedTimers.add(enrichedTimer);
    +      WindowActivation windowActivation = new WindowActivation(directContext, renamedContext);
    +      windowActivations.put(window, windowActivation);
     
           // Perform prefetching of state to determine if the trigger should fire.
    -      if (enrichedTimer.isGarbageCollection) {
    +      if (windowActivation.isGarbageCollection) {
             triggerRunner.prefetchIsClosed(directContext.state());
           } else {
             triggerRunner.prefetchShouldFire(directContext.window(), directContext.state());
    @@ -714,7 +744,7 @@ public void onTimers(Iterable timers) throws Exception {
         }
     
         // For those windows that are active and open, prefetch the triggering or emitting state.
    -    for (EnrichedTimerData timer : enrichedTimers) {
    +    for (WindowActivation timer : windowActivations.values()) {
           if (timer.windowIsActiveAndOpen()) {
             ReduceFn.Context directContext = timer.directContext;
             if (timer.isGarbageCollection) {
    @@ -727,25 +757,27 @@ public void onTimers(Iterable timers) throws Exception {
         }
     
         // Perform processing now that everything is prefetched.
    -    for (EnrichedTimerData timer : enrichedTimers) {
    -      ReduceFn.Context directContext = timer.directContext;
    -      ReduceFn.Context renamedContext = timer.renamedContext;
    +    for (WindowActivation windowActivation : windowActivations.values()) {
    +      ReduceFn.Context directContext = windowActivation.directContext;
    +      ReduceFn.Context renamedContext = windowActivation.renamedContext;
     
    -      if (timer.isGarbageCollection) {
    -        WindowTracing.debug("ReduceFnRunner.onTimer: Cleaning up for key:{}; window:{} at {} with "
    -                + "inputWatermark:{}; outputWatermark:{}",
    -            key, directContext.window(), timer.timestamp,
    +      if (windowActivation.isGarbageCollection) {
    +        WindowTracing.debug(
    +            "{}: Cleaning up for key:{}; window:{} with inputWatermark:{}; outputWatermark:{}",
    +            ReduceFnRunner.class.getSimpleName(),
    +            key,
    +            directContext.window(),
                 timerInternals.currentInputWatermarkTime(),
                 timerInternals.currentOutputWatermarkTime());
     
    -        boolean windowIsActiveAndOpen = timer.windowIsActiveAndOpen();
    +        boolean windowIsActiveAndOpen = windowActivation.windowIsActiveAndOpen();
             if (windowIsActiveAndOpen) {
               // We need to call onTrigger to emit the final pane if required.
               // The final pane *may* be ON_TIME if no prior ON_TIME pane has been emitted,
               // and the watermark has passed the end of the window.
               @Nullable
               Instant newHold = onTrigger(
    -              directContext, renamedContext, true /* isFinished */, timer.isEndOfWindow);
    +              directContext, renamedContext, true /* isFinished */, windowActivation.isEndOfWindow);
               checkState(newHold == null, "Hold placed at %s despite isFinished being true.", newHold);
             }
     
    @@ -753,18 +785,20 @@ public void onTimers(Iterable timers) throws Exception {
             // see elements for it again.
             clearAllState(directContext, renamedContext, windowIsActiveAndOpen);
           } else {
    -        WindowTracing.debug("ReduceFnRunner.onTimer: Triggering for key:{}; window:{} at {} with "
    +        WindowTracing.debug(
    +            "{}.onTimers: Triggering for key:{}; window:{} at {} with "
                     + "inputWatermark:{}; outputWatermark:{}",
    -            key, directContext.window(), timer.timestamp,
    +            key,
    +            directContext.window(),
                 timerInternals.currentInputWatermarkTime(),
                 timerInternals.currentOutputWatermarkTime());
    -        if (timer.windowIsActiveAndOpen()
    +        if (windowActivation.windowIsActiveAndOpen()
                 && triggerRunner.shouldFire(
                        directContext.window(), directContext.timers(), directContext.state())) {
               emit(directContext, renamedContext);
             }
     
    -        if (timer.isEndOfWindow) {
    +        if (windowActivation.isEndOfWindow) {
               // If the window strategy trigger includes a watermark trigger then at this point
               // there should be no data holds, either because we'd already cleared them on an
               // earlier onTrigger, or because we just cleared them on the above emit.
    diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/ReduceFnRunnerTest.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/ReduceFnRunnerTest.java
    index 79ee91b38495f..4f13af19b40e9 100644
    --- a/runners/core-java/src/test/java/org/apache/beam/runners/core/ReduceFnRunnerTest.java
    +++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/ReduceFnRunnerTest.java
    @@ -55,6 +55,7 @@
     import org.apache.beam.sdk.transforms.windowing.AfterProcessingTime;
     import org.apache.beam.sdk.transforms.windowing.AfterWatermark;
     import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
    +import org.apache.beam.sdk.transforms.windowing.DefaultTrigger;
     import org.apache.beam.sdk.transforms.windowing.FixedWindows;
     import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
     import org.apache.beam.sdk.transforms.windowing.GlobalWindows;
    @@ -79,7 +80,6 @@
     import org.joda.time.Duration;
     import org.joda.time.Instant;
     import org.junit.Before;
    -import org.junit.Ignore;
     import org.junit.Test;
     import org.junit.runner.RunWith;
     import org.junit.runners.JUnit4;
    @@ -246,6 +246,52 @@ public void testOnElementBufferingAccumulating() throws Exception {
         tester.assertHasOnlyGlobalAndFinishedSetsFor(firstWindow);
       }
     
    +  /**
    +   * Tests that with the default trigger we will not produce two ON_TIME panes, even
    +   * if there are two outputs that are both candidates.
    +   */
    +  @Test
    +  public void testOnlyOneOnTimePane() throws Exception {
    +    WindowingStrategy strategy =
    +        WindowingStrategy.of((WindowFn) FixedWindows.of(Duration.millis(10)))
    +            .withTrigger(DefaultTrigger.of())
    +            .withMode(AccumulationMode.ACCUMULATING_FIRED_PANES)
    +            .withAllowedLateness(Duration.millis(100));
    +
    +    ReduceFnTester tester =
    +        ReduceFnTester.combining(strategy, Sum.ofIntegers(), VarIntCoder.of());
    +
    +    tester.advanceInputWatermark(new Instant(0));
    +
    +    int value1 = 1;
    +    int value2 = 3;
    +
    +    // A single element that should be in the ON_TIME output
    +    tester.injectElements(
    +        TimestampedValue.of(value1, new Instant(1)));
    +
    +    // Should fire ON_TIME
    +    tester.advanceInputWatermark(new Instant(10));
    +
    +    // The DefaultTrigger should cause output labeled LATE, even though it does not have to be
    +    // labeled as such.
    +    tester.injectElements(
    +        TimestampedValue.of(value2, new Instant(3)));
    +
    +    List> output = tester.extractOutput();
    +    assertEquals(2, output.size());
    +
    +    assertThat(output.get(0), WindowMatchers.isWindowedValue(equalTo(value1)));
    +    assertThat(output.get(1), WindowMatchers.isWindowedValue(equalTo(value1 + value2)));
    +
    +    assertThat(
    +        output.get(0),
    +        WindowMatchers.valueWithPaneInfo(PaneInfo.createPane(true, false, Timing.ON_TIME, 0, 0)));
    +    assertThat(
    +        output.get(1),
    +        WindowMatchers.valueWithPaneInfo(PaneInfo.createPane(false, false, Timing.LATE, 1, 1)));
    +  }
    +
       @Test
       public void testOnElementCombiningDiscarding() throws Exception {
         // Test basic execution of a trigger using a non-combining window set and discarding mode.
    @@ -458,7 +504,6 @@ public void testCombiningAccumulatingProcessingTimeSeparateBundles() throws Exce
        * marked as final.
        */
       @Test
    -  @Ignore("https://issues.apache.org/jira/browse/BEAM-2505")
       public void testCombiningAccumulatingEventTime() throws Exception {
         WindowingStrategy strategy =
             WindowingStrategy.of((WindowFn) FixedWindows.of(Duration.millis(100)))
    
    From da92256ba64c5f4777776bd6283db2484bd72293 Mon Sep 17 00:00:00 2001
    From: Jeremie Lenfant-Engelmann 
    Date: Wed, 28 Jun 2017 16:11:21 -0700
    Subject: [PATCH 168/200] Made DataflowRunner TransformTranslator public
    
    ---
     .../org/apache/beam/runners/dataflow/TransformTranslator.java  | 3 ++-
     1 file changed, 2 insertions(+), 1 deletion(-)
    
    diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/TransformTranslator.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/TransformTranslator.java
    index a7452b2fdfa89..7f61b6cf3138d 100644
    --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/TransformTranslator.java
    +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/TransformTranslator.java
    @@ -36,7 +36,8 @@
      * A {@link TransformTranslator} knows how to translate a particular subclass of {@link PTransform}
      * for the Cloud Dataflow service. It does so by mutating the {@link TranslationContext}.
      */
    -interface TransformTranslator {
    +@Internal
    +public interface TransformTranslator {
       void translate(TransformT transform, TranslationContext context);
     
       /**
    
    From 17bc3b140c7c7315880ce18d4e15d6ac512c35d2 Mon Sep 17 00:00:00 2001
    From: Kenneth Knowles 
    Date: Thu, 6 Jul 2017 21:45:39 -0700
    Subject: [PATCH 169/200] Fix bad merge
    
    ---
     .../org/apache/beam/runners/dataflow/TransformTranslator.java    | 1 +
     1 file changed, 1 insertion(+)
    
    diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/TransformTranslator.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/TransformTranslator.java
    index 7f61b6cf3138d..06ed1e07b1823 100644
    --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/TransformTranslator.java
    +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/TransformTranslator.java
    @@ -22,6 +22,7 @@
     import org.apache.beam.runners.dataflow.options.DataflowPipelineOptions;
     import org.apache.beam.runners.dataflow.util.OutputReference;
     import org.apache.beam.sdk.Pipeline;
    +import org.apache.beam.sdk.annotations.Internal;
     import org.apache.beam.sdk.coders.Coder;
     import org.apache.beam.sdk.runners.AppliedPTransform;
     import org.apache.beam.sdk.transforms.PTransform;
    
    From c8d983363efd3f3d93825ecc8e8abae2dfa4e008 Mon Sep 17 00:00:00 2001
    From: Innocent Djiofack 
    Date: Wed, 28 Jun 2017 22:15:11 -0400
    Subject: [PATCH 170/200] Simplified ByteBuddyOnTimerInvokerFactory
    
    ---
     .../ByteBuddyOnTimerInvokerFactory.java       | 73 +++++++------------
     .../reflect/OnTimerMethodSpecifier.java       | 37 ++++++++++
     2 files changed, 65 insertions(+), 45 deletions(-)
     create mode 100644 sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/OnTimerMethodSpecifier.java
    
    diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/ByteBuddyOnTimerInvokerFactory.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/ByteBuddyOnTimerInvokerFactory.java
    index e031337279787..5e31f2e3e8463 100644
    --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/ByteBuddyOnTimerInvokerFactory.java
    +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/ByteBuddyOnTimerInvokerFactory.java
    @@ -17,6 +17,7 @@
      */
     package org.apache.beam.sdk.transforms.reflect;
     
    +
     import com.google.common.base.CharMatcher;
     import com.google.common.cache.CacheBuilder;
     import com.google.common.cache.CacheLoader;
    @@ -61,13 +62,14 @@ public  OnTimerInvoker forTimer(
     
         @SuppressWarnings("unchecked")
         Class> fnClass = (Class>) fn.getClass();
    -
         try {
    -      Constructor constructor = constructorCache.get(fnClass).get(timerId);
    -      @SuppressWarnings("unchecked")
    -      OnTimerInvoker invoker =
    +        OnTimerMethodSpecifier onTimerMethodSpecifier =
    +                OnTimerMethodSpecifier.forClassAndTimerId(fnClass, timerId);
    +        Constructor constructor = constructorCache.get(onTimerMethodSpecifier);
    +
    +        OnTimerInvoker invoker =
               (OnTimerInvoker) constructor.newInstance(fn);
    -      return invoker;
    +        return invoker;
         } catch (InstantiationException
             | IllegalAccessException
             | IllegalArgumentException
    @@ -97,50 +99,31 @@ private ByteBuddyOnTimerInvokerFactory() {}
       private static final String FN_DELEGATE_FIELD_NAME = "delegate";
     
       /**
    -   * A cache of constructors of generated {@link OnTimerInvoker} classes, keyed by {@link DoFn}
    -   * class and then by {@link TimerId}.
    +   * A cache of constructors of generated {@link OnTimerInvoker} classes,
    +   * keyed by {@link OnTimerMethodSpecifier}.
        *
        * 

    Needed because generating an invoker class is expensive, and to avoid generating an * excessive number of classes consuming PermGen memory in Java's that still have PermGen. */ - private final LoadingCache>, LoadingCache>> - constructorCache = - CacheBuilder.newBuilder() - .build( - new CacheLoader< - Class>, LoadingCache>>() { - @Override - public LoadingCache> load( - final Class> fnClass) throws Exception { - return CacheBuilder.newBuilder().build(new OnTimerConstructorLoader(fnClass)); - } - }); - - /** - * A cache loader fixed to a particular {@link DoFn} class that loads constructors for the - * invokers for its {@link OnTimer @OnTimer} methods. - */ - private static class OnTimerConstructorLoader extends CacheLoader> { - - private final DoFnSignature signature; - - public OnTimerConstructorLoader(Class> clazz) { - this.signature = DoFnSignatures.getSignature(clazz); - } - - @Override - public Constructor load(String timerId) throws Exception { - Class> invokerClass = - generateOnTimerInvokerClass(signature, timerId); - try { - return invokerClass.getConstructor(signature.fnClass()); - } catch (IllegalArgumentException | NoSuchMethodException | SecurityException e) { - throw new RuntimeException(e); - } - } - } - - /** + private final LoadingCache> constructorCache = + CacheBuilder.newBuilder().build( + new CacheLoader>() { + @Override + public Constructor load(final OnTimerMethodSpecifier onTimerMethodSpecifier) + throws Exception { + DoFnSignature signature = + DoFnSignatures.getSignature(onTimerMethodSpecifier.fnClass()); + Class> invokerClass = + generateOnTimerInvokerClass(signature, onTimerMethodSpecifier.timerId()); + try { + return invokerClass.getConstructor(signature.fnClass()); + } catch (IllegalArgumentException | NoSuchMethodException | SecurityException e) { + throw new RuntimeException(e); + } + + } + }); + /** * Generates a {@link OnTimerInvoker} class for the given {@link DoFnSignature} and {@link * TimerId}. */ diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/OnTimerMethodSpecifier.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/OnTimerMethodSpecifier.java new file mode 100644 index 0000000000000..edf7e3ccb0223 --- /dev/null +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/OnTimerMethodSpecifier.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.transforms.reflect; + +import com.google.auto.value.AutoValue; +import org.apache.beam.sdk.transforms.DoFn; + +/** + * Used by {@link ByteBuddyOnTimerInvokerFactory} to Dynamically generate + * {@link OnTimerInvoker} instances for invoking a particular + * {@link DoFn.TimerId} on a particular {@link DoFn}. + */ + +@AutoValue +abstract class OnTimerMethodSpecifier { + public abstract Class> fnClass(); + public abstract String timerId(); + public static OnTimerMethodSpecifier + forClassAndTimerId(Class> fnClass, String timerId){ + return new AutoValue_OnTimerMethodSpecifier(fnClass, timerId); + } +} From 35061e88066589d1dbfa81aa37fbb270274d70c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Isma=C3=ABl=20Mej=C3=ADa?= Date: Thu, 6 Jul 2017 17:37:33 +0200 Subject: [PATCH 171/200] Fix javadoc generation for AmqpIO, CassandraIO and HCatalogIO --- pom.xml | 18 ++++++++++++++++++ sdks/java/javadoc/pom.xml | 15 +++++++++++++++ 2 files changed, 33 insertions(+) diff --git a/pom.xml b/pom.xml index 01474c1a001c7..d9ab9ae9a1f7e 100644 --- a/pom.xml +++ b/pom.xml @@ -426,6 +426,18 @@ ${project.version} + + org.apache.beam + beam-sdks-java-io-amqp + ${project.version} + + + + org.apache.beam + beam-sdks-java-io-cassandra + ${project.version} + + org.apache.beam beam-sdks-java-io-elasticsearch @@ -463,6 +475,12 @@ ${project.version} + + org.apache.beam + beam-sdks-java-io-hcatalog + ${project.version} + + org.apache.beam beam-sdks-java-io-jdbc diff --git a/sdks/java/javadoc/pom.xml b/sdks/java/javadoc/pom.xml index ddb92cfb8e2b6..51109fbe1c943 100644 --- a/sdks/java/javadoc/pom.xml +++ b/sdks/java/javadoc/pom.xml @@ -97,6 +97,16 @@ beam-sdks-java-harness + + org.apache.beam + beam-sdks-java-io-amqp + + + + org.apache.beam + beam-sdks-java-io-cassandra + + org.apache.beam beam-sdks-java-io-elasticsearch @@ -122,6 +132,11 @@ beam-sdks-java-io-hbase + + org.apache.beam + beam-sdks-java-io-hcatalog + + org.apache.beam beam-sdks-java-io-jdbc From 3621b9555e6c446737e98a045c2f8a20b1e9c3ad Mon Sep 17 00:00:00 2001 From: Kenneth Knowles Date: Fri, 7 Jul 2017 08:49:08 -0700 Subject: [PATCH 172/200] Move DirectRunner knob for suppressing runner-determined sharding out of core SDK --- runners/direct-java/pom.xml | 2 +- .../beam/runners/direct/DirectRegistrar.java | 2 +- .../beam/runners/direct/DirectRunner.java | 5 +-- .../runners/direct/DirectTestOptions.java | 42 +++++++++++++++++++ .../runners/direct/DirectRegistrarTest.java | 2 +- .../beam/sdk/testing/TestPipelineOptions.java | 10 ----- 6 files changed, 47 insertions(+), 16 deletions(-) create mode 100644 runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectTestOptions.java diff --git a/runners/direct-java/pom.xml b/runners/direct-java/pom.xml index 0e1f73a4f3cb5..e14e8136c87a7 100644 --- a/runners/direct-java/pom.xml +++ b/runners/direct-java/pom.xml @@ -156,7 +156,7 @@ [ "--runner=DirectRunner", - "--unitTest" + "--runnerDeterminedSharding=false" ] diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRegistrar.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRegistrar.java index 0e6fbab888200..53fb2f24194d8 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRegistrar.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRegistrar.java @@ -50,7 +50,7 @@ public static class Options implements PipelineOptionsRegistrar { @Override public Iterable> getPipelineOptions() { return ImmutableList.>of( - DirectOptions.class); + DirectOptions.class, DirectTestOptions.class); } } } diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java index a16e24dc262bf..7a221c4cf5d59 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java @@ -43,7 +43,6 @@ import org.apache.beam.sdk.metrics.MetricsEnvironment; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.runners.PTransformOverride; -import org.apache.beam.sdk.testing.TestPipelineOptions; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.ParDo.MultiOutput; @@ -222,9 +221,9 @@ public DirectPipelineResult run(Pipeline pipeline) { @SuppressWarnings("rawtypes") @VisibleForTesting List defaultTransformOverrides() { - TestPipelineOptions testOptions = options.as(TestPipelineOptions.class); + DirectTestOptions testOptions = options.as(DirectTestOptions.class); ImmutableList.Builder builder = ImmutableList.builder(); - if (!testOptions.isUnitTest()) { + if (testOptions.isRunnerDeterminedSharding()) { builder.add( PTransformOverride.of( PTransformMatchers.writeWithRunnerDeterminedSharding(), diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectTestOptions.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectTestOptions.java new file mode 100644 index 0000000000000..a4264430613f3 --- /dev/null +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectTestOptions.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.direct; + +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.sdk.options.ApplicationNameOptions; +import org.apache.beam.sdk.options.Default; +import org.apache.beam.sdk.options.Description; +import org.apache.beam.sdk.options.Hidden; +import org.apache.beam.sdk.options.PipelineOptions; + +/** + * Internal-only options for tweaking the behavior of the {@link DirectRunner} in ways that users + * should never do. + * + *

    Currently, the only use is to disable user-friendly overrides that prevent fully testing + * certain composite transforms. + */ +@Internal +@Hidden +public interface DirectTestOptions extends PipelineOptions, ApplicationNameOptions { + @Default.Boolean(true) + @Description( + "Indicates whether this is an automatically-run unit test.") + boolean isRunnerDeterminedSharding(); + void setRunnerDeterminedSharding(boolean goAheadAndDetermineSharding); +} diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectRegistrarTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectRegistrarTest.java index 603e43e30f6be..4b909bc41fb81 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectRegistrarTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectRegistrarTest.java @@ -37,7 +37,7 @@ public class DirectRegistrarTest { @Test public void testCorrectOptionsAreReturned() { assertEquals( - ImmutableList.of(DirectOptions.class), + ImmutableList.of(DirectOptions.class, DirectTestOptions.class), new Options().getPipelineOptions()); } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/TestPipelineOptions.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/TestPipelineOptions.java index 904f3a2ff837d..206bc1f343c4b 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/TestPipelineOptions.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/TestPipelineOptions.java @@ -20,10 +20,8 @@ import com.fasterxml.jackson.annotation.JsonIgnore; import javax.annotation.Nullable; import org.apache.beam.sdk.PipelineResult; -import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.options.Default; import org.apache.beam.sdk.options.DefaultValueFactory; -import org.apache.beam.sdk.options.Hidden; import org.apache.beam.sdk.options.PipelineOptions; import org.hamcrest.BaseMatcher; import org.hamcrest.Description; @@ -52,14 +50,6 @@ public interface TestPipelineOptions extends PipelineOptions { Long getTestTimeoutSeconds(); void setTestTimeoutSeconds(Long value); - @Default.Boolean(false) - @Internal - @Hidden - @org.apache.beam.sdk.options.Description( - "Indicates whether this is an automatically-run unit test.") - boolean isUnitTest(); - void setUnitTest(boolean unitTest); - /** * Factory for {@link PipelineResult} matchers which always pass. */ From 175ff2fe873cacb11c0bb47c9812e1cd336ada5f Mon Sep 17 00:00:00 2001 From: Kenneth Knowles Date: Wed, 5 Jul 2017 17:24:25 -0700 Subject: [PATCH 173/200] Reject stateful ParDo if coder not KvCoder with deterministic key coder --- .../org/apache/beam/sdk/transforms/ParDo.java | 27 +++++ .../apache/beam/sdk/transforms/ParDoTest.java | 102 ++++++++++++++++++ 2 files changed, 129 insertions(+) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java index db1f7918e4fcd..0d03835bb7fb8 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java @@ -32,6 +32,7 @@ import org.apache.beam.sdk.coders.CannotProvideCoderException; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.CoderRegistry; +import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.state.StateSpec; import org.apache.beam.sdk.transforms.DoFn.WindowedContext; import org.apache.beam.sdk.transforms.display.DisplayData; @@ -455,6 +456,27 @@ private static void finishSpecifyingStateSpecs( } } + private static void validateStateApplicableForInput( + DoFn fn, + PCollection input) { + Coder inputCoder = input.getCoder(); + checkArgument( + inputCoder instanceof KvCoder, + "%s requires its input to use %s in order to use state and timers.", + ParDo.class.getSimpleName(), + KvCoder.class.getSimpleName()); + + KvCoder kvCoder = (KvCoder) inputCoder; + try { + kvCoder.getKeyCoder().verifyDeterministic(); + } catch (Coder.NonDeterministicException exc) { + throw new IllegalArgumentException( + String.format( + "%s requires a deterministic key coder in order to use state and timers", + ParDo.class.getSimpleName())); + } + } + /** * Try to provide coders for as many of the type arguments of given * {@link DoFnSignature.StateDeclaration} as possible. @@ -737,6 +759,11 @@ public PCollectionTuple expand(PCollection input) { // Use coder registry to determine coders for all StateSpec defined in the fn signature. finishSpecifyingStateSpecs(fn, input.getPipeline().getCoderRegistry(), input.getCoder()); + DoFnSignature signature = DoFnSignatures.getSignature(fn.getClass()); + if (signature.usesState() || signature.usesTimers()) { + validateStateApplicableForInput(fn, input); + } + PCollectionTuple outputs = PCollectionTuple.ofPrimitiveOutputsInternal( input.getPipeline(), TupleTagList.of(mainOutputTag).and(additionalOutputTags.getAll()), diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java index 5b60ef3ed03b1..fa4949e9501a3 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java @@ -1592,6 +1592,108 @@ public void processElement( pipeline.run(); } + @Test + public void testStateNotKeyed() { + final String stateId = "foo"; + + DoFn fn = + new DoFn() { + + @StateId(stateId) + private final StateSpec> intState = + StateSpecs.value(); + + @ProcessElement + public void processElement( + ProcessContext c, @StateId(stateId) ValueState state) {} + }; + + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("state"); + thrown.expectMessage("KvCoder"); + + pipeline.apply(Create.of("hello", "goodbye", "hello again")).apply(ParDo.of(fn)); + } + + @Test + public void testStateNotDeterministic() { + final String stateId = "foo"; + + // DoubleCoder is not deterministic, so this should crash + DoFn, Integer> fn = + new DoFn, Integer>() { + + @StateId(stateId) + private final StateSpec> intState = + StateSpecs.value(); + + @ProcessElement + public void processElement( + ProcessContext c, @StateId(stateId) ValueState state) {} + }; + + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("state"); + thrown.expectMessage("deterministic"); + + pipeline + .apply(Create.of(KV.of(1.0, "hello"), KV.of(5.4, "goodbye"), KV.of(7.2, "hello again"))) + .apply(ParDo.of(fn)); + } + + @Test + public void testTimerNotKeyed() { + final String timerId = "foo"; + + DoFn fn = + new DoFn() { + + @TimerId(timerId) + private final TimerSpec timer = TimerSpecs.timer(TimeDomain.EVENT_TIME); + + @ProcessElement + public void processElement( + ProcessContext c, @TimerId(timerId) Timer timer) {} + + @OnTimer(timerId) + public void onTimer() {} + }; + + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("timer"); + thrown.expectMessage("KvCoder"); + + pipeline.apply(Create.of("hello", "goodbye", "hello again")).apply(ParDo.of(fn)); + } + + @Test + public void testTimerNotDeterministic() { + final String timerId = "foo"; + + // DoubleCoder is not deterministic, so this should crash + DoFn, Integer> fn = + new DoFn, Integer>() { + + @TimerId(timerId) + private final TimerSpec timer = TimerSpecs.timer(TimeDomain.EVENT_TIME); + + @ProcessElement + public void processElement( + ProcessContext c, @TimerId(timerId) Timer timer) {} + + @OnTimer(timerId) + public void onTimer() {} + }; + + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("timer"); + thrown.expectMessage("deterministic"); + + pipeline + .apply(Create.of(KV.of(1.0, "hello"), KV.of(5.4, "goodbye"), KV.of(7.2, "hello again"))) + .apply(ParDo.of(fn)); + } + @Test @Category({ValidatesRunner.class, UsesStatefulParDo.class}) public void testValueStateCoderInference() { From 2295b905ecb055d9348170948e25f89665dd647d Mon Sep 17 00:00:00 2001 From: Luke Cwik Date: Fri, 23 Jun 2017 14:31:58 -0700 Subject: [PATCH 174/200] [BEAM-1347] Rename DoFnRunnerFactory to FnApiDoFnRunner. --- .../core/{DoFnRunnerFactory.java => FnApiDoFnRunner.java} | 2 +- ...oFnRunnerFactoryTest.java => FnApiDoFnRunnerTest.java} | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) rename sdks/java/harness/src/main/java/org/apache/beam/runners/core/{DoFnRunnerFactory.java => FnApiDoFnRunner.java} (99%) rename sdks/java/harness/src/test/java/org/apache/beam/runners/core/{DoFnRunnerFactoryTest.java => FnApiDoFnRunnerTest.java} (97%) diff --git a/sdks/java/harness/src/main/java/org/apache/beam/runners/core/DoFnRunnerFactory.java b/sdks/java/harness/src/main/java/org/apache/beam/runners/core/FnApiDoFnRunner.java similarity index 99% rename from sdks/java/harness/src/main/java/org/apache/beam/runners/core/DoFnRunnerFactory.java rename to sdks/java/harness/src/main/java/org/apache/beam/runners/core/FnApiDoFnRunner.java index 3c0b6ebcb408e..adf735ada1ddf 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/runners/core/DoFnRunnerFactory.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/runners/core/FnApiDoFnRunner.java @@ -54,7 +54,7 @@ *

    TODO: Move DoFnRunners into SDK harness and merge the methods below into it removing this * class. */ -public class DoFnRunnerFactory { +public class FnApiDoFnRunner { private static final String URN = "urn:org.apache.beam:dofn:java:0.1"; diff --git a/sdks/java/harness/src/test/java/org/apache/beam/runners/core/DoFnRunnerFactoryTest.java b/sdks/java/harness/src/test/java/org/apache/beam/runners/core/FnApiDoFnRunnerTest.java similarity index 97% rename from sdks/java/harness/src/test/java/org/apache/beam/runners/core/DoFnRunnerFactoryTest.java rename to sdks/java/harness/src/test/java/org/apache/beam/runners/core/FnApiDoFnRunnerTest.java index 62646ffa9710f..ae5cbacbec167 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/runners/core/DoFnRunnerFactoryTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/runners/core/FnApiDoFnRunnerTest.java @@ -62,9 +62,9 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; -/** Tests for {@link DoFnRunnerFactory}. */ +/** Tests for {@link FnApiDoFnRunner}. */ @RunWith(JUnit4.class) -public class DoFnRunnerFactoryTest { +public class FnApiDoFnRunnerTest { private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); private static final Coder> STRING_CODER = @@ -155,7 +155,7 @@ public void testCreatingAndProcessingDoFn() throws Exception { List startFunctions = new ArrayList<>(); List finishFunctions = new ArrayList<>(); - new DoFnRunnerFactory.Factory<>().createRunnerForPTransform( + new FnApiDoFnRunner.Factory<>().createRunnerForPTransform( PipelineOptionsFactory.create(), null /* beamFnDataClient */, pTransformId, @@ -199,7 +199,7 @@ public void testCreatingAndProcessingDoFn() throws Exception { public void testRegistration() { for (Registrar registrar : ServiceLoader.load(Registrar.class)) { - if (registrar instanceof DoFnRunnerFactory.Registrar) { + if (registrar instanceof FnApiDoFnRunner.Registrar) { assertThat(registrar.getPTransformRunnerFactories(), IsMapContaining.hasKey(URN)); return; } From 7da08c981dcf49f91595a1b78abbaeb84ccbf287 Mon Sep 17 00:00:00 2001 From: Luke Cwik Date: Fri, 23 Jun 2017 14:34:36 -0700 Subject: [PATCH 175/200] [BEAM-1347] Add DoFnRunner specific to Fn Api. --- sdks/java/harness/pom.xml | 10 + .../beam/runners/core/FnApiDoFnRunner.java | 483 +++++++++++++++--- .../runners/core/FnApiDoFnRunnerTest.java | 7 +- 3 files changed, 438 insertions(+), 62 deletions(-) diff --git a/sdks/java/harness/pom.xml b/sdks/java/harness/pom.xml index 9cfadc215edac..fe5c2f1c0c060 100644 --- a/sdks/java/harness/pom.xml +++ b/sdks/java/harness/pom.xml @@ -81,6 +81,11 @@ beam-runners-core-java + + org.apache.beam + beam-runners-core-construction-java + + org.apache.beam beam-runners-google-cloud-dataflow-java @@ -149,6 +154,11 @@ linux-x86_64 + + joda-time + joda-time + + org.slf4j slf4j-api diff --git a/sdks/java/harness/src/main/java/org/apache/beam/runners/core/FnApiDoFnRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/runners/core/FnApiDoFnRunner.java index adf735ada1ddf..b3cf3a76de985 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/runners/core/FnApiDoFnRunner.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/runners/core/FnApiDoFnRunner.java @@ -27,49 +27,59 @@ import com.google.protobuf.ByteString; import com.google.protobuf.BytesValue; import com.google.protobuf.InvalidProtocolBufferException; -import java.util.ArrayList; import java.util.Collection; import java.util.HashSet; +import java.util.Iterator; import java.util.Map; import java.util.Objects; import java.util.function.Consumer; import java.util.function.Supplier; import org.apache.beam.fn.harness.data.BeamFnDataClient; -import org.apache.beam.fn.harness.fake.FakeStepContext; import org.apache.beam.fn.harness.fn.ThrowingConsumer; import org.apache.beam.fn.harness.fn.ThrowingRunnable; -import org.apache.beam.runners.core.DoFnRunners.OutputManager; +import org.apache.beam.runners.core.construction.ParDoTranslation; import org.apache.beam.runners.dataflow.util.DoFnInfo; import org.apache.beam.sdk.common.runner.v1.RunnerApi; import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.state.State; +import org.apache.beam.sdk.state.TimeDomain; +import org.apache.beam.sdk.state.Timer; import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.DoFn.OnTimerContext; +import org.apache.beam.sdk.transforms.DoFn.ProcessContext; +import org.apache.beam.sdk.transforms.reflect.DoFnInvoker; +import org.apache.beam.sdk.transforms.reflect.DoFnInvokers; +import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.PaneInfo; import org.apache.beam.sdk.util.SerializableUtils; +import org.apache.beam.sdk.util.UserCodeException; import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.WindowingStrategy; +import org.joda.time.Instant; /** - * Classes associated with converting {@link RunnerApi.PTransform}s to {@link DoFnRunner}s. - * - *

    TODO: Move DoFnRunners into SDK harness and merge the methods below into it removing this - * class. + * A {@link DoFnRunner} specific to integrating with the Fn Api. This is to remove the layers + * of abstraction caused by StateInternals/TimerInternals since they model state and timer + * concepts differently. */ -public class FnApiDoFnRunner { - - private static final String URN = "urn:org.apache.beam:dofn:java:0.1"; - - /** A registrar which provides a factory to handle Java {@link DoFn}s. */ +public class FnApiDoFnRunner implements DoFnRunner { + /** + * A registrar which provides a factory to handle Java {@link DoFn}s. + */ @AutoService(PTransformRunnerFactory.Registrar.class) public static class Registrar implements PTransformRunnerFactory.Registrar { @Override public Map getPTransformRunnerFactories() { - return ImmutableMap.of(URN, new Factory()); + return ImmutableMap.of(ParDoTranslation.CUSTOM_JAVA_DO_FN_URN, new Factory()); } } - /** A factory for {@link DoFnRunner}s. */ + /** A factory for {@link FnApiDoFnRunner}. */ static class Factory implements PTransformRunnerFactory> { @@ -105,9 +115,9 @@ public DoFnRunner createRunnerForPTransform( throw new IllegalArgumentException( String.format("Unable to unwrap DoFn %s", pTransform.getSpec()), e); } - DoFnInfo doFnInfo = - (DoFnInfo) - SerializableUtils.deserializeFromByteArray(serializedFn.toByteArray(), "DoFnInfo"); + @SuppressWarnings({"unchecked", "rawtypes"}) + DoFnInfo doFnInfo = (DoFnInfo) SerializableUtils.deserializeFromByteArray( + serializedFn.toByteArray(), "DoFnInfo"); // Verify that the DoFnInfo tag to output map matches the output map on the PTransform. checkArgument( @@ -119,54 +129,26 @@ public DoFnRunner createRunnerForPTransform( doFnInfo.getOutputMap()); ImmutableMultimap.Builder, - ThrowingConsumer>> tagToOutput = + ThrowingConsumer>> tagToOutputMapBuilder = ImmutableMultimap.builder(); for (Map.Entry> entry : doFnInfo.getOutputMap().entrySet()) { @SuppressWarnings({"unchecked", "rawtypes"}) - Collection>> consumers = - (Collection) outputMap.get(Long.toString(entry.getKey())); - tagToOutput.putAll(entry.getValue(), consumers); + Collection>> consumers = + outputMap.get(Long.toString(entry.getKey())); + tagToOutputMapBuilder.putAll(entry.getValue(), consumers); } + ImmutableMultimap, ThrowingConsumer>> tagToOutputMap = + tagToOutputMapBuilder.build(); + @SuppressWarnings({"unchecked", "rawtypes"}) - Map, Collection>>> tagBasedOutputMap = - (Map) tagToOutput.build().asMap(); - - OutputManager outputManager = - new OutputManager() { - Map, Collection>>> tupleTagToOutput = - tagBasedOutputMap; - - @Override - public void output(TupleTag tag, WindowedValue output) { - try { - Collection>> consumers = - tupleTagToOutput.get(tag); - if (consumers == null) { - /* This is a normal case, e.g., if a DoFn has output but that output is not - * consumed. Drop the output. */ - return; - } - for (ThrowingConsumer> consumer : consumers) { - consumer.accept(output); - } - } catch (Throwable t) { - throw new RuntimeException(t); - } - } - }; - - @SuppressWarnings({"unchecked", "rawtypes", "deprecation"}) - DoFnRunner runner = - DoFnRunners.simpleRunner( - pipelineOptions, - (DoFn) doFnInfo.getDoFn(), - NullSideInputReader.empty(), /* TODO */ - outputManager, - (TupleTag) doFnInfo.getOutputMap().get(doFnInfo.getMainOutput()), - new ArrayList<>(doFnInfo.getOutputMap().values()), - new FakeStepContext(), - (WindowingStrategy) doFnInfo.getWindowingStrategy()); + DoFnRunner runner = new FnApiDoFnRunner<>( + pipelineOptions, + doFnInfo.getDoFn(), + (Collection>>) (Collection) + tagToOutputMap.get(doFnInfo.getOutputMap().get(doFnInfo.getMainOutput())), + tagToOutputMap, + doFnInfo.getWindowingStrategy()); // Register the appropriate handlers. addStartFunction.accept(runner::startBundle); @@ -179,4 +161,387 @@ public void output(TupleTag tag, WindowedValue output) { return runner; } } + + ////////////////////////////////////////////////////////////////////////////////////////////////// + + private final PipelineOptions pipelineOptions; + private final DoFn doFn; + private final Collection>> mainOutputConsumers; + private final Multimap, ThrowingConsumer>> outputMap; + private final DoFnInvoker doFnInvoker; + private final StartBundleContext startBundleContext; + private final ProcessBundleContext processBundleContext; + private final FinishBundleContext finishBundleContext; + + /** + * The lifetime of this member is only valid during {@link #processElement(WindowedValue)}. + */ + private WindowedValue currentElement; + + /** + * The lifetime of this member is only valid during {@link #processElement(WindowedValue)}. + */ + private BoundedWindow currentWindow; + + FnApiDoFnRunner( + PipelineOptions pipelineOptions, + DoFn doFn, + Collection>> mainOutputConsumers, + Multimap, ThrowingConsumer>> outputMap, + WindowingStrategy windowingStrategy) { + this.pipelineOptions = pipelineOptions; + this.doFn = doFn; + this.mainOutputConsumers = mainOutputConsumers; + this.outputMap = outputMap; + this.doFnInvoker = DoFnInvokers.invokerFor(doFn); + this.startBundleContext = new StartBundleContext(); + this.processBundleContext = new ProcessBundleContext(); + this.finishBundleContext = new FinishBundleContext(); + } + + @Override + public void startBundle() { + doFnInvoker.invokeStartBundle(startBundleContext); + } + + @Override + public void processElement(WindowedValue elem) { + currentElement = elem; + try { + Iterator windowIterator = + (Iterator) elem.getWindows().iterator(); + while (windowIterator.hasNext()) { + currentWindow = windowIterator.next(); + doFnInvoker.invokeProcessElement(processBundleContext); + } + } finally { + currentElement = null; + currentWindow = null; + } + } + + @Override + public void onTimer( + String timerId, + BoundedWindow window, + Instant timestamp, + TimeDomain timeDomain) { + throw new UnsupportedOperationException("TODO: Add support for timers"); + } + + @Override + public void finishBundle() { + doFnInvoker.invokeFinishBundle(finishBundleContext); + } + + /** + * Outputs the given element to the specified set of consumers wrapping any exceptions. + */ + private void outputTo( + Collection>> consumers, + WindowedValue output) { + Iterator>> consumerIterator; + try { + for (ThrowingConsumer> consumer : consumers) { + consumer.accept(output); + } + } catch (Throwable t) { + throw UserCodeException.wrap(t); + } + } + + /** + * Provides arguments for a {@link DoFnInvoker} for {@link DoFn.StartBundle @StartBundle}. + */ + private class StartBundleContext + extends DoFn.StartBundleContext + implements DoFnInvoker.ArgumentProvider { + + private StartBundleContext() { + doFn.super(); + } + + @Override + public PipelineOptions getPipelineOptions() { + return pipelineOptions; + } + + @Override + public PipelineOptions pipelineOptions() { + return pipelineOptions; + } + + @Override + public BoundedWindow window() { + throw new UnsupportedOperationException( + "Cannot access window outside of @ProcessElement and @OnTimer methods."); + } + + @Override + public DoFn.StartBundleContext startBundleContext( + DoFn doFn) { + return this; + } + + @Override + public DoFn.FinishBundleContext finishBundleContext( + DoFn doFn) { + throw new UnsupportedOperationException( + "Cannot access FinishBundleContext outside of @FinishBundle method."); + } + + @Override + public DoFn.ProcessContext processContext(DoFn doFn) { + throw new UnsupportedOperationException( + "Cannot access ProcessContext outside of @ProcessElement method."); + } + + @Override + public DoFn.OnTimerContext onTimerContext(DoFn doFn) { + throw new UnsupportedOperationException( + "Cannot access OnTimerContext outside of @OnTimer methods."); + } + + @Override + public RestrictionTracker restrictionTracker() { + throw new UnsupportedOperationException( + "Cannot access RestrictionTracker outside of @ProcessElement method."); + } + + @Override + public State state(String stateId) { + throw new UnsupportedOperationException( + "Cannot access state outside of @ProcessElement and @OnTimer methods."); + } + + @Override + public Timer timer(String timerId) { + throw new UnsupportedOperationException( + "Cannot access timers outside of @ProcessElement and @OnTimer methods."); + } + } + + /** + * Provides arguments for a {@link DoFnInvoker} for {@link DoFn.ProcessElement @ProcessElement}. + */ + private class ProcessBundleContext + extends DoFn.ProcessContext + implements DoFnInvoker.ArgumentProvider { + + private ProcessBundleContext() { + doFn.super(); + } + + @Override + public BoundedWindow window() { + return currentWindow; + } + + @Override + public DoFn.StartBundleContext startBundleContext(DoFn doFn) { + throw new UnsupportedOperationException( + "Cannot access StartBundleContext outside of @StartBundle method."); + } + + @Override + public DoFn.FinishBundleContext finishBundleContext(DoFn doFn) { + throw new UnsupportedOperationException( + "Cannot access FinishBundleContext outside of @FinishBundle method."); + } + + @Override + public ProcessContext processContext(DoFn doFn) { + return this; + } + + @Override + public OnTimerContext onTimerContext(DoFn doFn) { + throw new UnsupportedOperationException("TODO: Add support for timers"); + } + + @Override + public RestrictionTracker restrictionTracker() { + throw new UnsupportedOperationException("TODO: Add support for SplittableDoFn"); + } + + @Override + public State state(String stateId) { + throw new UnsupportedOperationException("TODO: Add support for state"); + } + + @Override + public Timer timer(String timerId) { + throw new UnsupportedOperationException("TODO: Add support for timers"); + } + + @Override + public PipelineOptions getPipelineOptions() { + return pipelineOptions; + } + + @Override + public PipelineOptions pipelineOptions() { + return pipelineOptions; + } + + @Override + public void output(OutputT output) { + outputTo(mainOutputConsumers, + WindowedValue.of( + output, + currentElement.getTimestamp(), + currentWindow, + currentElement.getPane())); + } + + @Override + public void outputWithTimestamp(OutputT output, Instant timestamp) { + outputTo(mainOutputConsumers, + WindowedValue.of( + output, + timestamp, + currentWindow, + currentElement.getPane())); + } + + @Override + public void output(TupleTag tag, T output) { + Collection>> consumers = (Collection) outputMap.get(tag); + if (consumers == null) { + throw new IllegalArgumentException(String.format("Unknown output tag %s", tag)); + } + outputTo(consumers, + WindowedValue.of( + output, + currentElement.getTimestamp(), + currentWindow, + currentElement.getPane())); + } + + @Override + public void outputWithTimestamp(TupleTag tag, T output, Instant timestamp) { + Collection>> consumers = (Collection) outputMap.get(tag); + if (consumers == null) { + throw new IllegalArgumentException(String.format("Unknown output tag %s", tag)); + } + outputTo(consumers, + WindowedValue.of( + output, + timestamp, + currentWindow, + currentElement.getPane())); + } + + @Override + public InputT element() { + return currentElement.getValue(); + } + + @Override + public T sideInput(PCollectionView view) { + throw new UnsupportedOperationException("TODO: Support side inputs"); + } + + @Override + public Instant timestamp() { + return currentElement.getTimestamp(); + } + + @Override + public PaneInfo pane() { + return currentElement.getPane(); + } + + @Override + public void updateWatermark(Instant watermark) { + throw new UnsupportedOperationException("TODO: Add support for SplittableDoFn"); + } + } + + /** + * Provides arguments for a {@link DoFnInvoker} for {@link DoFn.FinishBundle @FinishBundle}. + */ + private class FinishBundleContext + extends DoFn.FinishBundleContext + implements DoFnInvoker.ArgumentProvider { + + private FinishBundleContext() { + doFn.super(); + } + + @Override + public PipelineOptions getPipelineOptions() { + return pipelineOptions; + } + + @Override + public PipelineOptions pipelineOptions() { + return pipelineOptions; + } + + @Override + public BoundedWindow window() { + throw new UnsupportedOperationException( + "Cannot access window outside of @ProcessElement and @OnTimer methods."); + } + + @Override + public DoFn.StartBundleContext startBundleContext( + DoFn doFn) { + throw new UnsupportedOperationException( + "Cannot access StartBundleContext outside of @StartBundle method."); + } + + @Override + public DoFn.FinishBundleContext finishBundleContext( + DoFn doFn) { + return this; + } + + @Override + public DoFn.ProcessContext processContext(DoFn doFn) { + throw new UnsupportedOperationException( + "Cannot access ProcessContext outside of @ProcessElement method."); + } + + @Override + public DoFn.OnTimerContext onTimerContext(DoFn doFn) { + throw new UnsupportedOperationException( + "Cannot access OnTimerContext outside of @OnTimer methods."); + } + + @Override + public RestrictionTracker restrictionTracker() { + throw new UnsupportedOperationException( + "Cannot access RestrictionTracker outside of @ProcessElement method."); + } + + @Override + public State state(String stateId) { + throw new UnsupportedOperationException( + "Cannot access state outside of @ProcessElement and @OnTimer methods."); + } + + @Override + public Timer timer(String timerId) { + throw new UnsupportedOperationException( + "Cannot access timers outside of @ProcessElement and @OnTimer methods."); + } + + @Override + public void output(OutputT output, Instant timestamp, BoundedWindow window) { + outputTo(mainOutputConsumers, + WindowedValue.of(output, timestamp, window, PaneInfo.NO_FIRING)); + } + + @Override + public void output(TupleTag tag, T output, Instant timestamp, BoundedWindow window) { + Collection>> consumers = (Collection) outputMap.get(tag); + if (consumers == null) { + throw new IllegalArgumentException(String.format("Unknown output tag %s", tag)); + } + outputTo(consumers, + WindowedValue.of(output, timestamp, window, PaneInfo.NO_FIRING)); + } + } } diff --git a/sdks/java/harness/src/test/java/org/apache/beam/runners/core/FnApiDoFnRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/runners/core/FnApiDoFnRunnerTest.java index ae5cbacbec167..c4df77af8b29d 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/runners/core/FnApiDoFnRunnerTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/runners/core/FnApiDoFnRunnerTest.java @@ -44,6 +44,7 @@ import org.apache.beam.fn.harness.fn.ThrowingConsumer; import org.apache.beam.fn.harness.fn.ThrowingRunnable; import org.apache.beam.runners.core.PTransformRunnerFactory.Registrar; +import org.apache.beam.runners.core.construction.ParDoTranslation; import org.apache.beam.runners.dataflow.util.CloudObjects; import org.apache.beam.runners.dataflow.util.DoFnInfo; import org.apache.beam.sdk.coders.Coder; @@ -71,7 +72,6 @@ public class FnApiDoFnRunnerTest { WindowedValue.getFullCoder(StringUtf8Coder.of(), GlobalWindow.Coder.INSTANCE); private static final String STRING_CODER_SPEC_ID = "999L"; private static final RunnerApi.Coder STRING_CODER_SPEC; - private static final String URN = "urn:org.apache.beam:dofn:java:0.1"; static { try { @@ -132,7 +132,7 @@ public void testCreatingAndProcessingDoFn() throws Exception { Long.parseLong(mainOutputId), TestDoFn.mainOutput, Long.parseLong(additionalOutputId), TestDoFn.additionalOutput)); RunnerApi.FunctionSpec functionSpec = RunnerApi.FunctionSpec.newBuilder() - .setUrn("urn:org.apache.beam:dofn:java:0.1") + .setUrn(ParDoTranslation.CUSTOM_JAVA_DO_FN_URN) .setParameter(Any.pack(BytesValue.newBuilder() .setValue(ByteString.copyFrom(SerializableUtils.serializeToByteArray(doFnInfo))) .build())) @@ -200,7 +200,8 @@ public void testRegistration() { for (Registrar registrar : ServiceLoader.load(Registrar.class)) { if (registrar instanceof FnApiDoFnRunner.Registrar) { - assertThat(registrar.getPTransformRunnerFactories(), IsMapContaining.hasKey(URN)); + assertThat(registrar.getPTransformRunnerFactories(), + IsMapContaining.hasKey(ParDoTranslation.CUSTOM_JAVA_DO_FN_URN)); return; } } From 1fa3bfe92bc59a85bfcf12c47c68206757ce238a Mon Sep 17 00:00:00 2001 From: Valentyn Tymofieiev Date: Fri, 7 Jul 2017 15:14:56 -0700 Subject: [PATCH 176/200] Set the type of batch jobs to FNAPI_BATCH when beam_fn_api experiment is specified. --- .../runners/dataflow/dataflow_runner.py | 16 ++-------- .../runners/dataflow/internal/apiclient.py | 29 +++++++++++++++++-- .../dataflow/internal/apiclient_test.py | 5 +--- 3 files changed, 29 insertions(+), 21 deletions(-) diff --git a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py b/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py index 57bcc5e8cdda0..059e139020c93 100644 --- a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py +++ b/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py @@ -46,8 +46,8 @@ from apache_beam.runners.runner import PipelineState from apache_beam.transforms.display import DisplayData from apache_beam.typehints import typehints -from apache_beam.options.pipeline_options import StandardOptions from apache_beam.options.pipeline_options import SetupOptions +from apache_beam.options.pipeline_options import StandardOptions from apache_beam.options.pipeline_options import TestOptions from apache_beam.utils.plugin import BeamPlugin @@ -65,12 +65,6 @@ class DataflowRunner(PipelineRunner): if blocking is set to False. """ - # Environment version information. It is passed to the service during a - # a job submission and is used by the service to establish what features - # are expected by the workers. - BATCH_ENVIRONMENT_MAJOR_VERSION = '6' - STREAMING_ENVIRONMENT_MAJOR_VERSION = '1' - # A list of PTransformOverride objects to be applied before running a pipeline # using DataflowRunner. # Currently this only works for overrides where the input and output types do @@ -268,15 +262,9 @@ def run(self, pipeline): if test_options.dry_run: return None - standard_options = pipeline._options.view_as(StandardOptions) - if standard_options.streaming: - job_version = DataflowRunner.STREAMING_ENVIRONMENT_MAJOR_VERSION - else: - job_version = DataflowRunner.BATCH_ENVIRONMENT_MAJOR_VERSION - # Get a Dataflow API client and set its options self.dataflow_client = apiclient.DataflowApplicationClient( - pipeline._options, job_version) + pipeline._options) # Create the job result = DataflowPipelineResult( diff --git a/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py b/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py index edac9d7d55858..33dfe19529a5d 100644 --- a/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py +++ b/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py @@ -49,6 +49,13 @@ from apache_beam.options.pipeline_options import WorkerOptions +# Environment version information. It is passed to the service during a +# a job submission and is used by the service to establish what features +# are expected by the workers. +_LEGACY_ENVIRONMENT_MAJOR_VERSION = '6' +_FNAPI_ENVIRONMENT_MAJOR_VERSION = '1' + + class Step(object): """Wrapper for a dataflow Step protobuf.""" @@ -146,7 +153,10 @@ def __init__(self, packages, options, environment_version): if self.standard_options.streaming: job_type = 'FNAPI_STREAMING' else: - job_type = 'PYTHON_BATCH' + if _use_fnapi(options): + job_type = 'FNAPI_BATCH' + else: + job_type = 'PYTHON_BATCH' self.proto.version.additionalProperties.extend([ dataflow.Environment.VersionValue.AdditionalProperty( key='job_type', @@ -360,11 +370,16 @@ def __reduce__(self): class DataflowApplicationClient(object): """A Dataflow API client used by application code to create and query jobs.""" - def __init__(self, options, environment_version): + def __init__(self, options): """Initializes a Dataflow API client object.""" self.standard_options = options.view_as(StandardOptions) self.google_cloud_options = options.view_as(GoogleCloudOptions) - self.environment_version = environment_version + + if _use_fnapi(options): + self.environment_version = _FNAPI_ENVIRONMENT_MAJOR_VERSION + else: + self.environment_version = _LEGACY_ENVIRONMENT_MAJOR_VERSION + if self.google_cloud_options.no_auth: credentials = None else: @@ -706,6 +721,14 @@ def translate_mean(accumulator, metric_update): metric_update.kind = None +def _use_fnapi(pipeline_options): + standard_options = pipeline_options.view_as(StandardOptions) + debug_options = pipeline_options.view_as(DebugOptions) + + return standard_options.streaming or ( + debug_options.experiments and 'beam_fn_api' in debug_options.experiments) + + # To enable a counter on the service, add it to this dictionary. metric_translations = { cy_combiners.CountCombineFn: ('sum', translate_scalar), diff --git a/sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py b/sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py index 55211f7588aae..407ffcf2ad722 100644 --- a/sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py +++ b/sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py @@ -22,7 +22,6 @@ from apache_beam.metrics.cells import DistributionData from apache_beam.options.pipeline_options import PipelineOptions -from apache_beam.runners.dataflow.dataflow_runner import DataflowRunner from apache_beam.runners.dataflow.internal.clients import dataflow # Protect against environments where apitools library is not available. @@ -40,9 +39,7 @@ class UtilTest(unittest.TestCase): @unittest.skip("Enable once BEAM-1080 is fixed.") def test_create_application_client(self): pipeline_options = PipelineOptions() - apiclient.DataflowApplicationClient( - pipeline_options, - DataflowRunner.BATCH_ENVIRONMENT_MAJOR_VERSION) + apiclient.DataflowApplicationClient(pipeline_options) def test_set_network(self): pipeline_options = PipelineOptions( From 25d3baae4baf661c8da13465774fc5cf8988291a Mon Sep 17 00:00:00 2001 From: Mark Liu Date: Fri, 7 Jul 2017 15:20:12 -0700 Subject: [PATCH 177/200] [BEAM-2570] Fix breakage after cloud-bigquery updated --- sdks/python/apache_beam/io/gcp/tests/bigquery_matcher.py | 6 +++--- .../apache_beam/io/gcp/tests/bigquery_matcher_test.py | 2 +- sdks/python/setup.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/sdks/python/apache_beam/io/gcp/tests/bigquery_matcher.py b/sdks/python/apache_beam/io/gcp/tests/bigquery_matcher.py index 844cbc5fa2248..d6f0e97298c5f 100644 --- a/sdks/python/apache_beam/io/gcp/tests/bigquery_matcher.py +++ b/sdks/python/apache_beam/io/gcp/tests/bigquery_matcher.py @@ -92,9 +92,9 @@ def _query_with_retry(self, bigquery_client): page_token = None results = [] while True: - rows, _, page_token = query.fetch_data(page_token=page_token) - results.extend(rows) - if not page_token: + for row in query.fetch_data(page_token=page_token): + results.append(row) + if results: break return results diff --git a/sdks/python/apache_beam/io/gcp/tests/bigquery_matcher_test.py b/sdks/python/apache_beam/io/gcp/tests/bigquery_matcher_test.py index f12293e491c87..5b722856a7b90 100644 --- a/sdks/python/apache_beam/io/gcp/tests/bigquery_matcher_test.py +++ b/sdks/python/apache_beam/io/gcp/tests/bigquery_matcher_test.py @@ -53,7 +53,7 @@ def test_bigquery_matcher_success(self, mock_bigquery): matcher = bq_verifier.BigqueryMatcher( 'mock_project', 'mock_query', - 'da39a3ee5e6b4b0d3255bfef95601890afd80709') + '59f9d6bdee30d67ea73b8aded121c3a0280f9cd8') hc_assert_that(self._mock_result, matcher) @patch.object(bigquery, 'Client') diff --git a/sdks/python/setup.py b/sdks/python/setup.py index 8a0c9aefab16d..da82466822e77 100644 --- a/sdks/python/setup.py +++ b/sdks/python/setup.py @@ -122,7 +122,7 @@ def get_version(): 'googledatastore==7.0.1', 'google-cloud-pubsub==0.26.0', # GCP packages required by tests - 'google-cloud-bigquery>=0.23.0,<0.26.0', + 'google-cloud-bigquery==0.25.0', ] From cb5061e7149519cb18673f4c572757dce3cc7bd1 Mon Sep 17 00:00:00 2001 From: Thomas Weise Date: Sun, 9 Jul 2017 11:57:43 -0700 Subject: [PATCH 178/200] BEAM-2575 ApexRunner doesn't emit watermarks for additional outputs --- .../operators/ApexParDoOperator.java | 21 ++++++++++++------- .../runners/apex/examples/WordCountTest.java | 8 +++++-- 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/operators/ApexParDoOperator.java b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/operators/ApexParDoOperator.java index 809ca2a166c50..c3cbab2c54987 100644 --- a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/operators/ApexParDoOperator.java +++ b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/operators/ApexParDoOperator.java @@ -359,10 +359,7 @@ private void processWatermark(ApexStreamTuple.WatermarkTuple mark) { } } if (sideInputs.isEmpty()) { - if (traceTuples) { - LOG.debug("\nemitting watermark {}\n", mark); - } - output.emit(mark); + outputWatermark(mark); return; } @@ -370,10 +367,20 @@ private void processWatermark(ApexStreamTuple.WatermarkTuple mark) { Math.min(pushedBackWatermark.get(), currentInputWatermark); if (potentialOutputWatermark > currentOutputWatermark) { currentOutputWatermark = potentialOutputWatermark; - if (traceTuples) { - LOG.debug("\nemitting watermark {}\n", currentOutputWatermark); + outputWatermark(ApexStreamTuple.WatermarkTuple.of(currentOutputWatermark)); + } + } + + private void outputWatermark(ApexStreamTuple.WatermarkTuple mark) { + if (traceTuples) { + LOG.debug("\nemitting {}\n", mark); + } + output.emit(mark); + if (!additionalOutputPortMapping.isEmpty()) { + for (DefaultOutputPort> additionalOutput : + additionalOutputPortMapping.values()) { + additionalOutput.emit(mark); } - output.emit(ApexStreamTuple.WatermarkTuple.of(currentOutputWatermark)); } } diff --git a/runners/apex/src/test/java/org/apache/beam/runners/apex/examples/WordCountTest.java b/runners/apex/src/test/java/org/apache/beam/runners/apex/examples/WordCountTest.java index e76096ef78d1f..ba757468b031b 100644 --- a/runners/apex/src/test/java/org/apache/beam/runners/apex/examples/WordCountTest.java +++ b/runners/apex/src/test/java/org/apache/beam/runners/apex/examples/WordCountTest.java @@ -123,11 +123,15 @@ public void testWordCountExample() throws Exception { options.setInputFile(new File(inputFile).getAbsolutePath()); String outputFilePrefix = "target/wordcountresult.txt"; options.setOutput(outputFilePrefix); - WordCountTest.main(TestPipeline.convertToArgs(options)); File outFile1 = new File(outputFilePrefix + "-00000-of-00002"); File outFile2 = new File(outputFilePrefix + "-00001-of-00002"); - Assert.assertTrue(outFile1.exists() && outFile2.exists()); + Assert.assertTrue(!outFile1.exists() || outFile1.delete()); + Assert.assertTrue(!outFile2.exists() || outFile2.delete()); + + WordCountTest.main(TestPipeline.convertToArgs(options)); + + Assert.assertTrue("result files exist", outFile1.exists() && outFile2.exists()); HashSet results = new HashSet<>(); results.addAll(FileUtils.readLines(outFile1)); results.addAll(FileUtils.readLines(outFile2)); From c6f9fdeadaeda68be86e454377f8c665c22a7c0f Mon Sep 17 00:00:00 2001 From: Thomas Groh Date: Tue, 27 Jun 2017 15:03:11 -0700 Subject: [PATCH 179/200] Reflect #assignsToOneWindow in WindowingStrategy --- .../core/construction/WindowingStrategyTranslation.java | 1 + .../core/construction/WindowingStrategyTranslationTest.java | 3 +++ sdks/common/runner-api/src/main/proto/beam_runner_api.proto | 5 +++++ 3 files changed, 9 insertions(+) diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/WindowingStrategyTranslation.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/WindowingStrategyTranslation.java index 88ebc01b1df8d..1456a3fe80a98 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/WindowingStrategyTranslation.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/WindowingStrategyTranslation.java @@ -307,6 +307,7 @@ public static RunnerApi.WindowingStrategy toProto( .setAllowedLateness(windowingStrategy.getAllowedLateness().getMillis()) .setTrigger(TriggerTranslation.toProto(windowingStrategy.getTrigger())) .setWindowFn(windowFnSpec) + .setAssignsToOneWindow(windowingStrategy.getWindowFn().assignsToOneWindow()) .setWindowCoderId( components.registerCoder(windowingStrategy.getWindowFn().windowCoder())); diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/WindowingStrategyTranslationTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/WindowingStrategyTranslationTest.java index e406545467876..7a57fd7ac6e90 100644 --- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/WindowingStrategyTranslationTest.java +++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/WindowingStrategyTranslationTest.java @@ -116,5 +116,8 @@ public void testToProtoAndBackWithComponents() throws Exception { protoComponents.getCodersOrThrow( components.registerCoder(windowingStrategy.getWindowFn().windowCoder())); + assertThat( + proto.getAssignsToOneWindow(), + equalTo(windowingStrategy.getWindowFn().assignsToOneWindow())); } } diff --git a/sdks/common/runner-api/src/main/proto/beam_runner_api.proto b/sdks/common/runner-api/src/main/proto/beam_runner_api.proto index 24e907a72dad3..93fea44785522 100644 --- a/sdks/common/runner-api/src/main/proto/beam_runner_api.proto +++ b/sdks/common/runner-api/src/main/proto/beam_runner_api.proto @@ -436,6 +436,11 @@ message WindowingStrategy { // (Required) Indicate whether empty on-time panes should be omitted. OnTimeBehavior OnTimeBehavior = 9; + + // (Required) Whether or not the window fn assigns inputs to exactly one window + // + // This knowledge is required for some optimizations + bool assigns_to_one_window = 10; } // Whether or not a PCollection's WindowFn is non-merging, merging, or From 311547aa561bb314a8fe743b6f4677a2eaaaca50 Mon Sep 17 00:00:00 2001 From: Kenneth Knowles Date: Mon, 10 Jul 2017 15:25:11 -0700 Subject: [PATCH 180/200] Use URNs, not Java classes, in immutability enforcements --- .../beam/runners/direct/DirectRunner.java | 21 +++++++------------ .../ExecutorServiceParallelExecutor.java | 16 ++++++-------- 2 files changed, 14 insertions(+), 23 deletions(-) diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java index 7a221c4cf5d59..46212246c2ff0 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java @@ -38,14 +38,11 @@ import org.apache.beam.sdk.Pipeline.PipelineExecutionException; import org.apache.beam.sdk.PipelineResult; import org.apache.beam.sdk.PipelineRunner; -import org.apache.beam.sdk.io.Read; import org.apache.beam.sdk.metrics.MetricResults; import org.apache.beam.sdk.metrics.MetricsEnvironment; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.runners.PTransformOverride; import org.apache.beam.sdk.transforms.PTransform; -import org.apache.beam.sdk.transforms.ParDo; -import org.apache.beam.sdk.transforms.ParDo.MultiOutput; import org.apache.beam.sdk.util.UserCodeException; import org.apache.beam.sdk.values.PCollection; import org.joda.time.Duration; @@ -72,16 +69,17 @@ public boolean appliesTo(PCollection collection, DirectGraph graph) { IMMUTABILITY { @Override public boolean appliesTo(PCollection collection, DirectGraph graph) { - return CONTAINS_UDF.contains(graph.getProducer(collection).getTransform().getClass()); + return CONTAINS_UDF.contains( + PTransformTranslation.urnForTransform(graph.getProducer(collection).getTransform())); } }; /** * The set of {@link PTransform PTransforms} that execute a UDF. Useful for some enforcements. */ - private static final Set> CONTAINS_UDF = + private static final Set CONTAINS_UDF = ImmutableSet.of( - Read.Bounded.class, Read.Unbounded.class, ParDo.SingleOutput.class, MultiOutput.class); + PTransformTranslation.READ_TRANSFORM_URN, PTransformTranslation.PAR_DO_TRANSFORM_URN); public abstract boolean appliesTo(PCollection collection, DirectGraph graph); @@ -110,22 +108,19 @@ static BundleFactory bundleFactoryFor( return bundleFactory; } - @SuppressWarnings("rawtypes") - private static Map, Collection> + private static Map> defaultModelEnforcements(Set enabledEnforcements) { - ImmutableMap.Builder, Collection> - enforcements = ImmutableMap.builder(); + ImmutableMap.Builder> enforcements = + ImmutableMap.builder(); ImmutableList.Builder enabledParDoEnforcements = ImmutableList.builder(); if (enabledEnforcements.contains(Enforcement.IMMUTABILITY)) { enabledParDoEnforcements.add(ImmutabilityEnforcementFactory.create()); } Collection parDoEnforcements = enabledParDoEnforcements.build(); - enforcements.put(ParDo.SingleOutput.class, parDoEnforcements); - enforcements.put(MultiOutput.class, parDoEnforcements); + enforcements.put(PTransformTranslation.PAR_DO_TRANSFORM_URN, parDoEnforcements); return enforcements.build(); } - } //////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ExecutorServiceParallelExecutor.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ExecutorServiceParallelExecutor.java index 2f4d1f64ec43a..75e25623ecb97 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ExecutorServiceParallelExecutor.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ExecutorServiceParallelExecutor.java @@ -49,11 +49,11 @@ import org.apache.beam.runners.core.KeyedWorkItem; import org.apache.beam.runners.core.KeyedWorkItems; import org.apache.beam.runners.core.TimerInternals.TimerData; +import org.apache.beam.runners.core.construction.PTransformTranslation; import org.apache.beam.runners.direct.WatermarkManager.FiredTimers; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.PipelineResult.State; import org.apache.beam.sdk.runners.AppliedPTransform; -import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.util.UserCodeException; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.PCollection; @@ -77,9 +77,7 @@ final class ExecutorServiceParallelExecutor implements PipelineExecutor { private final DirectGraph graph; private final RootProviderRegistry rootProviderRegistry; private final TransformEvaluatorRegistry registry; - @SuppressWarnings("rawtypes") - private final Map, Collection> - transformEnforcements; + private final Map> transformEnforcements; private final EvaluationContext evaluationContext; @@ -112,9 +110,7 @@ public static ExecutorServiceParallelExecutor create( DirectGraph graph, RootProviderRegistry rootProviderRegistry, TransformEvaluatorRegistry registry, - @SuppressWarnings("rawtypes") - Map, Collection> - transformEnforcements, + Map> transformEnforcements, EvaluationContext context) { return new ExecutorServiceParallelExecutor( targetParallelism, @@ -130,8 +126,7 @@ private ExecutorServiceParallelExecutor( DirectGraph graph, RootProviderRegistry rootProviderRegistry, TransformEvaluatorRegistry registry, - @SuppressWarnings("rawtypes") - Map, Collection> transformEnforcements, + Map> transformEnforcements, EvaluationContext context) { this.targetParallelism = targetParallelism; // Don't use Daemon threads for workers. The Pipeline should continue to execute even if there @@ -237,7 +232,8 @@ private void evaluateBundle( Collection enforcements = MoreObjects.firstNonNull( - transformEnforcements.get(transform.getTransform().getClass()), + transformEnforcements.get( + PTransformTranslation.urnForTransform(transform.getTransform())), Collections.emptyList()); TransformExecutor callable = From 521488f8239547c7e93c30e75ecb2462ff114cb8 Mon Sep 17 00:00:00 2001 From: Luke Cwik Date: Fri, 30 Jun 2017 10:21:55 -0700 Subject: [PATCH 181/200] [BEAM-1348] Remove deprecated concepts in Fn API (now replaced with Runner API concepts). --- .../fn-api/src/main/proto/beam_fn_api.proto | 151 +----------------- .../harness/control/ProcessBundleHandler.java | 4 +- .../fn/harness/control/RegisterHandler.java | 2 +- .../harness/control/RegisterHandlerTest.java | 8 +- .../apache_beam/runners/pipeline_context.py | 2 +- .../runners/portability/fn_api_runner.py | 2 +- .../apache_beam/runners/worker/sdk_worker.py | 4 +- .../runners/worker/sdk_worker_test.py | 16 +- 8 files changed, 25 insertions(+), 164 deletions(-) diff --git a/sdks/common/fn-api/src/main/proto/beam_fn_api.proto b/sdks/common/fn-api/src/main/proto/beam_fn_api.proto index 8162bc50598b6..9da5afec1b4e3 100644 --- a/sdks/common/fn-api/src/main/proto/beam_fn_api.proto +++ b/sdks/common/fn-api/src/main/proto/beam_fn_api.proto @@ -38,7 +38,6 @@ option java_package = "org.apache.beam.fn.v1"; option java_outer_classname = "BeamFnApi"; import "beam_runner_api.proto"; -import "google/protobuf/any.proto"; import "google/protobuf/timestamp.proto"; /* @@ -67,129 +66,6 @@ message Target { string name = 2; } -// (Deprecated) Information defining a PCollection -// -// Migrate to Runner API. -message PCollection { - // (Required) A reference to a coder. - string coder_reference = 1 [deprecated = true]; - - // TODO: Windowing strategy, ... -} - -// (Deprecated) A primitive transform within Apache Beam. -// -// Migrate to Runner API. -message PrimitiveTransform { - // (Required) A pipeline level unique id which can be used as a reference to - // refer to this. - string id = 1 [deprecated = true]; - - // (Required) A function spec that is used by this primitive - // transform to process data. - FunctionSpec function_spec = 2 [deprecated = true]; - - // A map of distinct input names to target definitions. - // For example, in CoGbk this represents the tag name associated with each - // distinct input name and a list of primitive transforms that are associated - // with the specified input. - map inputs = 3 [deprecated = true]; - - // A map from local output name to PCollection definitions. For example, in - // DoFn this represents the tag name associated with each distinct output. - map outputs = 4 [deprecated = true]; - - // TODO: Should we model side inputs as a special type of input for a - // primitive transform or should it be modeled as the relationship that - // the predecessor input will be a view primitive transform. - // A map of from side input names to side inputs. - map side_inputs = 5 [deprecated = true]; - - // The user name of this step. - // TODO: This should really be in display data and not at this level - string step_name = 6 [deprecated = true]; -} - -/* - * User Definable Functions - * - * This is still unstable mainly due to how we model the side input. - */ - -// (Deprecated) Defines the common elements of user-definable functions, -// to allow the SDK to express the information the runner needs to execute work. -// -// Migrate to Runner API. -message FunctionSpec { - // (Required) A pipeline level unique id which can be used as a reference to - // refer to this. - string id = 1 [deprecated = true]; - - // (Required) A globally unique name representing this user definable - // function. - // - // User definable functions use the urn encodings registered such that another - // may implement the user definable function within another language. - // - // For example: - // urn:org.apache.beam:coder:kv:1.0 - string urn = 2 [deprecated = true]; - - // (Required) Reference to specification of execution environment required to - // invoke this function. - string environment_reference = 3 [deprecated = true]; - - // Data used to parameterize this function. Depending on the urn, this may be - // optional or required. - google.protobuf.Any data = 4 [deprecated = true]; -} - -// (Deprecated) Migrate to Runner API. -message SideInput { - // TODO: Coder? - - // For RunnerAPI. - Target input = 1 [deprecated = true]; - - // For FnAPI. - FunctionSpec view_fn = 2 [deprecated = true]; -} - -// (Deprecated) Defines how to encode values into byte streams and decode -// values from byte streams. A coder can be parameterized by additional -// properties which may or may not be language agnostic. -// -// Coders using the urn:org.apache.beam:coder namespace must have their -// encodings registered such that another may implement the encoding within -// another language. -// -// For example: -// urn:org.apache.beam:coder:kv:1.0 -// urn:org.apache.beam:coder:iterable:1.0 -// -// Migrate to Runner API. -message Coder { - // TODO: This looks weird when compared to the other function specs - // which use URN to differentiate themselves. Should "Coder" be embedded - // inside the FunctionSpec data block. - - // The data associated with this coder used to reconstruct it. - FunctionSpec function_spec = 1 [deprecated = true]; - - // A list of component coder references. - // - // For a key-value coder, there must be exactly two component coder references - // where the first reference represents the key coder and the second reference - // is the value coder. - // - // For an iterable coder, there must be exactly one component coder reference - // representing the value coder. - // - // TODO: Perhaps this is redundant with the data of the FunctionSpec - // for known coders? - repeated string component_coder_reference = 2 [deprecated = true]; -} - // A descriptor for connecting to a remote port using the Beam Fn Data API. // Allows for communication between two environments (for example between the // runner and the SDK). @@ -278,33 +154,20 @@ message ProcessBundleDescriptor { // refer to this. string id = 1; - // (Deprecated) A list of primitive transforms that should - // be used to construct the bundle processing graph. - // - // Migrate to Runner API definitions found within transforms field. - repeated PrimitiveTransform primitive_transform = 2 [deprecated = true]; - - // (Deprecated) The set of all coders referenced in this bundle. - // - // Migrate to Runner API defintions found within codersyyy field. - repeated Coder coders = 4 [deprecated = true]; - // (Required) A map from pipeline-scoped id to PTransform. - map transforms = 5; + map transforms = 2; // (Required) A map from pipeline-scoped id to PCollection. - map pcollections = 6; + map pcollections = 3; // (Required) A map from pipeline-scoped id to WindowingStrategy. - map windowing_strategies = 7; + map windowing_strategies = 4; // (Required) A map from pipeline-scoped id to Coder. - // TODO: Rename to "coders" once deprecated coders field is removed. Unique - // name is choosen to make it an easy search/replace - map codersyyy = 8; + map coders = 5; // (Required) A map from pipeline-scoped id to Environment. - map environments = 9; + map environments = 6; } // A request to process a given bundle. @@ -385,14 +248,14 @@ message PrimitiveTransformSplit { // // For example, a remote GRPC source will have a specific urn and data // block containing an ElementCountRestriction. - FunctionSpec completed_restriction = 2; + org.apache.beam.runner_api.v1.FunctionSpec completed_restriction = 2; // (Required) A function specification describing the restriction // representing the remainder of work for the primitive transform. // // FOr example, a remote GRPC source will have a specific urn and data // block contain an ElemntCountSkipRestriction. - FunctionSpec remaining_restriction = 3; + org.apache.beam.runner_api.v1.FunctionSpec remaining_restriction = 3; } message ProcessBundleSplitResponse { diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java index 4c4f73d4326b0..2a9cef8cda870 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java @@ -49,7 +49,7 @@ /** * Processes {@link org.apache.beam.fn.v1.BeamFnApi.ProcessBundleRequest}s by materializing - * the set of required runners for each {@link org.apache.beam.fn.v1.BeamFnApi.FunctionSpec}, + * the set of required runners for each {@link RunnerApi.FunctionSpec}, * wiring them together based upon the {@code input} and {@code output} map definitions. * *

    Finally executes the DAG based graph by starting all runners in reverse topological order, @@ -166,7 +166,7 @@ private void createRunnerAndConsumersForPTransformRecursively( pTransform, processBundleInstructionId, processBundleDescriptor.getPcollectionsMap(), - processBundleDescriptor.getCodersyyyMap(), + processBundleDescriptor.getCodersMap(), pCollectionIdsToConsumers, addStartFunction, addFinishFunction); diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/RegisterHandler.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/RegisterHandler.java index 276a1200df018..0e738ac76d882 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/RegisterHandler.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/RegisterHandler.java @@ -79,7 +79,7 @@ public BeamFnApi.InstructionResponse.Builder register(BeamFnApi.InstructionReque processBundleDescriptor.getClass()); computeIfAbsent(processBundleDescriptor.getId()).complete(processBundleDescriptor); for (Map.Entry entry - : processBundleDescriptor.getCodersyyyMap().entrySet()) { + : processBundleDescriptor.getCodersMap().entrySet()) { LOG.debug("Registering {} with type {}", entry.getKey(), entry.getValue().getClass()); diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/RegisterHandlerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/RegisterHandlerTest.java index b1f441030fdb5..2b275af4565f9 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/RegisterHandlerTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/RegisterHandlerTest.java @@ -44,14 +44,14 @@ public class RegisterHandlerTest { .setRegister(BeamFnApi.RegisterRequest.newBuilder() .addProcessBundleDescriptor(BeamFnApi.ProcessBundleDescriptor.newBuilder() .setId("1L") - .putCodersyyy("10L", RunnerApi.Coder.newBuilder() + .putCoders("10L", RunnerApi.Coder.newBuilder() .setSpec(RunnerApi.SdkFunctionSpec.newBuilder() .setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn("urn:10L").build()) .build()) .build()) .build()) .addProcessBundleDescriptor(BeamFnApi.ProcessBundleDescriptor.newBuilder().setId("2L") - .putCodersyyy("20L", RunnerApi.Coder.newBuilder() + .putCoders("20L", RunnerApi.Coder.newBuilder() .setSpec(RunnerApi.SdkFunctionSpec.newBuilder() .setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn("urn:20L").build()) .build()) @@ -82,10 +82,10 @@ public BeamFnApi.InstructionResponse call() throws Exception { assertEquals(REGISTER_REQUEST.getRegister().getProcessBundleDescriptor(1), handler.getById("2L")); assertEquals( - REGISTER_REQUEST.getRegister().getProcessBundleDescriptor(0).getCodersyyyOrThrow("10L"), + REGISTER_REQUEST.getRegister().getProcessBundleDescriptor(0).getCodersOrThrow("10L"), handler.getById("10L")); assertEquals( - REGISTER_REQUEST.getRegister().getProcessBundleDescriptor(1).getCodersyyyOrThrow("20L"), + REGISTER_REQUEST.getRegister().getProcessBundleDescriptor(1).getCodersOrThrow("20L"), handler.getById("20L")); assertEquals(REGISTER_RESPONSE, responseFuture.get()); } diff --git a/sdks/python/apache_beam/runners/pipeline_context.py b/sdks/python/apache_beam/runners/pipeline_context.py index c2ae3f33650dc..a40069b4280e2 100644 --- a/sdks/python/apache_beam/runners/pipeline_context.py +++ b/sdks/python/apache_beam/runners/pipeline_context.py @@ -84,7 +84,7 @@ class PipelineContext(object): def __init__(self, proto=None): if isinstance(proto, beam_fn_api_pb2.ProcessBundleDescriptor): proto = beam_runner_api_pb2.Components( - coders=dict(proto.codersyyy.items()), + coders=dict(proto.coders.items()), windowing_strategies=dict(proto.windowing_strategies.items()), environments=dict(proto.environments.items())) for name, cls in self._COMPONENT_TYPES.items(): diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner.py b/sdks/python/apache_beam/runners/portability/fn_api_runner.py index c5438adbdcf0b..f52286456530c 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner.py @@ -261,7 +261,7 @@ def only_element(iterable): id=self._next_uid(), transforms=transform_protos, pcollections=pcollection_protos, - codersyyy=dict(context_proto.coders.items()), + coders=dict(context_proto.coders.items()), windowing_strategies=dict(context_proto.windowing_strategies.items()), environments=dict(context_proto.environments.items())) return input_data, side_input_data, runner_sinks, process_bundle_descriptor diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker.py b/sdks/python/apache_beam/runners/worker/sdk_worker.py index e1ddfb7807aa8..ae8683047122f 100644 --- a/sdks/python/apache_beam/runners/worker/sdk_worker.py +++ b/sdks/python/apache_beam/runners/worker/sdk_worker.py @@ -249,8 +249,6 @@ def do_instruction(self, request): def register(self, request, unused_instruction_id=None): for process_bundle_descriptor in request.process_bundle_descriptor: self.fns[process_bundle_descriptor.id] = process_bundle_descriptor - for p_transform in list(process_bundle_descriptor.primitive_transform): - self.fns[p_transform.function_spec.id] = p_transform.function_spec return beam_fn_api_pb2.RegisterResponse() def create_execution_tree(self, descriptor): @@ -355,7 +353,7 @@ def create_operation(self, transform_id, consumers): return creator(self, transform_id, transform_proto, parameter, consumers) def get_coder(self, coder_id): - coder_proto = self.descriptor.codersyyy[coder_id] + coder_proto = self.descriptor.coders[coder_id] if coder_proto.spec.spec.urn: return self.context.coders.get_by_id(coder_id) else: diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker_test.py b/sdks/python/apache_beam/runners/worker/sdk_worker_test.py index 553d5b86cbadc..dc72a5ff4faeb 100644 --- a/sdks/python/apache_beam/runners/worker/sdk_worker_test.py +++ b/sdks/python/apache_beam/runners/worker/sdk_worker_test.py @@ -28,6 +28,7 @@ import grpc from apache_beam.portability.api import beam_fn_api_pb2 +from apache_beam.portability.api import beam_runner_api_pb2 from apache_beam.runners.worker import sdk_worker @@ -61,13 +62,12 @@ def Control(self, response_iterator, context): class SdkWorkerTest(unittest.TestCase): def test_fn_registration(self): - fns = [beam_fn_api_pb2.FunctionSpec(id=str(ix)) for ix in range(4)] - - process_bundle_descriptors = [beam_fn_api_pb2.ProcessBundleDescriptor( - id=str(100+ix), - primitive_transform=[ - beam_fn_api_pb2.PrimitiveTransform(function_spec=fn)]) - for ix, fn in enumerate(fns)] + process_bundle_descriptors = [ + beam_fn_api_pb2.ProcessBundleDescriptor( + id=str(100+ix), + transforms={ + str(ix): beam_runner_api_pb2.PTransform(unique_name=str(ix))}) + for ix in range(4)] test_controller = BeamFnControlServicer([beam_fn_api_pb2.InstructionRequest( register=beam_fn_api_pb2.RegisterRequest( @@ -83,7 +83,7 @@ def test_fn_registration(self): harness.run() self.assertEqual( harness.worker.fns, - {item.id: item for item in fns + process_bundle_descriptors}) + {item.id: item for item in process_bundle_descriptors}) if __name__ == "__main__": From 77ba7a35cdae0b036791cce0682beefeb3fd809b Mon Sep 17 00:00:00 2001 From: Reuven Lax Date: Fri, 9 Jun 2017 17:11:32 -0700 Subject: [PATCH 182/200] Adds DynamicDestinations support to FileBasedSink --- .../common/WriteOneFilePerWindow.java | 52 +- .../beam/examples/WindowedWordCountIT.java | 4 +- .../complete/game/utils/WriteToText.java | 43 +- .../construction/WriteFilesTranslation.java | 67 +- .../construction/PTransformMatchersTest.java | 22 +- .../WriteFilesTranslationTest.java | 62 +- .../direct/WriteWithShardingFactory.java | 6 +- .../direct/WriteWithShardingFactoryTest.java | 18 +- .../beam/runners/dataflow/DataflowRunner.java | 15 +- .../runners/dataflow/DataflowRunnerTest.java | 35 +- .../spark/SparkRunnerDebuggerTest.java | 26 +- .../src/main/proto/beam_runner_api.proto | 7 +- .../beam/sdk/coders}/ShardedKeyCoder.java | 18 +- .../java/org/apache/beam/sdk/io/AvroIO.java | 220 +++--- .../java/org/apache/beam/sdk/io/AvroSink.java | 32 +- .../beam/sdk/io/DefaultFilenamePolicy.java | 274 ++++++-- .../beam/sdk/io/DynamicFileDestinations.java | 115 ++++ .../org/apache/beam/sdk/io/FileBasedSink.java | 513 ++++++++------ .../org/apache/beam/sdk/io/TFRecordIO.java | 44 +- .../java/org/apache/beam/sdk/io/TextIO.java | 488 +++++++++---- .../java/org/apache/beam/sdk/io/TextSink.java | 22 +- .../org/apache/beam/sdk/io/WriteFiles.java | 640 +++++++++++------- .../sdk/transforms/SerializableFunctions.java | 50 ++ .../apache/beam/sdk/values}/ShardedKey.java | 10 +- .../org/apache/beam/sdk/io/AvroIOTest.java | 85 ++- .../sdk/io/DefaultFilenamePolicyTest.java | 135 ++-- .../io/DrunkWritableByteChannelFactory.java | 2 +- .../apache/beam/sdk/io/FileBasedSinkTest.java | 93 +-- .../org/apache/beam/sdk/io/SimpleSink.java | 56 +- .../org/apache/beam/sdk/io/TextIOTest.java | 264 +++++++- .../apache/beam/sdk/io/WriteFilesTest.java | 339 ++++++++-- .../beam/sdk/io/gcp/bigquery/BatchLoads.java | 2 + .../io/gcp/bigquery/DynamicDestinations.java | 29 +- .../io/gcp/bigquery/GenerateShardedTable.java | 1 + .../sdk/io/gcp/bigquery/StreamingWriteFn.java | 1 + .../io/gcp/bigquery/StreamingWriteTables.java | 2 + .../sdk/io/gcp/bigquery/TagWithUniqueIds.java | 1 + .../io/gcp/bigquery/WriteBundlesToFiles.java | 2 + .../bigquery/WriteGroupedRecordsToFiles.java | 1 + .../sdk/io/gcp/bigquery/WritePartition.java | 1 + .../beam/sdk/io/gcp/bigquery/WriteTables.java | 1 + .../sdk/io/gcp/bigquery/BigQueryIOTest.java | 2 + .../org/apache/beam/sdk/io/xml/XmlIO.java | 4 +- .../org/apache/beam/sdk/io/xml/XmlSink.java | 21 +- .../apache/beam/sdk/io/xml/XmlSinkTest.java | 4 +- 45 files changed, 2588 insertions(+), 1241 deletions(-) rename sdks/java/{io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery => core/src/main/java/org/apache/beam/sdk/coders}/ShardedKeyCoder.java (80%) create mode 100644 sdks/java/core/src/main/java/org/apache/beam/sdk/io/DynamicFileDestinations.java create mode 100644 sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/SerializableFunctions.java rename sdks/java/{io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery => core/src/main/java/org/apache/beam/sdk/values}/ShardedKey.java (90%) diff --git a/examples/java/src/main/java/org/apache/beam/examples/common/WriteOneFilePerWindow.java b/examples/java/src/main/java/org/apache/beam/examples/common/WriteOneFilePerWindow.java index 5e6df9c781fd7..49865ba60b7c9 100644 --- a/examples/java/src/main/java/org/apache/beam/examples/common/WriteOneFilePerWindow.java +++ b/examples/java/src/main/java/org/apache/beam/examples/common/WriteOneFilePerWindow.java @@ -17,11 +17,12 @@ */ package org.apache.beam.examples.common; -import static com.google.common.base.Verify.verifyNotNull; +import static com.google.common.base.MoreObjects.firstNonNull; import javax.annotation.Nullable; import org.apache.beam.sdk.io.FileBasedSink; import org.apache.beam.sdk.io.FileBasedSink.FilenamePolicy; +import org.apache.beam.sdk.io.FileBasedSink.OutputFileHints; import org.apache.beam.sdk.io.TextIO; import org.apache.beam.sdk.io.fs.ResolveOptions.StandardResolveOptions; import org.apache.beam.sdk.io.fs.ResourceId; @@ -53,22 +54,12 @@ public WriteOneFilePerWindow(String filenamePrefix, Integer numShards) { @Override public PDone expand(PCollection input) { - // filenamePrefix may contain a directory and a filename component. Pull out only the filename - // component from that path for the PerWindowFiles. - String prefix = ""; ResourceId resource = FileBasedSink.convertToFileResourceIfPossible(filenamePrefix); - if (!resource.isDirectory()) { - prefix = verifyNotNull( - resource.getFilename(), - "A non-directory resource should have a non-null filename: %s", - resource); - } - - - TextIO.Write write = TextIO.write() - .to(resource.getCurrentDirectory()) - .withFilenamePolicy(new PerWindowFiles(prefix)) - .withWindowedWrites(); + TextIO.Write write = + TextIO.write() + .to(new PerWindowFiles(resource)) + .withTempDirectory(resource.getCurrentDirectory()) + .withWindowedWrites(); if (numShards != null) { write = write.withNumShards(numShards); } @@ -83,31 +74,36 @@ public PDone expand(PCollection input) { */ public static class PerWindowFiles extends FilenamePolicy { - private final String prefix; + private final ResourceId baseFilename; - public PerWindowFiles(String prefix) { - this.prefix = prefix; + public PerWindowFiles(ResourceId baseFilename) { + this.baseFilename = baseFilename; } public String filenamePrefixForWindow(IntervalWindow window) { + String prefix = + baseFilename.isDirectory() ? "" : firstNonNull(baseFilename.getFilename(), ""); return String.format("%s-%s-%s", prefix, FORMATTER.print(window.start()), FORMATTER.print(window.end())); } @Override - public ResourceId windowedFilename( - ResourceId outputDirectory, WindowedContext context, String extension) { + public ResourceId windowedFilename(WindowedContext context, OutputFileHints outputFileHints) { IntervalWindow window = (IntervalWindow) context.getWindow(); - String filename = String.format( - "%s-%s-of-%s%s", - filenamePrefixForWindow(window), context.getShardNumber(), context.getNumShards(), - extension); - return outputDirectory.resolve(filename, StandardResolveOptions.RESOLVE_FILE); + String filename = + String.format( + "%s-%s-of-%s%s", + filenamePrefixForWindow(window), + context.getShardNumber(), + context.getNumShards(), + outputFileHints.getSuggestedFilenameSuffix()); + return baseFilename + .getCurrentDirectory() + .resolve(filename, StandardResolveOptions.RESOLVE_FILE); } @Override - public ResourceId unwindowedFilename( - ResourceId outputDirectory, Context context, String extension) { + public ResourceId unwindowedFilename(Context context, OutputFileHints outputFileHints) { throw new UnsupportedOperationException("Unsupported."); } } diff --git a/examples/java/src/test/java/org/apache/beam/examples/WindowedWordCountIT.java b/examples/java/src/test/java/org/apache/beam/examples/WindowedWordCountIT.java index eb7e4c4e936f7..bec795210fa13 100644 --- a/examples/java/src/test/java/org/apache/beam/examples/WindowedWordCountIT.java +++ b/examples/java/src/test/java/org/apache/beam/examples/WindowedWordCountIT.java @@ -32,6 +32,7 @@ import org.apache.beam.examples.common.ExampleUtils; import org.apache.beam.examples.common.WriteOneFilePerWindow.PerWindowFiles; import org.apache.beam.sdk.PipelineResult; +import org.apache.beam.sdk.io.FileBasedSink; import org.apache.beam.sdk.io.FileSystems; import org.apache.beam.sdk.io.fs.ResolveOptions.StandardResolveOptions; import org.apache.beam.sdk.options.PipelineOptionsFactory; @@ -149,7 +150,8 @@ private void testWindowedWordCountPipeline(WindowedWordCountITOptions options) t String outputPrefix = options.getOutput(); - PerWindowFiles filenamePolicy = new PerWindowFiles(outputPrefix); + PerWindowFiles filenamePolicy = + new PerWindowFiles(FileBasedSink.convertToFileResourceIfPossible(outputPrefix)); List expectedOutputFiles = Lists.newArrayListWithCapacity(6); diff --git a/examples/java8/src/main/java/org/apache/beam/examples/complete/game/utils/WriteToText.java b/examples/java8/src/main/java/org/apache/beam/examples/complete/game/utils/WriteToText.java index e6c8ddbecc41a..1d601987211b7 100644 --- a/examples/java8/src/main/java/org/apache/beam/examples/complete/game/utils/WriteToText.java +++ b/examples/java8/src/main/java/org/apache/beam/examples/complete/game/utils/WriteToText.java @@ -18,7 +18,6 @@ package org.apache.beam.examples.complete.game.utils; import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Verify.verifyNotNull; import java.io.Serializable; import java.util.ArrayList; @@ -28,6 +27,7 @@ import java.util.stream.Collectors; import org.apache.beam.sdk.io.FileBasedSink; import org.apache.beam.sdk.io.FileBasedSink.FilenamePolicy; +import org.apache.beam.sdk.io.FileBasedSink.OutputFileHints; import org.apache.beam.sdk.io.TextIO; import org.apache.beam.sdk.io.fs.ResolveOptions.StandardResolveOptions; import org.apache.beam.sdk.io.fs.ResourceId; @@ -111,21 +111,12 @@ public PDone expand(PCollection input) { checkArgument( input.getWindowingStrategy().getWindowFn().windowCoder() == IntervalWindow.getCoder()); - // filenamePrefix may contain a directory and a filename component. Pull out only the filename - // component from that path for the PerWindowFiles. - String prefix = ""; ResourceId resource = FileBasedSink.convertToFileResourceIfPossible(filenamePrefix); - if (!resource.isDirectory()) { - prefix = verifyNotNull( - resource.getFilename(), - "A non-directory resource should have a non-null filename: %s", - resource); - } return input.apply( TextIO.write() - .to(resource.getCurrentDirectory()) - .withFilenamePolicy(new PerWindowFiles(prefix)) + .to(new PerWindowFiles(resource)) + .withTempDirectory(resource.getCurrentDirectory()) .withWindowedWrites() .withNumShards(3)); } @@ -139,31 +130,33 @@ public PDone expand(PCollection input) { */ protected static class PerWindowFiles extends FilenamePolicy { - private final String prefix; + private final ResourceId prefix; - public PerWindowFiles(String prefix) { + public PerWindowFiles(ResourceId prefix) { this.prefix = prefix; } public String filenamePrefixForWindow(IntervalWindow window) { - return String.format("%s-%s-%s", - prefix, formatter.print(window.start()), formatter.print(window.end())); + String filePrefix = prefix.isDirectory() ? "" : prefix.getFilename(); + return String.format( + "%s-%s-%s", filePrefix, formatter.print(window.start()), formatter.print(window.end())); } @Override - public ResourceId windowedFilename( - ResourceId outputDirectory, WindowedContext context, String extension) { + public ResourceId windowedFilename(WindowedContext context, OutputFileHints outputFileHints) { IntervalWindow window = (IntervalWindow) context.getWindow(); - String filename = String.format( - "%s-%s-of-%s%s", - filenamePrefixForWindow(window), context.getShardNumber(), context.getNumShards(), - extension); - return outputDirectory.resolve(filename, StandardResolveOptions.RESOLVE_FILE); + String filename = + String.format( + "%s-%s-of-%s%s", + filenamePrefixForWindow(window), + context.getShardNumber(), + context.getNumShards(), + outputFileHints.getSuggestedFilenameSuffix()); + return prefix.getCurrentDirectory().resolve(filename, StandardResolveOptions.RESOLVE_FILE); } @Override - public ResourceId unwindowedFilename( - ResourceId outputDirectory, Context context, String extension) { + public ResourceId unwindowedFilename(Context context, OutputFileHints outputFileHints) { throw new UnsupportedOperationException("Unsupported."); } } diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/WriteFilesTranslation.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/WriteFilesTranslation.java index 99b77efd9d471..b1d2da435b1e8 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/WriteFilesTranslation.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/WriteFilesTranslation.java @@ -26,6 +26,7 @@ import com.google.protobuf.ByteString; import com.google.protobuf.BytesValue; import java.io.IOException; +import java.io.Serializable; import java.util.Collections; import java.util.Map; import org.apache.beam.runners.core.construction.PTransformTranslation.TransformPayloadTranslator; @@ -37,6 +38,7 @@ import org.apache.beam.sdk.io.WriteFiles; import org.apache.beam.sdk.runners.AppliedPTransform; import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.util.SerializableUtils; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PDone; @@ -51,32 +53,45 @@ public class WriteFilesTranslation { public static final String CUSTOM_JAVA_FILE_BASED_SINK_URN = "urn:beam:file_based_sink:javasdk:0.1"; + public static final String CUSTOM_JAVA_FILE_BASED_SINK_FORMAT_FUNCTION_URN = + "urn:beam:file_based_sink_format_function:javasdk:0.1"; + @VisibleForTesting - static WriteFilesPayload toProto(WriteFiles transform) { + static WriteFilesPayload toProto(WriteFiles transform) { return WriteFilesPayload.newBuilder() .setSink(toProto(transform.getSink())) + .setFormatFunction(toProto(transform.getFormatFunction())) .setWindowedWrites(transform.isWindowedWrites()) .setRunnerDeterminedSharding( transform.getNumShards() == null && transform.getSharding() == null) .build(); } - private static SdkFunctionSpec toProto(FileBasedSink sink) { + private static SdkFunctionSpec toProto(FileBasedSink sink) { + return toProto(CUSTOM_JAVA_FILE_BASED_SINK_URN, sink); + } + + private static SdkFunctionSpec toProto(SerializableFunction serializableFunction) { + return toProto(CUSTOM_JAVA_FILE_BASED_SINK_FORMAT_FUNCTION_URN, serializableFunction); + } + + private static SdkFunctionSpec toProto(String urn, Serializable serializable) { return SdkFunctionSpec.newBuilder() .setSpec( FunctionSpec.newBuilder() - .setUrn(CUSTOM_JAVA_FILE_BASED_SINK_URN) + .setUrn(urn) .setParameter( Any.pack( BytesValue.newBuilder() .setValue( - ByteString.copyFrom(SerializableUtils.serializeToByteArray(sink))) + ByteString.copyFrom( + SerializableUtils.serializeToByteArray(serializable))) .build()))) .build(); } @VisibleForTesting - static FileBasedSink sinkFromProto(SdkFunctionSpec sinkProto) throws IOException { + static FileBasedSink sinkFromProto(SdkFunctionSpec sinkProto) throws IOException { checkArgument( sinkProto.getSpec().getUrn().equals(CUSTOM_JAVA_FILE_BASED_SINK_URN), "Cannot extract %s instance from %s with URN %s", @@ -87,16 +102,44 @@ static FileBasedSink sinkFromProto(SdkFunctionSpec sinkProto) throws IOExcept byte[] serializedSink = sinkProto.getSpec().getParameter().unpack(BytesValue.class).getValue().toByteArray(); - return (FileBasedSink) + return (FileBasedSink) SerializableUtils.deserializeFromByteArray( serializedSink, FileBasedSink.class.getSimpleName()); } - public static FileBasedSink getSink( - AppliedPTransform, PDone, ? extends PTransform, PDone>> + @VisibleForTesting + static SerializableFunction formatFunctionFromProto( + SdkFunctionSpec sinkProto) throws IOException { + checkArgument( + sinkProto.getSpec().getUrn().equals(CUSTOM_JAVA_FILE_BASED_SINK_FORMAT_FUNCTION_URN), + "Cannot extract %s instance from %s with URN %s", + SerializableFunction.class.getSimpleName(), + FunctionSpec.class.getSimpleName(), + sinkProto.getSpec().getUrn()); + + byte[] serializedFunction = + sinkProto.getSpec().getParameter().unpack(BytesValue.class).getValue().toByteArray(); + + return (SerializableFunction) + SerializableUtils.deserializeFromByteArray( + serializedFunction, FileBasedSink.class.getSimpleName()); + } + + public static FileBasedSink getSink( + AppliedPTransform, PDone, ? extends PTransform, PDone>> + transform) + throws IOException { + return (FileBasedSink) + sinkFromProto(getWriteFilesPayload(transform).getSink()); + } + + public static SerializableFunction getFormatFunction( + AppliedPTransform< + PCollection, PDone, ? extends PTransform, PDone>> transform) throws IOException { - return (FileBasedSink) sinkFromProto(getWriteFilesPayload(transform).getSink()); + return formatFunctionFromProto( + getWriteFilesPayload(transform).getFormatFunction()); } public static boolean isWindowedWrites( @@ -124,15 +167,15 @@ private static WriteFilesPayload getWriteFilesPayload( .unpack(WriteFilesPayload.class); } - static class WriteFilesTranslator implements TransformPayloadTranslator> { + static class WriteFilesTranslator implements TransformPayloadTranslator> { @Override - public String getUrn(WriteFiles transform) { + public String getUrn(WriteFiles transform) { return PTransformTranslation.WRITE_FILES_TRANSFORM_URN; } @Override public FunctionSpec translate( - AppliedPTransform> transform, SdkComponents components) { + AppliedPTransform> transform, SdkComponents components) { return FunctionSpec.newBuilder() .setUrn(getUrn(transform.getTransform())) .setParameter(Any.pack(toProto(transform.getTransform()))) diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/PTransformMatchersTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/PTransformMatchersTest.java index 6459849f24fa9..99d3dd1b9180e 100644 --- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/PTransformMatchersTest.java +++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/PTransformMatchersTest.java @@ -32,6 +32,7 @@ import org.apache.beam.sdk.coders.VarIntCoder; import org.apache.beam.sdk.coders.VoidCoder; import org.apache.beam.sdk.io.DefaultFilenamePolicy; +import org.apache.beam.sdk.io.DynamicFileDestinations; import org.apache.beam.sdk.io.FileBasedSink; import org.apache.beam.sdk.io.FileBasedSink.FilenamePolicy; import org.apache.beam.sdk.io.LocalResources; @@ -55,6 +56,7 @@ import org.apache.beam.sdk.transforms.Materializations; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.SerializableFunctions; import org.apache.beam.sdk.transforms.Sum; import org.apache.beam.sdk.transforms.View; import org.apache.beam.sdk.transforms.View.CreatePCollectionView; @@ -537,30 +539,32 @@ public void flattenWithDuplicateInputsNonFlatten() { public void writeWithRunnerDeterminedSharding() { ResourceId outputDirectory = LocalResources.fromString("/foo/bar", true /* isDirectory */); FilenamePolicy policy = - DefaultFilenamePolicy.constructUsingStandardParameters( + DefaultFilenamePolicy.fromStandardParameters( StaticValueProvider.of(outputDirectory), DefaultFilenamePolicy.DEFAULT_UNWINDOWED_SHARD_TEMPLATE, "", false); - WriteFiles write = + WriteFiles write = WriteFiles.to( - new FileBasedSink(StaticValueProvider.of(outputDirectory), policy) { + new FileBasedSink( + StaticValueProvider.of(outputDirectory), DynamicFileDestinations.constant(null)) { @Override - public WriteOperation createWriteOperation() { + public WriteOperation createWriteOperation() { return null; } - }); + }, + SerializableFunctions.identity()); assertThat( PTransformMatchers.writeWithRunnerDeterminedSharding().matches(appliedWrite(write)), is(true)); - WriteFiles withStaticSharding = write.withNumShards(3); + WriteFiles withStaticSharding = write.withNumShards(3); assertThat( PTransformMatchers.writeWithRunnerDeterminedSharding() .matches(appliedWrite(withStaticSharding)), is(false)); - WriteFiles withCustomSharding = + WriteFiles withCustomSharding = write.withSharding(Sum.integersGlobally().asSingletonView()); assertThat( PTransformMatchers.writeWithRunnerDeterminedSharding() @@ -568,8 +572,8 @@ public WriteOperation createWriteOperation() { is(false)); } - private AppliedPTransform appliedWrite(WriteFiles write) { - return AppliedPTransform., PDone, WriteFiles>of( + private AppliedPTransform appliedWrite(WriteFiles write) { + return AppliedPTransform., PDone, WriteFiles>of( "WriteFiles", Collections., PValue>emptyMap(), Collections., PValue>emptyMap(), diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/WriteFilesTranslationTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/WriteFilesTranslationTest.java index 739034cfa8605..283df1657dedb 100644 --- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/WriteFilesTranslationTest.java +++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/WriteFilesTranslationTest.java @@ -26,8 +26,10 @@ import javax.annotation.Nullable; import org.apache.beam.sdk.common.runner.v1.RunnerApi; import org.apache.beam.sdk.common.runner.v1.RunnerApi.ParDoPayload; +import org.apache.beam.sdk.io.DynamicFileDestinations; import org.apache.beam.sdk.io.FileBasedSink; import org.apache.beam.sdk.io.FileBasedSink.FilenamePolicy; +import org.apache.beam.sdk.io.FileBasedSink.OutputFileHints; import org.apache.beam.sdk.io.FileSystems; import org.apache.beam.sdk.io.WriteFiles; import org.apache.beam.sdk.io.fs.ResourceId; @@ -36,6 +38,8 @@ import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.transforms.SerializableFunctions; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PDone; import org.junit.Test; @@ -56,16 +60,17 @@ public class WriteFilesTranslationTest { @RunWith(Parameterized.class) public static class TestWriteFilesPayloadTranslation { @Parameters(name = "{index}: {0}") - public static Iterable> data() { - return ImmutableList.>of( - WriteFiles.to(new DummySink()), - WriteFiles.to(new DummySink()).withWindowedWrites(), - WriteFiles.to(new DummySink()).withNumShards(17), - WriteFiles.to(new DummySink()).withWindowedWrites().withNumShards(42)); + public static Iterable> data() { + SerializableFunction format = SerializableFunctions.constant(null); + return ImmutableList.of( + WriteFiles.to(new DummySink(), format), + WriteFiles.to(new DummySink(), format).withWindowedWrites(), + WriteFiles.to(new DummySink(), format).withNumShards(17), + WriteFiles.to(new DummySink(), format).withWindowedWrites().withNumShards(42)); } @Parameter(0) - public WriteFiles writeFiles; + public WriteFiles writeFiles; public static TestPipeline p = TestPipeline.create().enableAbandonedNodeEnforcement(false); @@ -80,7 +85,7 @@ public void testEncodedProto() throws Exception { assertThat(payload.getWindowedWrites(), equalTo(writeFiles.isWindowedWrites())); assertThat( - (FileBasedSink) WriteFilesTranslation.sinkFromProto(payload.getSink()), + (FileBasedSink) WriteFilesTranslation.sinkFromProto(payload.getSink()), equalTo(writeFiles.getSink())); } @@ -89,9 +94,9 @@ public void testExtractionDirectFromTransform() throws Exception { PCollection input = p.apply(Create.of("hello")); PDone output = input.apply(writeFiles); - AppliedPTransform, PDone, WriteFiles> appliedPTransform = - AppliedPTransform., PDone, WriteFiles>of( - "foo", input.expand(), output.expand(), writeFiles, p); + AppliedPTransform, PDone, WriteFiles> + appliedPTransform = + AppliedPTransform.of("foo", input.expand(), output.expand(), writeFiles, p); assertThat( WriteFilesTranslation.isRunnerDeterminedSharding(appliedPTransform), @@ -101,7 +106,9 @@ public void testExtractionDirectFromTransform() throws Exception { WriteFilesTranslation.isWindowedWrites(appliedPTransform), equalTo(writeFiles.isWindowedWrites())); - assertThat(WriteFilesTranslation.getSink(appliedPTransform), equalTo(writeFiles.getSink())); + assertThat( + WriteFilesTranslation.getSink(appliedPTransform), + equalTo(writeFiles.getSink())); } } @@ -109,16 +116,16 @@ public void testExtractionDirectFromTransform() throws Exception { * A simple {@link FileBasedSink} for testing serialization/deserialization. Not mocked to avoid * any issues serializing mocks. */ - private static class DummySink extends FileBasedSink { + private static class DummySink extends FileBasedSink { DummySink() { super( StaticValueProvider.of(FileSystems.matchNewResource("nowhere", false)), - new DummyFilenamePolicy()); + DynamicFileDestinations.constant(new DummyFilenamePolicy())); } @Override - public WriteOperation createWriteOperation() { + public WriteOperation createWriteOperation() { return new DummyWriteOperation(this); } @@ -130,46 +137,39 @@ public boolean equals(Object other) { DummySink that = (DummySink) other; - return getFilenamePolicy().equals(((DummySink) other).getFilenamePolicy()) - && getBaseOutputDirectoryProvider().isAccessible() - && that.getBaseOutputDirectoryProvider().isAccessible() - && getBaseOutputDirectoryProvider() - .get() - .equals(that.getBaseOutputDirectoryProvider().get()); + return getTempDirectoryProvider().isAccessible() + && that.getTempDirectoryProvider().isAccessible() + && getTempDirectoryProvider().get().equals(that.getTempDirectoryProvider().get()); } @Override public int hashCode() { return Objects.hash( DummySink.class, - getFilenamePolicy(), - getBaseOutputDirectoryProvider().isAccessible() - ? getBaseOutputDirectoryProvider().get() - : null); + getTempDirectoryProvider().isAccessible() ? getTempDirectoryProvider().get() : null); } } - private static class DummyWriteOperation extends FileBasedSink.WriteOperation { - public DummyWriteOperation(FileBasedSink sink) { + private static class DummyWriteOperation extends FileBasedSink.WriteOperation { + public DummyWriteOperation(FileBasedSink sink) { super(sink); } @Override - public FileBasedSink.Writer createWriter() throws Exception { + public FileBasedSink.Writer createWriter() throws Exception { throw new UnsupportedOperationException("Should never be called."); } } private static class DummyFilenamePolicy extends FilenamePolicy { @Override - public ResourceId windowedFilename( - ResourceId outputDirectory, WindowedContext c, String extension) { + public ResourceId windowedFilename(WindowedContext c, OutputFileHints outputFileHints) { throw new UnsupportedOperationException("Should never be called."); } @Nullable @Override - public ResourceId unwindowedFilename(ResourceId outputDirectory, Context c, String extension) { + public ResourceId unwindowedFilename(Context c, OutputFileHints outputFileHints) { throw new UnsupportedOperationException("Should never be called."); } diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java index d8734a1c55544..ba796ae745626 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java @@ -60,9 +60,11 @@ class WriteWithShardingFactory public PTransformReplacement, PDone> getReplacementTransform( AppliedPTransform, PDone, PTransform, PDone>> transform) { - try { - WriteFiles replacement = WriteFiles.to(WriteFilesTranslation.getSink(transform)); + WriteFiles replacement = + WriteFiles.to( + WriteFilesTranslation.getSink(transform), + WriteFilesTranslation.getFormatFunction(transform)); if (WriteFilesTranslation.isWindowedWrites(transform)) { replacement = replacement.withWindowedWrites(); } diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WriteWithShardingFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WriteWithShardingFactoryTest.java index 41d671f5c8e17..546a18135e665 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WriteWithShardingFactoryTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WriteWithShardingFactoryTest.java @@ -39,9 +39,8 @@ import org.apache.beam.runners.direct.WriteWithShardingFactory.CalculateShardsFn; import org.apache.beam.sdk.coders.VarLongCoder; import org.apache.beam.sdk.coders.VoidCoder; -import org.apache.beam.sdk.io.DefaultFilenamePolicy; +import org.apache.beam.sdk.io.DynamicFileDestinations; import org.apache.beam.sdk.io.FileBasedSink; -import org.apache.beam.sdk.io.FileBasedSink.FilenamePolicy; import org.apache.beam.sdk.io.FileSystems; import org.apache.beam.sdk.io.LocalResources; import org.apache.beam.sdk.io.TextIO; @@ -55,6 +54,7 @@ import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.DoFnTester; import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.SerializableFunctions; import org.apache.beam.sdk.transforms.windowing.GlobalWindow; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionView; @@ -137,21 +137,17 @@ public void dynamicallyReshardedWrite() throws Exception { @Test public void withNoShardingSpecifiedReturnsNewTransform() { ResourceId outputDirectory = LocalResources.fromString("/foo", true /* isDirectory */); - FilenamePolicy policy = - DefaultFilenamePolicy.constructUsingStandardParameters( - StaticValueProvider.of(outputDirectory), - DefaultFilenamePolicy.DEFAULT_UNWINDOWED_SHARD_TEMPLATE, - "", - false); PTransform, PDone> original = WriteFiles.to( - new FileBasedSink(StaticValueProvider.of(outputDirectory), policy) { + new FileBasedSink( + StaticValueProvider.of(outputDirectory), DynamicFileDestinations.constant(null)) { @Override - public WriteOperation createWriteOperation() { + public WriteOperation createWriteOperation() { throw new IllegalArgumentException("Should not be used"); } - }); + }, + SerializableFunctions.identity()); @SuppressWarnings("unchecked") PCollection objs = (PCollection) p.apply(Create.empty(VoidCoder.of())); diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java index 5d9f0f32aca4a..893575969d5a1 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java @@ -1455,8 +1455,9 @@ public Map mapOutputs( } @VisibleForTesting - static class StreamingShardedWriteFactory - implements PTransformOverrideFactory, PDone, WriteFiles> { + static class StreamingShardedWriteFactory + implements PTransformOverrideFactory< + PCollection, PDone, WriteFiles> { // We pick 10 as a a default, as it works well with the default number of workers started // by Dataflow. static final int DEFAULT_NUM_SHARDS = 10; @@ -1467,8 +1468,9 @@ static class StreamingShardedWriteFactory } @Override - public PTransformReplacement, PDone> getReplacementTransform( - AppliedPTransform, PDone, WriteFiles> transform) { + public PTransformReplacement, PDone> getReplacementTransform( + AppliedPTransform, PDone, WriteFiles> + transform) { // By default, if numShards is not set WriteFiles will produce one file per bundle. In // streaming, there are large numbers of small bundles, resulting in many tiny files. // Instead we pick max workers * 2 to ensure full parallelism, but prevent too-many files. @@ -1485,7 +1487,10 @@ public PTransformReplacement, PDone> getReplacementTransform( } try { - WriteFiles replacement = WriteFiles.to(WriteFilesTranslation.getSink(transform)); + WriteFiles replacement = + WriteFiles.to( + WriteFilesTranslation.getSink(transform), + WriteFilesTranslation.getFormatFunction(transform)); if (WriteFilesTranslation.isWindowedWrites(transform)) { replacement = replacement.withWindowedWrites(); } diff --git a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java index bc1a04247c223..94985f883a32d 100644 --- a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java +++ b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java @@ -76,6 +76,7 @@ import org.apache.beam.sdk.extensions.gcp.auth.NoopCredentialFactory; import org.apache.beam.sdk.extensions.gcp.auth.TestCredential; import org.apache.beam.sdk.extensions.gcp.storage.NoopPathValidator; +import org.apache.beam.sdk.io.DynamicFileDestinations; import org.apache.beam.sdk.io.FileBasedSink; import org.apache.beam.sdk.io.FileSystems; import org.apache.beam.sdk.io.TextIO; @@ -100,6 +101,7 @@ import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.SerializableFunctions; import org.apache.beam.sdk.transforms.windowing.Sessions; import org.apache.beam.sdk.transforms.windowing.Window; import org.apache.beam.sdk.util.GcsUtil; @@ -1263,30 +1265,39 @@ public void testMergingStatefulRejectedInBatch() throws Exception { private void testStreamingWriteOverride(PipelineOptions options, int expectedNumShards) { TestPipeline p = TestPipeline.fromOptions(options); - StreamingShardedWriteFactory factory = + StreamingShardedWriteFactory factory = new StreamingShardedWriteFactory<>(p.getOptions()); - WriteFiles original = WriteFiles.to(new TestSink(tmpFolder.toString())); + WriteFiles original = + WriteFiles.to(new TestSink(tmpFolder.toString()), SerializableFunctions.identity()); PCollection objs = (PCollection) p.apply(Create.empty(VoidCoder.of())); - AppliedPTransform, PDone, WriteFiles> originalApplication = - AppliedPTransform.of( - "writefiles", objs.expand(), Collections., PValue>emptyMap(), original, p); - - WriteFiles replacement = (WriteFiles) - factory.getReplacementTransform(originalApplication).getTransform(); + AppliedPTransform, PDone, WriteFiles> + originalApplication = + AppliedPTransform.of( + "writefiles", + objs.expand(), + Collections., PValue>emptyMap(), + original, + p); + + WriteFiles replacement = + (WriteFiles) + factory.getReplacementTransform(originalApplication).getTransform(); assertThat(replacement, not(equalTo((Object) original))); assertThat(replacement.getNumShards().get(), equalTo(expectedNumShards)); } - private static class TestSink extends FileBasedSink { + private static class TestSink extends FileBasedSink { @Override public void validate(PipelineOptions options) {} TestSink(String tmpFolder) { - super(StaticValueProvider.of(FileSystems.matchNewResource(tmpFolder, true)), - null); + super( + StaticValueProvider.of(FileSystems.matchNewResource(tmpFolder, true)), + DynamicFileDestinations.constant(null)); } + @Override - public WriteOperation createWriteOperation() { + public WriteOperation createWriteOperation() { throw new IllegalArgumentException("Should not be used"); } } diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkRunnerDebuggerTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkRunnerDebuggerTest.java index 64ff98cebfd86..246eb81ecb988 100644 --- a/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkRunnerDebuggerTest.java +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkRunnerDebuggerTest.java @@ -52,7 +52,6 @@ import org.joda.time.Duration; import org.junit.Test; - /** * Test {@link SparkRunnerDebugger} with different pipelines. */ @@ -85,17 +84,20 @@ public void debugBatchPipeline() { .apply(MapElements.via(new WordCount.FormatAsTextFn())) .apply(TextIO.write().to("!!PLACEHOLDER-OUTPUT-DIR!!").withNumShards(3).withSuffix(".txt")); - final String expectedPipeline = "sparkContext.parallelize(Arrays.asList(...))\n" - + "_.mapPartitions(new org.apache.beam.runners.spark.examples.WordCount$ExtractWordsFn())\n" - + "_.mapPartitions(new org.apache.beam.sdk.transforms.Count$PerElement$1())\n" - + "_.combineByKey(..., new org.apache.beam.sdk.transforms.Count$CountFn(), ...)\n" - + "_.groupByKey()\n" - + "_.map(new org.apache.beam.sdk.transforms.Sum$SumLongFn())\n" - + "_.mapPartitions(new org.apache.beam.runners.spark" - + ".SparkRunnerDebuggerTest$PlusOne())\n" - + "sparkContext.union(...)\n" - + "_.mapPartitions(new org.apache.beam.runners.spark.examples.WordCount$FormatAsTextFn())\n" - + "_."; + final String expectedPipeline = + "sparkContext.parallelize(Arrays.asList(...))\n" + + "_.mapPartitions(" + + "new org.apache.beam.runners.spark.examples.WordCount$ExtractWordsFn())\n" + + "_.mapPartitions(new org.apache.beam.sdk.transforms.Count$PerElement$1())\n" + + "_.combineByKey(..., new org.apache.beam.sdk.transforms.Count$CountFn(), ...)\n" + + "_.groupByKey()\n" + + "_.map(new org.apache.beam.sdk.transforms.Sum$SumLongFn())\n" + + "_.mapPartitions(new org.apache.beam.runners.spark" + + ".SparkRunnerDebuggerTest$PlusOne())\n" + + "sparkContext.union(...)\n" + + "_.mapPartitions(" + + "new org.apache.beam.runners.spark.examples.WordCount$FormatAsTextFn())\n" + + "_."; SparkRunnerDebugger.DebugSparkPipelineResult result = (SparkRunnerDebugger.DebugSparkPipelineResult) pipeline.run(); diff --git a/sdks/common/runner-api/src/main/proto/beam_runner_api.proto b/sdks/common/runner-api/src/main/proto/beam_runner_api.proto index 24e907a72dad3..1f74afb52a918 100644 --- a/sdks/common/runner-api/src/main/proto/beam_runner_api.proto +++ b/sdks/common/runner-api/src/main/proto/beam_runner_api.proto @@ -367,9 +367,12 @@ message WriteFilesPayload { // (Required) The SdkFunctionSpec of the FileBasedSink. SdkFunctionSpec sink = 1; - bool windowed_writes = 2; + // (Required) The format function. + SdkFunctionSpec format_function = 2; - bool runner_determined_sharding = 3; + bool windowed_writes = 3; + + bool runner_determined_sharding = 4; } // A coder, the binary format for serialization and deserialization of data in diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/ShardedKeyCoder.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/ShardedKeyCoder.java similarity index 80% rename from sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/ShardedKeyCoder.java rename to sdks/java/core/src/main/java/org/apache/beam/sdk/coders/ShardedKeyCoder.java index c2b62b72357e1..a86b198543288 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/ShardedKeyCoder.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/ShardedKeyCoder.java @@ -16,7 +16,7 @@ * limitations under the License. */ -package org.apache.beam.sdk.io.gcp.bigquery; +package org.apache.beam.sdk.coders; import com.google.common.annotations.VisibleForTesting; import java.io.IOException; @@ -24,17 +24,11 @@ import java.io.OutputStream; import java.util.Arrays; import java.util.List; -import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.coders.StructuredCoder; -import org.apache.beam.sdk.coders.VarIntCoder; +import org.apache.beam.sdk.values.ShardedKey; - -/** - * A {@link Coder} for {@link ShardedKey}, using a wrapped key {@link Coder}. - */ +/** A {@link Coder} for {@link ShardedKey}, using a wrapped key {@link Coder}. */ @VisibleForTesting -class ShardedKeyCoder - extends StructuredCoder> { +public class ShardedKeyCoder extends StructuredCoder> { public static ShardedKeyCoder of(Coder keyCoder) { return new ShardedKeyCoder<>(keyCoder); } @@ -62,9 +56,7 @@ public void encode(ShardedKey key, OutputStream outStream) @Override public ShardedKey decode(InputStream inStream) throws IOException { - return new ShardedKey<>( - keyCoder.decode(inStream), - shardNumberCoder.decode(inStream)); + return ShardedKey.of(keyCoder.decode(inStream), shardNumberCoder.decode(inStream)); } @Override diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/AvroIO.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/AvroIO.java index 4143db29096c8..89cadbdede70b 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/AvroIO.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/AvroIO.java @@ -18,7 +18,6 @@ package org.apache.beam.sdk.io; import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Preconditions.checkState; import com.google.auto.value.AutoValue; import com.google.common.collect.ImmutableMap; @@ -35,6 +34,7 @@ import org.apache.beam.sdk.coders.AvroCoder; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.VoidCoder; +import org.apache.beam.sdk.io.FileBasedSink.DynamicDestinations; import org.apache.beam.sdk.io.FileBasedSink.FilenamePolicy; import org.apache.beam.sdk.io.Read.Bounded; import org.apache.beam.sdk.io.fs.ResourceId; @@ -43,6 +43,7 @@ import org.apache.beam.sdk.options.ValueProvider.StaticValueProvider; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.transforms.SerializableFunctions; import org.apache.beam.sdk.transforms.display.DisplayData; import org.apache.beam.sdk.transforms.display.HasDisplayData; import org.apache.beam.sdk.values.PBegin; @@ -52,18 +53,19 @@ /** * {@link PTransform}s for reading and writing Avro files. * - *

    To read a {@link PCollection} from one or more Avro files, use {@code AvroIO.read()}, - * using {@link AvroIO.Read#from} to specify the filename or filepattern to read from. - * See {@link FileSystems} for information on supported file systems and filepatterns. + *

    To read a {@link PCollection} from one or more Avro files, use {@code AvroIO.read()}, using + * {@link AvroIO.Read#from} to specify the filename or filepattern to read from. See {@link + * FileSystems} for information on supported file systems and filepatterns. * - *

    To read specific records, such as Avro-generated classes, use {@link #read(Class)}. - * To read {@link GenericRecord GenericRecords}, use {@link #readGenericRecords(Schema)} which takes - * a {@link Schema} object, or {@link #readGenericRecords(String)} which takes an Avro schema in a + *

    To read specific records, such as Avro-generated classes, use {@link #read(Class)}. To read + * {@link GenericRecord GenericRecords}, use {@link #readGenericRecords(Schema)} which takes a + * {@link Schema} object, or {@link #readGenericRecords(String)} which takes an Avro schema in a * JSON-encoded string form. An exception will be thrown if a record doesn't match the specified * schema. * *

    For example: - *

     {@code
    + *
    + * 
    {@code
      * Pipeline p = ...;
      *
      * // A simple Read of a local file (only runs locally):
    @@ -75,34 +77,33 @@
      * PCollection records =
      *     p.apply(AvroIO.readGenericRecords(schema)
      *                .from("gs://my_bucket/path/to/records-*.avro"));
    - * } 
    + * }
    * *

    To write a {@link PCollection} to one or more Avro files, use {@link AvroIO.Write}, using - * {@code AvroIO.write().to(String)} to specify the output filename prefix. The default - * {@link DefaultFilenamePolicy} will use this prefix, in conjunction with a - * {@link ShardNameTemplate} (set via {@link Write#withShardNameTemplate(String)}) and optional - * filename suffix (set via {@link Write#withSuffix(String)}, to generate output filenames in a - * sharded way. You can override this default write filename policy using - * {@link Write#withFilenamePolicy(FileBasedSink.FilenamePolicy)} to specify a custom file naming - * policy. + * {@code AvroIO.write().to(String)} to specify the output filename prefix. The default {@link + * DefaultFilenamePolicy} will use this prefix, in conjunction with a {@link ShardNameTemplate} (set + * via {@link Write#withShardNameTemplate(String)}) and optional filename suffix (set via {@link + * Write#withSuffix(String)}, to generate output filenames in a sharded way. You can override this + * default write filename policy using {@link Write#to(FileBasedSink.FilenamePolicy)} to specify a + * custom file naming policy. * *

    By default, all input is put into the global window before writing. If per-window writes are - * desired - for example, when using a streaming runner - - * {@link AvroIO.Write#withWindowedWrites()} will cause windowing and triggering to be - * preserved. When producing windowed writes, the number of output shards must be set explicitly - * using {@link AvroIO.Write#withNumShards(int)}; some runners may set this for you to a - * runner-chosen value, so you may need not set it yourself. A - * {@link FileBasedSink.FilenamePolicy} must be set, and unique windows and triggers must produce - * unique filenames. + * desired - for example, when using a streaming runner - {@link AvroIO.Write#withWindowedWrites()} + * will cause windowing and triggering to be preserved. When producing windowed writes with a + * streaming runner that supports triggers, the number of output shards must be set explicitly using + * {@link AvroIO.Write#withNumShards(int)}; some runners may set this for you to a runner-chosen + * value, so you may need not set it yourself. A {@link FileBasedSink.FilenamePolicy} must be set, + * and unique windows and triggers must produce unique filenames. * - *

    To write specific records, such as Avro-generated classes, use {@link #write(Class)}. - * To write {@link GenericRecord GenericRecords}, use either {@link #writeGenericRecords(Schema)} - * which takes a {@link Schema} object, or {@link #writeGenericRecords(String)} which takes a schema - * in a JSON-encoded string form. An exception will be thrown if a record doesn't match the - * specified schema. + *

    To write specific records, such as Avro-generated classes, use {@link #write(Class)}. To write + * {@link GenericRecord GenericRecords}, use either {@link #writeGenericRecords(Schema)} which takes + * a {@link Schema} object, or {@link #writeGenericRecords(String)} which takes a schema in a + * JSON-encoded string form. An exception will be thrown if a record doesn't match the specified + * schema. * *

    For example: - *

     {@code
    + *
    + * 
    {@code
      * // A simple Write to a local file (only runs locally):
      * PCollection records = ...;
      * records.apply(AvroIO.write(AvroAutoGenClass.class).to("/path/to/file.avro"));
    @@ -113,11 +114,11 @@
      * records.apply("WriteToAvro", AvroIO.writeGenericRecords(schema)
      *     .to("gs://my_bucket/path/to/numbers")
      *     .withSuffix(".avro"));
    - * } 
    + * }
    * - *

    By default, {@link AvroIO.Write} produces output files that are compressed using the - * {@link org.apache.avro.file.Codec CodecFactory.deflateCodec(6)}. This default can - * be changed or overridden using {@link AvroIO.Write#withCodec}. + *

    By default, {@link AvroIO.Write} produces output files that are compressed using the {@link + * org.apache.avro.file.Codec CodecFactory.deflateCodec(6)}. This default can be changed or + * overridden using {@link AvroIO.Write#withCodec}. */ public class AvroIO { /** @@ -258,11 +259,16 @@ public abstract static class Write extends PTransform, PDone> @Nullable abstract ValueProvider getFilenamePrefix(); @Nullable abstract String getShardTemplate(); @Nullable abstract String getFilenameSuffix(); + + @Nullable + abstract ValueProvider getTempDirectory(); + abstract int getNumShards(); @Nullable abstract Class getRecordClass(); @Nullable abstract Schema getSchema(); abstract boolean getWindowedWrites(); @Nullable abstract FilenamePolicy getFilenamePolicy(); + /** * The codec used to encode the blocks in the Avro file. String value drawn from those in * https://avro.apache.org/docs/1.7.7/api/java/org/apache/avro/file/CodecFactory.html @@ -277,6 +283,9 @@ public abstract static class Write extends PTransform, PDone> abstract static class Builder { abstract Builder setFilenamePrefix(ValueProvider filenamePrefix); abstract Builder setFilenameSuffix(String filenameSuffix); + + abstract Builder setTempDirectory(ValueProvider tempDirectory); + abstract Builder setNumShards(int numShards); abstract Builder setShardTemplate(String shardTemplate); abstract Builder setRecordClass(Class recordClass); @@ -296,9 +305,9 @@ abstract static class Builder { *

    The name of the output files will be determined by the {@link FilenamePolicy} used. * *

    By default, a {@link DefaultFilenamePolicy} will build output filenames using the - * specified prefix, a shard name template (see {@link #withShardNameTemplate(String)}, and - * a common suffix (if supplied using {@link #withSuffix(String)}). This default can be - * overridden using {@link #withFilenamePolicy(FilenamePolicy)}. + * specified prefix, a shard name template (see {@link #withShardNameTemplate(String)}, and a + * common suffix (if supplied using {@link #withSuffix(String)}). This default can be overridden + * using {@link #to(FilenamePolicy)}. */ public Write to(String outputPrefix) { return to(FileBasedSink.convertToFileResourceIfPossible(outputPrefix)); @@ -306,14 +315,21 @@ public Write to(String outputPrefix) { /** * Writes to file(s) with the given output prefix. See {@link FileSystems} for information on - * supported file systems. - * - *

    The name of the output files will be determined by the {@link FilenamePolicy} used. + * supported file systems. This prefix is used by the {@link DefaultFilenamePolicy} to generate + * filenames. * *

    By default, a {@link DefaultFilenamePolicy} will build output filenames using the - * specified prefix, a shard name template (see {@link #withShardNameTemplate(String)}, and - * a common suffix (if supplied using {@link #withSuffix(String)}). This default can be - * overridden using {@link #withFilenamePolicy(FilenamePolicy)}. + * specified prefix, a shard name template (see {@link #withShardNameTemplate(String)}, and a + * common suffix (if supplied using {@link #withSuffix(String)}). This default can be overridden + * using {@link #to(FilenamePolicy)}. + * + *

    This default policy can be overridden using {@link #to(FilenamePolicy)}, in which case + * {@link #withShardNameTemplate(String)} and {@link #withSuffix(String)} should not be set. + * Custom filename policies do not automatically see this prefix - you should explicitly pass + * the prefix into your {@link FilenamePolicy} object if you need this. + * + *

    If {@link #withTempDirectory} has not been called, this filename prefix will be used to + * infer a directory for temporary files. */ @Experimental(Kind.FILESYSTEM) public Write to(ResourceId outputPrefix) { @@ -342,15 +358,22 @@ public Write toResource(ValueProvider outputPrefix) { } /** - * Configures the {@link FileBasedSink.FilenamePolicy} that will be used to name written files. + * Writes to files named according to the given {@link FileBasedSink.FilenamePolicy}. A + * directory for temporary files must be specified using {@link #withTempDirectory}. */ - public Write withFilenamePolicy(FilenamePolicy filenamePolicy) { + public Write to(FilenamePolicy filenamePolicy) { return toBuilder().setFilenamePolicy(filenamePolicy).build(); } + /** Set the base directory used to generate temporary files. */ + @Experimental(Kind.FILESYSTEM) + public Write withTempDirectory(ValueProvider tempDirectory) { + return toBuilder().setTempDirectory(tempDirectory).build(); + } + /** * Uses the given {@link ShardNameTemplate} for naming output files. This option may only be - * used when {@link #withFilenamePolicy(FilenamePolicy)} has not been configured. + * used when using one of the default filename-prefix to() overrides. * *

    See {@link DefaultFilenamePolicy} for how the prefix, shard name template, and suffix are * used. @@ -360,8 +383,8 @@ public Write withShardNameTemplate(String shardTemplate) { } /** - * Configures the filename suffix for written files. This option may only be used when - * {@link #withFilenamePolicy(FilenamePolicy)} has not been configured. + * Configures the filename suffix for written files. This option may only be used when using one + * of the default filename-prefix to() overrides. * *

    See {@link DefaultFilenamePolicy} for how the prefix, shard name template, and suffix are * used. @@ -402,9 +425,8 @@ public Write withoutSharding() { /** * Preserves windowing of input elements and writes them to files based on the element's window. * - *

    Requires use of {@link #withFilenamePolicy(FileBasedSink.FilenamePolicy)}. Filenames will - * be generated using {@link FilenamePolicy#windowedFilename}. See also - * {@link WriteFiles#withWindowedWrites()}. + *

    If using {@link #to(FileBasedSink.FilenamePolicy)}. Filenames will be generated using + * {@link FilenamePolicy#windowedFilename}. See also {@link WriteFiles#withWindowedWrites()}. */ public Write withWindowedWrites() { return toBuilder().setWindowedWrites(true).build(); @@ -435,32 +457,46 @@ public Write withMetadata(Map metadata) { return toBuilder().setMetadata(ImmutableMap.copyOf(metadata)).build(); } - @Override - public PDone expand(PCollection input) { - checkState(getFilenamePrefix() != null, - "Need to set the filename prefix of an AvroIO.Write transform."); - checkState( - (getFilenamePolicy() == null) - || (getShardTemplate() == null && getFilenameSuffix() == null), - "Cannot set a filename policy and also a filename template or suffix."); - checkState(getSchema() != null, - "Need to set the schema of an AvroIO.Write transform."); - checkState(!getWindowedWrites() || (getFilenamePolicy() != null), - "When using windowed writes, a filename policy must be set via withFilenamePolicy()."); - + DynamicDestinations resolveDynamicDestinations() { FilenamePolicy usedFilenamePolicy = getFilenamePolicy(); if (usedFilenamePolicy == null) { - usedFilenamePolicy = DefaultFilenamePolicy.constructUsingStandardParameters( - getFilenamePrefix(), getShardTemplate(), getFilenameSuffix(), getWindowedWrites()); + usedFilenamePolicy = + DefaultFilenamePolicy.fromStandardParameters( + getFilenamePrefix(), getShardTemplate(), getFilenameSuffix(), getWindowedWrites()); + } + return DynamicFileDestinations.constant(usedFilenamePolicy); + } + + @Override + public PDone expand(PCollection input) { + checkArgument( + getFilenamePrefix() != null || getTempDirectory() != null, + "Need to set either the filename prefix or the tempDirectory of a AvroIO.Write " + + "transform."); + if (getFilenamePolicy() != null) { + checkArgument( + getShardTemplate() == null && getFilenameSuffix() == null, + "shardTemplate and filenameSuffix should only be used with the default " + + "filename policy"); } + return expandTyped(input, resolveDynamicDestinations()); + } - WriteFiles write = WriteFiles.to( - new AvroSink<>( - getFilenamePrefix(), - usedFilenamePolicy, - AvroCoder.of(getRecordClass(), getSchema()), - getCodec(), - getMetadata())); + public PDone expandTyped( + PCollection input, DynamicDestinations dynamicDestinations) { + ValueProvider tempDirectory = getTempDirectory(); + if (tempDirectory == null) { + tempDirectory = getFilenamePrefix(); + } + WriteFiles write = + WriteFiles.to( + new AvroSink<>( + tempDirectory, + dynamicDestinations, + AvroCoder.of(getRecordClass(), getSchema()), + getCodec(), + getMetadata()), + SerializableFunctions.identity()); if (getNumShards() > 0) { write = write.withNumShards(getNumShards()); } @@ -473,31 +509,25 @@ public PDone expand(PCollection input) { @Override public void populateDisplayData(DisplayData.Builder builder) { super.populateDisplayData(builder); - checkState( - getFilenamePrefix() != null, - "Unable to populate DisplayData for invalid AvroIO.Write (unset output prefix)."); - String outputPrefixString = null; - if (getFilenamePrefix().isAccessible()) { - ResourceId dir = getFilenamePrefix().get(); - outputPrefixString = dir.toString(); - } else { - outputPrefixString = getFilenamePrefix().toString(); + resolveDynamicDestinations().populateDisplayData(builder); + + String tempDirectory = null; + if (getTempDirectory() != null) { + tempDirectory = + getTempDirectory().isAccessible() + ? getTempDirectory().get().toString() + : getTempDirectory().toString(); } builder - .add(DisplayData.item("schema", getRecordClass()) - .withLabel("Record Schema")) - .addIfNotNull(DisplayData.item("filePrefix", outputPrefixString) - .withLabel("Output File Prefix")) - .addIfNotNull(DisplayData.item("shardNameTemplate", getShardTemplate()) - .withLabel("Output Shard Name Template")) - .addIfNotNull(DisplayData.item("fileSuffix", getFilenameSuffix()) - .withLabel("Output File Suffix")) - .addIfNotDefault(DisplayData.item("numShards", getNumShards()) - .withLabel("Maximum Output Shards"), - 0) - .addIfNotDefault(DisplayData.item("codec", getCodec().toString()) - .withLabel("Avro Compression Codec"), - DEFAULT_CODEC.toString()); + .add(DisplayData.item("schema", getRecordClass()).withLabel("Record Schema")) + .addIfNotDefault( + DisplayData.item("numShards", getNumShards()).withLabel("Maximum Output Shards"), 0) + .addIfNotDefault( + DisplayData.item("codec", getCodec().toString()).withLabel("Avro Compression Codec"), + DEFAULT_CODEC.toString()) + .addIfNotNull( + DisplayData.item("tempDirectory", tempDirectory) + .withLabel("Directory for temporary files")); builder.include("Metadata", new Metadata()); } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/AvroSink.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/AvroSink.java index 6c362664d8fa2..c78870b5d7ece 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/AvroSink.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/AvroSink.java @@ -32,39 +32,40 @@ import org.apache.beam.sdk.util.MimeTypes; /** A {@link FileBasedSink} for Avro files. */ -class AvroSink extends FileBasedSink { +class AvroSink extends FileBasedSink { private final AvroCoder coder; private final SerializableAvroCodecFactory codec; private final ImmutableMap metadata; AvroSink( ValueProvider outputPrefix, - FilenamePolicy filenamePolicy, + DynamicDestinations dynamicDestinations, AvroCoder coder, SerializableAvroCodecFactory codec, ImmutableMap metadata) { // Avro handle compression internally using the codec. - super(outputPrefix, filenamePolicy, CompressionType.UNCOMPRESSED); + super(outputPrefix, dynamicDestinations, CompressionType.UNCOMPRESSED); this.coder = coder; this.codec = codec; this.metadata = metadata; } @Override - public WriteOperation createWriteOperation() { + public WriteOperation createWriteOperation() { return new AvroWriteOperation<>(this, coder, codec, metadata); } /** A {@link WriteOperation WriteOperation} for Avro files. */ - private static class AvroWriteOperation extends WriteOperation { + private static class AvroWriteOperation extends WriteOperation { private final AvroCoder coder; private final SerializableAvroCodecFactory codec; private final ImmutableMap metadata; - private AvroWriteOperation(AvroSink sink, - AvroCoder coder, - SerializableAvroCodecFactory codec, - ImmutableMap metadata) { + private AvroWriteOperation( + AvroSink sink, + AvroCoder coder, + SerializableAvroCodecFactory codec, + ImmutableMap metadata) { super(sink); this.coder = coder; this.codec = codec; @@ -72,22 +73,23 @@ private AvroWriteOperation(AvroSink sink, } @Override - public Writer createWriter() throws Exception { + public Writer createWriter() throws Exception { return new AvroWriter<>(this, coder, codec, metadata); } } /** A {@link Writer Writer} for Avro files. */ - private static class AvroWriter extends Writer { + private static class AvroWriter extends Writer { private final AvroCoder coder; private DataFileWriter dataFileWriter; private SerializableAvroCodecFactory codec; private final ImmutableMap metadata; - public AvroWriter(WriteOperation writeOperation, - AvroCoder coder, - SerializableAvroCodecFactory codec, - ImmutableMap metadata) { + public AvroWriter( + WriteOperation writeOperation, + AvroCoder coder, + SerializableAvroCodecFactory codec, + ImmutableMap metadata) { super(writeOperation, MimeTypes.BINARY); this.coder = coder; this.codec = codec; diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/DefaultFilenamePolicy.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/DefaultFilenamePolicy.java index f9e4ac4a11d4c..7a60e49ebfb03 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/DefaultFilenamePolicy.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/DefaultFilenamePolicy.java @@ -20,25 +20,31 @@ import static com.google.common.base.MoreObjects.firstNonNull; import com.google.common.annotations.VisibleForTesting; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.Serializable; import java.text.DecimalFormat; import java.util.Arrays; import java.util.regex.Matcher; import java.util.regex.Pattern; import javax.annotation.Nullable; +import org.apache.beam.sdk.coders.AtomicCoder; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.CoderException; +import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.io.FileBasedSink.FilenamePolicy; +import org.apache.beam.sdk.io.FileBasedSink.OutputFileHints; import org.apache.beam.sdk.io.fs.ResolveOptions.StandardResolveOptions; import org.apache.beam.sdk.io.fs.ResourceId; import org.apache.beam.sdk.options.ValueProvider; -import org.apache.beam.sdk.options.ValueProvider.NestedValueProvider; -import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.options.ValueProvider.StaticValueProvider; import org.apache.beam.sdk.transforms.display.DisplayData; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.GlobalWindow; import org.apache.beam.sdk.transforms.windowing.IntervalWindow; import org.apache.beam.sdk.transforms.windowing.PaneInfo; import org.apache.beam.sdk.transforms.windowing.PaneInfo.Timing; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; /** * A default {@link FilenamePolicy} for windowed and unwindowed files. This policy is constructed @@ -51,10 +57,7 @@ * {@code WriteOneFilePerWindow} example pipeline. */ public final class DefaultFilenamePolicy extends FilenamePolicy { - - private static final Logger LOG = LoggerFactory.getLogger(DefaultFilenamePolicy.class); - - /** The default sharding name template used in {@link #constructUsingStandardParameters}. */ + /** The default sharding name template. */ public static final String DEFAULT_UNWINDOWED_SHARD_TEMPLATE = ShardNameTemplate.INDEX_OF_MAX; /** The default windowed sharding name template used when writing windowed files. @@ -67,75 +70,184 @@ public final class DefaultFilenamePolicy extends FilenamePolicy { "W-P" + DEFAULT_UNWINDOWED_SHARD_TEMPLATE; /* - * pattern for both windowed and non-windowed file names + * pattern for both windowed and non-windowed file names. */ private static final Pattern SHARD_FORMAT_RE = Pattern.compile("(S+|N+|W|P)"); + /** + * Encapsulates constructor parameters to {@link DefaultFilenamePolicy}. + * + *

    This is used as the {@code DestinationT} argument to allow {@link DefaultFilenamePolicy} + * objects to be dynamically generated. + */ + public static class Params implements Serializable { + private final ValueProvider baseFilename; + private final String shardTemplate; + private final boolean explicitTemplate; + private final String suffix; + + /** + * Construct a default Params object. The shard template will be set to the default {@link + * #DEFAULT_UNWINDOWED_SHARD_TEMPLATE} value. + */ + public Params() { + this.baseFilename = null; + this.shardTemplate = DEFAULT_UNWINDOWED_SHARD_TEMPLATE; + this.suffix = ""; + this.explicitTemplate = false; + } + + private Params( + ValueProvider baseFilename, + String shardTemplate, + String suffix, + boolean explicitTemplate) { + this.baseFilename = baseFilename; + this.shardTemplate = shardTemplate; + this.suffix = suffix; + this.explicitTemplate = explicitTemplate; + } + + /** + * Specify that writes are windowed. This affects the default shard template, changing it to + * {@link #DEFAULT_WINDOWED_SHARD_TEMPLATE}. + */ + public Params withWindowedWrites() { + String template = this.shardTemplate; + if (!explicitTemplate) { + template = DEFAULT_WINDOWED_SHARD_TEMPLATE; + } + return new Params(baseFilename, template, suffix, explicitTemplate); + } + + /** Sets the base filename. */ + public Params withBaseFilename(ResourceId baseFilename) { + return withBaseFilename(StaticValueProvider.of(baseFilename)); + } + + /** Like {@link #withBaseFilename(ResourceId)}, but takes in a {@link ValueProvider}. */ + public Params withBaseFilename(ValueProvider baseFilename) { + return new Params(baseFilename, shardTemplate, suffix, explicitTemplate); + } + + /** Sets the shard template. */ + public Params withShardTemplate(String shardTemplate) { + return new Params(baseFilename, shardTemplate, suffix, true); + } + + /** Sets the suffix. */ + public Params withSuffix(String suffix) { + return new Params(baseFilename, shardTemplate, suffix, explicitTemplate); + } + } + + /** A Coder for {@link Params}. */ + public static class ParamsCoder extends AtomicCoder { + private static final ParamsCoder INSTANCE = new ParamsCoder(); + private Coder stringCoder = StringUtf8Coder.of(); + + public static ParamsCoder of() { + return INSTANCE; + } + + @Override + public void encode(Params value, OutputStream outStream) throws IOException { + if (value == null) { + throw new CoderException("cannot encode a null value"); + } + stringCoder.encode(value.baseFilename.get().toString(), outStream); + stringCoder.encode(value.shardTemplate, outStream); + stringCoder.encode(value.suffix, outStream); + } + + @Override + public Params decode(InputStream inStream) throws IOException { + ResourceId prefix = + FileBasedSink.convertToFileResourceIfPossible(stringCoder.decode(inStream)); + String shardTemplate = stringCoder.decode(inStream); + String suffix = stringCoder.decode(inStream); + return new Params() + .withBaseFilename(prefix) + .withShardTemplate(shardTemplate) + .withSuffix(suffix); + } + } + + private final Params params; /** * Constructs a new {@link DefaultFilenamePolicy}. * * @see DefaultFilenamePolicy for more information on the arguments to this function. */ @VisibleForTesting - DefaultFilenamePolicy(ValueProvider prefix, String shardTemplate, String suffix) { - this.prefix = prefix; - this.shardTemplate = shardTemplate; - this.suffix = suffix; + DefaultFilenamePolicy(Params params) { + this.params = params; } /** - * A helper function to construct a {@link DefaultFilenamePolicy} using the standard filename - * parameters, namely a provided {@link ResourceId} for the output prefix, and possibly-null - * shard name template and suffix. + * Construct a {@link DefaultFilenamePolicy}. * - *

    Any filename component of the provided resource will be used as the filename prefix. + *

    This is a shortcut for: * - *

    If provided, the shard name template will be used; otherwise - * {@link #DEFAULT_UNWINDOWED_SHARD_TEMPLATE} will be used for non-windowed file names and - * {@link #DEFAULT_WINDOWED_SHARD_TEMPLATE} will be used for windowed file names. + *

    {@code
    +   *   DefaultFilenamePolicy.fromParams(new Params()
    +   *     .withBaseFilename(baseFilename)
    +   *     .withShardTemplate(shardTemplate)
    +   *     .withSuffix(filenameSuffix)
    +   *     .withWindowedWrites())
    +   * }
    * - *

    If provided, the suffix will be used; otherwise the files will have an empty suffix. + *

    Where the respective {@code with} methods are invoked only if the value is non-null or true. */ - public static DefaultFilenamePolicy constructUsingStandardParameters( - ValueProvider outputPrefix, + public static DefaultFilenamePolicy fromStandardParameters( + ValueProvider baseFilename, @Nullable String shardTemplate, @Nullable String filenameSuffix, boolean windowedWrites) { - // Pick the appropriate default policy based on whether windowed writes are being performed. - String defaultTemplate = - windowedWrites ? DEFAULT_WINDOWED_SHARD_TEMPLATE : DEFAULT_UNWINDOWED_SHARD_TEMPLATE; - return new DefaultFilenamePolicy( - NestedValueProvider.of(outputPrefix, new ExtractFilename()), - firstNonNull(shardTemplate, defaultTemplate), - firstNonNull(filenameSuffix, "")); + Params params = new Params().withBaseFilename(baseFilename); + if (shardTemplate != null) { + params = params.withShardTemplate(shardTemplate); + } + if (filenameSuffix != null) { + params = params.withSuffix(filenameSuffix); + } + if (windowedWrites) { + params = params.withWindowedWrites(); + } + return fromParams(params); } - private final ValueProvider prefix; - private final String shardTemplate; - private final String suffix; + /** Construct a {@link DefaultFilenamePolicy} from a {@link Params} object. */ + public static DefaultFilenamePolicy fromParams(Params params) { + return new DefaultFilenamePolicy(params); + } /** * Constructs a fully qualified name from components. * - *

    The name is built from a prefix, shard template (with shard numbers - * applied), and a suffix. All components are required, but may be empty - * strings. + *

    The name is built from a base filename, shard template (with shard numbers applied), and a + * suffix. All components are required, but may be empty strings. * - *

    Within a shard template, repeating sequences of the letters "S" or "N" - * are replaced with the shard number, or number of shards respectively. - * "P" is replaced with by stringification of current pane. - * "W" is replaced by stringification of current window. + *

    Within a shard template, repeating sequences of the letters "S" or "N" are replaced with the + * shard number, or number of shards respectively. "P" is replaced with by stringification of + * current pane. "W" is replaced by stringification of current window. * - *

    The numbers are formatted with leading zeros to match the length of the - * repeated sequence of letters. + *

    The numbers are formatted with leading zeros to match the length of the repeated sequence of + * letters. * - *

    For example, if prefix = "output", shardTemplate = "-SSS-of-NNN", and - * suffix = ".txt", with shardNum = 1 and numShards = 100, the following is - * produced: "output-001-of-100.txt". + *

    For example, if baseFilename = "path/to/output", shardTemplate = "-SSS-of-NNN", and suffix = + * ".txt", with shardNum = 1 and numShards = 100, the following is produced: + * "path/to/output-001-of-100.txt". */ - static String constructName( - String prefix, String shardTemplate, String suffix, int shardNum, int numShards, - String paneStr, String windowStr) { + static ResourceId constructName( + ResourceId baseFilename, + String shardTemplate, + String suffix, + int shardNum, + int numShards, + String paneStr, + String windowStr) { + String prefix = extractFilename(baseFilename); // Matcher API works with StringBuffer, rather than StringBuilder. StringBuffer sb = new StringBuffer(); sb.append(prefix); @@ -165,27 +277,37 @@ static String constructName( m.appendTail(sb); sb.append(suffix); - return sb.toString(); + return baseFilename + .getCurrentDirectory() + .resolve(sb.toString(), StandardResolveOptions.RESOLVE_FILE); } @Override @Nullable - public ResourceId unwindowedFilename(ResourceId outputDirectory, Context context, - String extension) { - String filename = constructName(prefix.get(), shardTemplate, suffix, context.getShardNumber(), - context.getNumShards(), null, null) + extension; - return outputDirectory.resolve(filename, StandardResolveOptions.RESOLVE_FILE); + public ResourceId unwindowedFilename(Context context, OutputFileHints outputFileHints) { + return constructName( + params.baseFilename.get(), + params.shardTemplate, + params.suffix + outputFileHints.getSuggestedFilenameSuffix(), + context.getShardNumber(), + context.getNumShards(), + null, + null); } @Override - public ResourceId windowedFilename(ResourceId outputDirectory, - WindowedContext context, String extension) { + public ResourceId windowedFilename(WindowedContext context, OutputFileHints outputFileHints) { final PaneInfo paneInfo = context.getPaneInfo(); String paneStr = paneInfoToString(paneInfo); String windowStr = windowToString(context.getWindow()); - String filename = constructName(prefix.get(), shardTemplate, suffix, context.getShardNumber(), - context.getNumShards(), paneStr, windowStr) + extension; - return outputDirectory.resolve(filename, StandardResolveOptions.RESOLVE_FILE); + return constructName( + params.baseFilename.get(), + params.shardTemplate, + params.suffix + outputFileHints.getSuggestedFilenameSuffix(), + context.getShardNumber(), + context.getNumShards(), + paneStr, + windowStr); } /* @@ -216,24 +338,32 @@ private String paneInfoToString(PaneInfo paneInfo) { @Override public void populateDisplayData(DisplayData.Builder builder) { String filenamePattern; - if (prefix.isAccessible()) { - filenamePattern = String.format("%s%s%s", prefix.get(), shardTemplate, suffix); + if (params.baseFilename.isAccessible()) { + filenamePattern = + String.format("%s%s%s", params.baseFilename.get(), params.shardTemplate, params.suffix); } else { - filenamePattern = String.format("%s%s%s", prefix, shardTemplate, suffix); + filenamePattern = + String.format("%s%s%s", params.baseFilename, params.shardTemplate, params.suffix); } + + String outputPrefixString = null; + outputPrefixString = + params.baseFilename.isAccessible() + ? params.baseFilename.get().toString() + : params.baseFilename.toString(); + builder.add(DisplayData.item("filenamePattern", filenamePattern).withLabel("Filename Pattern")); + builder.add(DisplayData.item("filePrefix", outputPrefixString).withLabel("Output File Prefix")); + builder.add(DisplayData.item("fileSuffix", params.suffix).withLabel("Output file Suffix")); builder.add( - DisplayData.item("filenamePattern", filenamePattern) - .withLabel("Filename Pattern")); + DisplayData.item("shardNameTemplate", params.shardTemplate) + .withLabel("Output Shard Name Template")); } - private static class ExtractFilename implements SerializableFunction { - @Override - public String apply(ResourceId input) { - if (input.isDirectory()) { - return ""; - } else { - return firstNonNull(input.getFilename(), ""); - } + private static String extractFilename(ResourceId input) { + if (input.isDirectory()) { + return ""; + } else { + return firstNonNull(input.getFilename(), ""); } } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/DynamicFileDestinations.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/DynamicFileDestinations.java new file mode 100644 index 0000000000000..e7ef0f69b497d --- /dev/null +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/DynamicFileDestinations.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.sdk.io; + +import javax.annotation.Nullable; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.io.DefaultFilenamePolicy.Params; +import org.apache.beam.sdk.io.DefaultFilenamePolicy.ParamsCoder; +import org.apache.beam.sdk.io.FileBasedSink.DynamicDestinations; +import org.apache.beam.sdk.io.FileBasedSink.FilenamePolicy; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.transforms.display.DisplayData; + +/** Some helper classes that derive from {@link FileBasedSink.DynamicDestinations}. */ +public class DynamicFileDestinations { + /** Always returns a constant {@link FilenamePolicy}. */ + private static class ConstantFilenamePolicy extends DynamicDestinations { + private final FilenamePolicy filenamePolicy; + + public ConstantFilenamePolicy(FilenamePolicy filenamePolicy) { + this.filenamePolicy = filenamePolicy; + } + + @Override + public Void getDestination(T element) { + return (Void) null; + } + + @Override + public Coder getDestinationCoder() { + return null; + } + + @Override + public Void getDefaultDestination() { + return (Void) null; + } + + @Override + public FilenamePolicy getFilenamePolicy(Void destination) { + return filenamePolicy; + } + + @Override + public void populateDisplayData(DisplayData.Builder builder) { + filenamePolicy.populateDisplayData(builder); + } + } + + /** + * A base class for a {@link DynamicDestinations} object that returns differently-configured + * instances of {@link DefaultFilenamePolicy}. + */ + private static class DefaultPolicyDestinations extends DynamicDestinations { + SerializableFunction destinationFunction; + Params emptyDestination; + + public DefaultPolicyDestinations( + SerializableFunction destinationFunction, Params emptyDestination) { + this.destinationFunction = destinationFunction; + this.emptyDestination = emptyDestination; + } + + @Override + public Params getDestination(UserT element) { + return destinationFunction.apply(element); + } + + @Override + public Params getDefaultDestination() { + return emptyDestination; + } + + @Nullable + @Override + public Coder getDestinationCoder() { + return ParamsCoder.of(); + } + + @Override + public FilenamePolicy getFilenamePolicy(DefaultFilenamePolicy.Params params) { + return DefaultFilenamePolicy.fromParams(params); + } + } + + /** Returns a {@link DynamicDestinations} that always returns the same {@link FilenamePolicy}. */ + public static DynamicDestinations constant(FilenamePolicy filenamePolicy) { + return new ConstantFilenamePolicy<>(filenamePolicy); + } + + /** + * Returns a {@link DynamicDestinations} that returns instances of {@link DefaultFilenamePolicy} + * configured with the given {@link Params}. + */ + public static DynamicDestinations toDefaultPolicies( + SerializableFunction destinationFunction, Params emptyDestination) { + return new DefaultPolicyDestinations<>(destinationFunction, emptyDestination); + } +} diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileBasedSink.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileBasedSink.java index 8102316b03f5d..583af60df68b3 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileBasedSink.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileBasedSink.java @@ -33,6 +33,7 @@ import java.io.InputStream; import java.io.OutputStream; import java.io.Serializable; +import java.lang.reflect.TypeVariable; import java.nio.channels.Channels; import java.nio.channels.WritableByteChannel; import java.util.ArrayList; @@ -49,8 +50,10 @@ import javax.annotation.Nullable; import org.apache.beam.sdk.annotations.Experimental; import org.apache.beam.sdk.annotations.Experimental.Kind; +import org.apache.beam.sdk.coders.CannotProvideCoderException; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.CoderException; +import org.apache.beam.sdk.coders.CoderRegistry; import org.apache.beam.sdk.coders.NullableCoder; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.coders.StructuredCoder; @@ -73,6 +76,7 @@ import org.apache.beam.sdk.transforms.windowing.PaneInfo; import org.apache.beam.sdk.transforms.windowing.PaneInfo.PaneInfoCoder; import org.apache.beam.sdk.util.MimeTypes; +import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.commons.compress.compressors.bzip2.BZip2CompressorOutputStream; import org.apache.commons.compress.compressors.deflate.DeflateCompressorOutputStream; import org.joda.time.Instant; @@ -82,43 +86,43 @@ import org.slf4j.LoggerFactory; /** - * Abstract class for file-based output. An implementation of FileBasedSink writes file-based - * output and defines the format of output files (how values are written, headers/footers, MIME - * type, etc.). + * Abstract class for file-based output. An implementation of FileBasedSink writes file-based output + * and defines the format of output files (how values are written, headers/footers, MIME type, + * etc.). * *

    At pipeline construction time, the methods of FileBasedSink are called to validate the sink * and to create a {@link WriteOperation} that manages the process of writing to the sink. * *

    The process of writing to file-based sink is as follows: + * *

      - *
    1. An optional subclass-defined initialization, - *
    2. a parallel write of bundles to temporary files, and finally, - *
    3. these temporary files are renamed with final output filenames. + *
    4. An optional subclass-defined initialization, + *
    5. a parallel write of bundles to temporary files, and finally, + *
    6. these temporary files are renamed with final output filenames. *
    * *

    In order to ensure fault-tolerance, a bundle may be executed multiple times (e.g., in the * event of failure/retry or for redundancy). However, exactly one of these executions will have its - * result passed to the finalize method. Each call to {@link Writer#openWindowed} - * or {@link Writer#openUnwindowed} is passed a unique bundle id when it is called - * by the WriteFiles transform, so even redundant or retried bundles will have a unique way of - * identifying - * their output. + * result passed to the finalize method. Each call to {@link Writer#openWindowed} or {@link + * Writer#openUnwindowed} is passed a unique bundle id when it is called by the WriteFiles + * transform, so even redundant or retried bundles will have a unique way of identifying their + * output. * *

    The bundle id should be used to guarantee that a bundle's output is unique. This uniqueness * guarantee is important; if a bundle is to be output to a file, for example, the name of the file * will encode the unique bundle id to avoid conflicts with other writers. * - * {@link FileBasedSink} can take a custom {@link FilenamePolicy} object to determine output - * filenames, and this policy object can be used to write windowed or triggered - * PCollections into separate files per window pane. This allows file output from unbounded - * PCollections, and also works for bounded PCollecctions. + *

    {@link FileBasedSink} can take a custom {@link FilenamePolicy} object to determine output + * filenames, and this policy object can be used to write windowed or triggered PCollections into + * separate files per window pane. This allows file output from unbounded PCollections, and also + * works for bounded PCollecctions. * *

    Supported file systems are those registered with {@link FileSystems}. * - * @param the type of values written to the sink. + * @param the type of values written to the sink. */ @Experimental(Kind.FILESYSTEM) -public abstract class FileBasedSink implements Serializable, HasDisplayData { +public abstract class FileBasedSink implements Serializable, HasDisplayData { private static final Logger LOG = LoggerFactory.getLogger(FileBasedSink.class); /** @@ -173,7 +177,7 @@ public WritableByteChannel create(WritableByteChannel channel) throws IOExceptio } @Override - public String getFilenameSuffix() { + public String getSuggestedFilenameSuffix() { return filenameSuffix; } @@ -205,6 +209,8 @@ public static ResourceId convertToFileResourceIfPossible(String outputPrefix) { } } + private final DynamicDestinations dynamicDestinations; + /** * The {@link WritableByteChannelFactory} that is used to wrap the raw data output to the * underlying channel. The default is to not compress the output using @@ -213,8 +219,70 @@ public static ResourceId convertToFileResourceIfPossible(String outputPrefix) { private final WritableByteChannelFactory writableByteChannelFactory; /** - * A naming policy for output files. + * A class that allows value-dependent writes in {@link FileBasedSink}. + * + *

    Users can define a custom type to represent destinations, and provide a mapping to turn this + * destination type into an instance of {@link FilenamePolicy}. */ + @Experimental(Kind.FILESYSTEM) + public abstract static class DynamicDestinations + implements HasDisplayData, Serializable { + /** + * Returns an object that represents at a high level the destination being written to. May not + * return null. + */ + public abstract DestinationT getDestination(UserT element); + + /** + * Returns the default destination. This is used for collections that have no elements as the + * destination to write empty files to. + */ + public abstract DestinationT getDefaultDestination(); + + /** + * Returns the coder for {@link DestinationT}. If this is not overridden, then the coder + * registry will be use to find a suitable coder. This must be a deterministic coder, as {@link + * DestinationT} will be used as a key type in a {@link + * org.apache.beam.sdk.transforms.GroupByKey}. + */ + @Nullable + public Coder getDestinationCoder() { + return null; + } + + /** Converts a destination into a {@link FilenamePolicy}. May not return null. */ + public abstract FilenamePolicy getFilenamePolicy(DestinationT destination); + + /** Populates the display data. */ + @Override + public void populateDisplayData(DisplayData.Builder builder) {} + + // Gets the destination coder. If the user does not provide one, try to find one in the coder + // registry. If no coder can be found, throws CannotProvideCoderException. + final Coder getDestinationCoderWithDefault(CoderRegistry registry) + throws CannotProvideCoderException { + Coder destinationCoder = getDestinationCoder(); + if (destinationCoder != null) { + return destinationCoder; + } + // If dynamicDestinations doesn't provide a coder, try to find it in the coder registry. + // We must first use reflection to figure out what the type parameter is. + TypeDescriptor superDescriptor = + TypeDescriptor.of(getClass()).getSupertype(DynamicDestinations.class); + if (!superDescriptor.getRawType().equals(DynamicDestinations.class)) { + throw new AssertionError( + "Couldn't find the DynamicDestinations superclass of " + this.getClass()); + } + TypeVariable typeVariable = superDescriptor.getTypeParameter("DestinationT"); + @SuppressWarnings("unchecked") + TypeDescriptor descriptor = + (TypeDescriptor) superDescriptor.resolveType(typeVariable); + return registry.getCoder(descriptor); + } + } + + /** A naming policy for output files. */ + @Experimental(Kind.FILESYSTEM) public abstract static class FilenamePolicy implements Serializable { /** * Context used for generating a name based on shard number, and num shards. @@ -287,29 +355,28 @@ public int getNumShards() { /** * When a sink has requested windowed or triggered output, this method will be invoked to return * the file {@link ResourceId resource} to be created given the base output directory and a - * (possibly empty) extension from {@link FileBasedSink} configuration - * (e.g., {@link CompressionType}). + * {@link OutputFileHints} containing information about the file, including a suggested + * extension (e.g. coming from {@link CompressionType}). * - *

    The {@link WindowedContext} object gives access to the window and pane, - * as well as sharding information. The policy must return unique and consistent filenames - * for different windows and panes. + *

    The {@link WindowedContext} object gives access to the window and pane, as well as + * sharding information. The policy must return unique and consistent filenames for different + * windows and panes. */ @Experimental(Kind.FILESYSTEM) - public abstract ResourceId windowedFilename( - ResourceId outputDirectory, WindowedContext c, String extension); + public abstract ResourceId windowedFilename(WindowedContext c, OutputFileHints outputFileHints); /** * When a sink has not requested windowed or triggered output, this method will be invoked to * return the file {@link ResourceId resource} to be created given the base output directory and - * a (possibly empty) extension applied by additional {@link FileBasedSink} configuration - * (e.g., {@link CompressionType}). + * a {@link OutputFileHints} containing information about the file, including a suggested (e.g. + * coming from {@link CompressionType}). * *

    The {@link Context} object only provides sharding information, which is used by the policy * to generate unique and consistent filenames. */ @Experimental(Kind.FILESYSTEM) - @Nullable public abstract ResourceId unwindowedFilename( - ResourceId outputDirectory, Context c, String extension); + @Nullable + public abstract ResourceId unwindowedFilename(Context c, OutputFileHints outputFileHints); /** * Populates the display data. @@ -318,19 +385,8 @@ public void populateDisplayData(DisplayData.Builder builder) { } } - /** The policy used to generate names of files to be produced. */ - private final FilenamePolicy filenamePolicy; /** The directory to which files will be written. */ - private final ValueProvider baseOutputDirectoryProvider; - - /** - * Construct a {@link FileBasedSink} with the given filename policy, producing uncompressed files. - */ - @Experimental(Kind.FILESYSTEM) - public FileBasedSink( - ValueProvider baseOutputDirectoryProvider, FilenamePolicy filenamePolicy) { - this(baseOutputDirectoryProvider, filenamePolicy, CompressionType.UNCOMPRESSED); - } + private final ValueProvider tempDirectoryProvider; private static class ExtractDirectory implements SerializableFunction { @Override @@ -340,95 +396,91 @@ public ResourceId apply(ResourceId input) { } /** - * Construct a {@link FileBasedSink} with the given filename policy and output channel type. + * Construct a {@link FileBasedSink} with the given temp directory, producing uncompressed files. */ @Experimental(Kind.FILESYSTEM) public FileBasedSink( - ValueProvider baseOutputDirectoryProvider, - FilenamePolicy filenamePolicy, + ValueProvider tempDirectoryProvider, + DynamicDestinations dynamicDestinations) { + this(tempDirectoryProvider, dynamicDestinations, CompressionType.UNCOMPRESSED); + } + + /** Construct a {@link FileBasedSink} with the given temp directory and output channel type. */ + @Experimental(Kind.FILESYSTEM) + public FileBasedSink( + ValueProvider tempDirectoryProvider, + DynamicDestinations dynamicDestinations, WritableByteChannelFactory writableByteChannelFactory) { - this.baseOutputDirectoryProvider = - NestedValueProvider.of(baseOutputDirectoryProvider, new ExtractDirectory()); - this.filenamePolicy = filenamePolicy; + this.tempDirectoryProvider = + NestedValueProvider.of(tempDirectoryProvider, new ExtractDirectory()); + this.dynamicDestinations = checkNotNull(dynamicDestinations); this.writableByteChannelFactory = writableByteChannelFactory; } - /** - * Returns the base directory inside which files will be written according to the configured - * {@link FilenamePolicy}. - */ - @Experimental(Kind.FILESYSTEM) - public ValueProvider getBaseOutputDirectoryProvider() { - return baseOutputDirectoryProvider; + /** Return the {@link DynamicDestinations} used. */ + @SuppressWarnings("unchecked") + public DynamicDestinations getDynamicDestinations() { + return (DynamicDestinations) dynamicDestinations; } /** - * Returns the policy by which files will be named inside of the base output directory. Note that - * the {@link FilenamePolicy} may itself specify one or more inner directories before each output - * file, say when writing windowed outputs in a {@code output/YYYY/MM/DD/file.txt} format. + * Returns the directory inside which temprary files will be written according to the configured + * {@link FilenamePolicy}. */ @Experimental(Kind.FILESYSTEM) - public final FilenamePolicy getFilenamePolicy() { - return filenamePolicy; + public ValueProvider getTempDirectoryProvider() { + return tempDirectoryProvider; } public void validate(PipelineOptions options) {} - /** - * Return a subclass of {@link WriteOperation} that will manage the write - * to the sink. - */ - public abstract WriteOperation createWriteOperation(); + /** Return a subclass of {@link WriteOperation} that will manage the write to the sink. */ + public abstract WriteOperation createWriteOperation(); public void populateDisplayData(DisplayData.Builder builder) { - getFilenamePolicy().populateDisplayData(builder); + getDynamicDestinations().populateDisplayData(builder); } /** * Abstract operation that manages the process of writing to {@link FileBasedSink}. * - *

    The primary responsibilities of the WriteOperation is the management of output - * files. During a write, {@link Writer}s write bundles to temporary file - * locations. After the bundles have been written, + *

    The primary responsibilities of the WriteOperation is the management of output files. During + * a write, {@link Writer}s write bundles to temporary file locations. After the bundles have been + * written, + * *

      - *
    1. {@link WriteOperation#finalize} is given a list of the temporary - * files containing the output bundles. - *
    2. During finalize, these temporary files are copied to final output locations and named - * according to a file naming template. - *
    3. Finally, any temporary files that were created during the write are removed. + *
    4. {@link WriteOperation#finalize} is given a list of the temporary files containing the + * output bundles. + *
    5. During finalize, these temporary files are copied to final output locations and named + * according to a file naming template. + *
    6. Finally, any temporary files that were created during the write are removed. *
    * - *

    Subclass implementations of WriteOperation must implement - * {@link WriteOperation#createWriter} to return a concrete - * FileBasedSinkWriter. + *

    Subclass implementations of WriteOperation must implement {@link + * WriteOperation#createWriter} to return a concrete FileBasedSinkWriter. * - *

    Temporary and Output File Naming:

    During the write, bundles are written to temporary - * files using the tempDirectory that can be provided via the constructor of - * WriteOperation. These temporary files will be named - * {@code {tempDirectory}/{bundleId}}, where bundleId is the unique id of the bundle. - * For example, if tempDirectory is "gs://my-bucket/my_temp_output", the output for a - * bundle with bundle id 15723 will be "gs://my-bucket/my_temp_output/15723". + *

    Temporary and Output File Naming:

    * - *

    Final output files are written to baseOutputFilename with the format - * {@code {baseOutputFilename}-0000i-of-0000n.{extension}} where n is the total number of bundles - * written and extension is the file extension. Both baseOutputFilename and extension are required - * constructor arguments. + *

    During the write, bundles are written to temporary files using the tempDirectory that can be + * provided via the constructor of WriteOperation. These temporary files will be named {@code + * {tempDirectory}/{bundleId}}, where bundleId is the unique id of the bundle. For example, if + * tempDirectory is "gs://my-bucket/my_temp_output", the output for a bundle with bundle id 15723 + * will be "gs://my-bucket/my_temp_output/15723". * - *

    Subclass implementations can change the file naming template by supplying a value for - * fileNamingTemplate. + *

    Final output files are written to the location specified by the {@link FilenamePolicy}. If + * no filename policy is specified, then the {@link DefaultFilenamePolicy} will be used. The + * directory that the files are written to is determined by the {@link FilenamePolicy} instance. * *

    Note that in the case of permanent failure of a bundle's write, no clean up of temporary * files will occur. * *

    If there are no elements in the PCollection being written, no output will be generated. * - * @param the type of values written to the sink. + * @param the type of values written to the sink. */ - public abstract static class WriteOperation implements Serializable { - /** - * The Sink that this WriteOperation will write to. - */ - protected final FileBasedSink sink; + public abstract static class WriteOperation implements Serializable { + /** The Sink that this WriteOperation will write to. */ + protected final FileBasedSink sink; /** Directory for temporary output files. */ protected final ValueProvider tempDirectory; @@ -445,17 +497,19 @@ protected static ResourceId buildTemporaryFilename(ResourceId tempDirectory, Str } /** - * Constructs a WriteOperation using the default strategy for generating a temporary - * directory from the base output filename. + * Constructs a WriteOperation using the default strategy for generating a temporary directory + * from the base output filename. * - *

    Default is a uniquely named sibling of baseOutputFilename, e.g. if baseOutputFilename is - * /path/to/foo, the temporary directory will be /path/to/temp-beam-foo-$date. + *

    Default is a uniquely named subdirectory of the provided tempDirectory, e.g. if + * tempDirectory is /path/to/foo/, the temporary directory will be + * /path/to/foo/temp-beam-foo-$date. * * @param sink the FileBasedSink that will be used to configure this write operation. */ - public WriteOperation(FileBasedSink sink) { - this(sink, NestedValueProvider.of( - sink.getBaseOutputDirectoryProvider(), new TemporaryDirectoryBuilder())); + public WriteOperation(FileBasedSink sink) { + this( + sink, + NestedValueProvider.of(sink.getTempDirectoryProvider(), new TemporaryDirectoryBuilder())); } private static class TemporaryDirectoryBuilder @@ -471,10 +525,12 @@ private static class TemporaryDirectoryBuilder private final Long tempId = TEMP_COUNT.getAndIncrement(); @Override - public ResourceId apply(ResourceId baseOutputDirectory) { + public ResourceId apply(ResourceId tempDirectory) { // Temp directory has a timestamp and a unique ID String tempDirName = String.format(".temp-beam-%s-%s", timestamp, tempId); - return baseOutputDirectory.resolve(tempDirName, StandardResolveOptions.RESOLVE_DIRECTORY); + return tempDirectory + .getCurrentDirectory() + .resolve(tempDirName, StandardResolveOptions.RESOLVE_DIRECTORY); } } @@ -485,22 +541,22 @@ public ResourceId apply(ResourceId baseOutputDirectory) { * @param tempDirectory the base directory to be used for temporary output files. */ @Experimental(Kind.FILESYSTEM) - public WriteOperation(FileBasedSink sink, ResourceId tempDirectory) { + public WriteOperation(FileBasedSink sink, ResourceId tempDirectory) { this(sink, StaticValueProvider.of(tempDirectory)); } private WriteOperation( - FileBasedSink sink, ValueProvider tempDirectory) { + FileBasedSink sink, ValueProvider tempDirectory) { this.sink = sink; this.tempDirectory = tempDirectory; this.windowedWrites = false; } /** - * Clients must implement to return a subclass of {@link Writer}. This - * method must not mutate the state of the object. + * Clients must implement to return a subclass of {@link Writer}. This method must not mutate + * the state of the object. */ - public abstract Writer createWriter() throws Exception; + public abstract Writer createWriter() throws Exception; /** * Indicates that the operation will be performing windowed writes. @@ -514,8 +570,8 @@ public void setWindowedWrites(boolean windowedWrites) { * removing temporary files. * *

    Finalization may be overridden by subclass implementations to perform customized - * finalization (e.g., initiating some operation on output bundles, merging them, etc.). - * {@code writerResults} contains the filenames of written bundles. + * finalization (e.g., initiating some operation on output bundles, merging them, etc.). {@code + * writerResults} contains the filenames of written bundles. * *

    If subclasses override this method, they must guarantee that its implementation is * idempotent, as it may be executed multiple times in the case of failure or for redundancy. It @@ -523,7 +579,7 @@ public void setWindowedWrites(boolean windowedWrites) { * * @param writerResults the results of writes (FileResult). */ - public void finalize(Iterable writerResults) throws Exception { + public void finalize(Iterable> writerResults) throws Exception { // Collect names of temporary files and rename them. Map outputFilenames = buildOutputFilenames(writerResults); copyToOutputFiles(outputFilenames); @@ -542,17 +598,14 @@ public void finalize(Iterable writerResults) throws Exception { @Experimental(Kind.FILESYSTEM) protected final Map buildOutputFilenames( - Iterable writerResults) { + Iterable> writerResults) { int numShards = Iterables.size(writerResults); Map outputFilenames = new HashMap<>(); - FilenamePolicy policy = getSink().getFilenamePolicy(); - ResourceId baseOutputDir = getSink().getBaseOutputDirectoryProvider().get(); - // Either all results have a shard number set (if the sink is configured with a fixed // number of shards), or they all don't (otherwise). Boolean isShardNumberSetEverywhere = null; - for (FileResult result : writerResults) { + for (FileResult result : writerResults) { boolean isShardNumberSetHere = (result.getShard() != UNKNOWN_SHARDNUM); if (isShardNumberSetEverywhere == null) { isShardNumberSetEverywhere = isShardNumberSetHere; @@ -568,7 +621,7 @@ protected final Map buildOutputFilenames( isShardNumberSetEverywhere = true; } - List resultsWithShardNumbers = Lists.newArrayList(); + List> resultsWithShardNumbers = Lists.newArrayList(); if (isShardNumberSetEverywhere) { resultsWithShardNumbers = Lists.newArrayList(writerResults); } else { @@ -577,29 +630,32 @@ protected final Map buildOutputFilenames( // case of triggers, the list of FileResult objects in the Finalize iterable is not // deterministic, and might change over retries. This breaks the assumption below that // sorting the FileResult objects provides idempotency. - List sortedByTempFilename = + List> sortedByTempFilename = Ordering.from( - new Comparator() { - @Override - public int compare(FileResult first, FileResult second) { - String firstFilename = first.getTempFilename().toString(); - String secondFilename = second.getTempFilename().toString(); - return firstFilename.compareTo(secondFilename); - } - }) + new Comparator>() { + @Override + public int compare( + FileResult first, FileResult second) { + String firstFilename = first.getTempFilename().toString(); + String secondFilename = second.getTempFilename().toString(); + return firstFilename.compareTo(secondFilename); + } + }) .sortedCopy(writerResults); for (int i = 0; i < sortedByTempFilename.size(); i++) { resultsWithShardNumbers.add(sortedByTempFilename.get(i).withShard(i)); } } - for (FileResult result : resultsWithShardNumbers) { + for (FileResult result : resultsWithShardNumbers) { checkArgument( result.getShard() != UNKNOWN_SHARDNUM, "Should have set shard number on %s", result); outputFilenames.put( result.getTempFilename(), result.getDestinationFile( - policy, baseOutputDir, numShards, getSink().getExtension())); + getSink().getDynamicDestinations(), + numShards, + getSink().getWritableByteChannelFactory())); } int numDistinctShards = new HashSet<>(outputFilenames.values()).size(); @@ -615,18 +671,18 @@ public int compare(FileResult first, FileResult second) { * *

    Can be called from subclasses that override {@link WriteOperation#finalize}. * - *

    Files will be named according to the file naming template. The order of the output files - * will be the same as the sorted order of the input filenames. In other words, if the input - * filenames are ["C", "A", "B"], baseOutputFilename is "file", the extension is ".txt", and - * the fileNamingTemplate is "-SSS-of-NNN", the contents of A will be copied to - * file-000-of-003.txt, the contents of B will be copied to file-001-of-003.txt, etc. + *

    Files will be named according to the {@link FilenamePolicy}. The order of the output files + * will be the same as the sorted order of the input filenames. In other words (when using + * {@link DefaultFilenamePolicy}), if the input filenames are ["C", "A", "B"], baseFilename (int + * the policy) is "dir/file", the extension is ".txt", and the fileNamingTemplate is + * "-SSS-of-NNN", the contents of A will be copied to dir/file-000-of-003.txt, the contents of B + * will be copied to dir/file-001-of-003.txt, etc. * * @param filenames the filenames of temporary files. */ @VisibleForTesting @Experimental(Kind.FILESYSTEM) - final void copyToOutputFiles(Map filenames) - throws IOException { + final void copyToOutputFiles(Map filenames) throws IOException { int numFiles = filenames.size(); if (numFiles > 0) { LOG.debug("Copying {} files.", numFiles); @@ -698,10 +754,8 @@ final void removeTemporaryFiles( } } - /** - * Returns the FileBasedSink for this write operation. - */ - public FileBasedSink getSink() { + /** Returns the FileBasedSink for this write operation. */ + public FileBasedSink getSink() { return sink; } @@ -719,33 +773,28 @@ public String toString() { } } - /** Returns the extension that will be written to the produced files. */ - protected final String getExtension() { - String extension = MoreObjects.firstNonNull(writableByteChannelFactory.getFilenameSuffix(), ""); - if (!extension.isEmpty() && !extension.startsWith(".")) { - extension = "." + extension; - } - return extension; + /** Returns the {@link WritableByteChannelFactory} used. */ + protected final WritableByteChannelFactory getWritableByteChannelFactory() { + return writableByteChannelFactory; } /** - * Abstract writer that writes a bundle to a {@link FileBasedSink}. Subclass - * implementations provide a method that can write a single value to a - * {@link WritableByteChannel}. + * Abstract writer that writes a bundle to a {@link FileBasedSink}. Subclass implementations + * provide a method that can write a single value to a {@link WritableByteChannel}. * *

    Subclass implementations may also override methods that write headers and footers before and * after the values in a bundle, respectively, as well as provide a MIME type for the output * channel. * - *

    Multiple {@link Writer} instances may be created on the same worker, and therefore - * any access to static members or methods should be thread safe. + *

    Multiple {@link Writer} instances may be created on the same worker, and therefore any + * access to static members or methods should be thread safe. * - * @param the type of values to write. + * @param the type of values to write. */ - public abstract static class Writer { + public abstract static class Writer { private static final Logger LOG = LoggerFactory.getLogger(Writer.class); - private final WriteOperation writeOperation; + private final WriteOperation writeOperation; /** Unique id for this output bundle. */ private String id; @@ -753,6 +802,7 @@ public abstract static class Writer { private BoundedWindow window; private PaneInfo paneInfo; private int shard = -1; + private DestinationT destination; /** The output file for this bundle. May be null if opening failed. */ private @Nullable ResourceId outputFile; @@ -772,10 +822,8 @@ public abstract static class Writer { */ private final String mimeType; - /** - * Construct a new {@link Writer} that will produce files of the given MIME type. - */ - public Writer(WriteOperation writeOperation, String mimeType) { + /** Construct a new {@link Writer} that will produce files of the given MIME type. */ + public Writer(WriteOperation writeOperation, String mimeType) { checkNotNull(writeOperation); this.writeOperation = writeOperation; this.mimeType = mimeType; @@ -818,28 +866,29 @@ protected void finishWrite() throws Exception {} * id populated for the case of static sharding. In cases where the runner is dynamically * picking sharding, shard might be set to -1. */ - public final void openWindowed(String uId, BoundedWindow window, PaneInfo paneInfo, int shard) + public final void openWindowed( + String uId, BoundedWindow window, PaneInfo paneInfo, int shard, DestinationT destination) throws Exception { if (!getWriteOperation().windowedWrites) { throw new IllegalStateException("openWindowed called a non-windowed sink."); } - open(uId, window, paneInfo, shard); + open(uId, window, paneInfo, shard, destination); } /** * Called for each value in the bundle. */ - public abstract void write(T value) throws Exception; + public abstract void write(OutputT value) throws Exception; /** - * Similar to {@link #openWindowed} however for the case where unwindowed writes were - * requested. + * Similar to {@link #openWindowed} however for the case where unwindowed writes were requested. */ - public final void openUnwindowed(String uId, int shard) throws Exception { + public final void openUnwindowed(String uId, int shard, DestinationT destination) + throws Exception { if (getWriteOperation().windowedWrites) { throw new IllegalStateException("openUnwindowed called a windowed sink."); } - open(uId, null, null, shard); + open(uId, null, null, shard, destination); } // Helper function to close a channel, on exception cases. @@ -855,14 +904,18 @@ private static void closeChannelAndThrow( } } - private void open(String uId, - @Nullable BoundedWindow window, - @Nullable PaneInfo paneInfo, - int shard) throws Exception { + private void open( + String uId, + @Nullable BoundedWindow window, + @Nullable PaneInfo paneInfo, + int shard, + DestinationT destination) + throws Exception { this.id = uId; this.window = window; this.paneInfo = paneInfo; this.shard = shard; + this.destination = destination; ResourceId tempDirectory = getWriteOperation().tempDirectory.get(); outputFile = tempDirectory.resolve(id, StandardResolveOptions.RESOLVE_FILE); verifyNotNull( @@ -908,7 +961,7 @@ public final void cleanup() throws Exception { } /** Closes the channel and returns the bundle result. */ - public final FileResult close() throws Exception { + public final FileResult close() throws Exception { checkState(outputFile != null, "FileResult.close cannot be called with a null outputFile"); LOG.debug("Writing footer to {}.", outputFile); @@ -938,35 +991,41 @@ public final FileResult close() throws Exception { throw new IOException(String.format("Failed closing channel to %s", outputFile), e); } - FileResult result = new FileResult(outputFile, shard, window, paneInfo); + FileResult result = + new FileResult<>(outputFile, shard, window, paneInfo, destination); LOG.debug("Result for bundle {}: {}", this.id, outputFile); return result; } - /** - * Return the WriteOperation that this Writer belongs to. - */ - public WriteOperation getWriteOperation() { + /** Return the WriteOperation that this Writer belongs to. */ + public WriteOperation getWriteOperation() { return writeOperation; } } /** - * Result of a single bundle write. Contains the filename produced by the bundle, and if known - * the final output filename. + * Result of a single bundle write. Contains the filename produced by the bundle, and if known the + * final output filename. */ - public static final class FileResult { + public static final class FileResult { private final ResourceId tempFilename; private final int shard; private final BoundedWindow window; private final PaneInfo paneInfo; + private final DestinationT destination; @Experimental(Kind.FILESYSTEM) - public FileResult(ResourceId tempFilename, int shard, BoundedWindow window, PaneInfo paneInfo) { + public FileResult( + ResourceId tempFilename, + int shard, + BoundedWindow window, + PaneInfo paneInfo, + DestinationT destination) { this.tempFilename = tempFilename; this.shard = shard; this.window = window; this.paneInfo = paneInfo; + this.destination = destination; } @Experimental(Kind.FILESYSTEM) @@ -978,8 +1037,8 @@ public int getShard() { return shard; } - public FileResult withShard(int shard) { - return new FileResult(tempFilename, shard, window, paneInfo); + public FileResult withShard(int shard) { + return new FileResult<>(tempFilename, shard, window, paneInfo, destination); } public BoundedWindow getWindow() { @@ -990,17 +1049,24 @@ public PaneInfo getPaneInfo() { return paneInfo; } + public DestinationT getDestination() { + return destination; + } + @Experimental(Kind.FILESYSTEM) - public ResourceId getDestinationFile(FilenamePolicy policy, ResourceId outputDirectory, - int numShards, String extension) { + public ResourceId getDestinationFile( + DynamicDestinations dynamicDestinations, + int numShards, + OutputFileHints outputFileHints) { checkArgument(getShard() != UNKNOWN_SHARDNUM); checkArgument(numShards > 0); + FilenamePolicy policy = dynamicDestinations.getFilenamePolicy(destination); if (getWindow() != null) { - return policy.windowedFilename(outputDirectory, new WindowedContext( - getWindow(), getPaneInfo(), getShard(), numShards), extension); + return policy.windowedFilename( + new WindowedContext(getWindow(), getPaneInfo(), getShard(), numShards), + outputFileHints); } else { - return policy.unwindowedFilename(outputDirectory, new Context(getShard(), numShards), - extension); + return policy.unwindowedFilename(new Context(getShard(), numShards), outputFileHints); } } @@ -1014,22 +1080,24 @@ public String toString() { } } - /** - * A coder for {@link FileResult} objects. - */ - public static final class FileResultCoder extends StructuredCoder { + /** A coder for {@link FileResult} objects. */ + public static final class FileResultCoder + extends StructuredCoder> { private static final Coder FILENAME_CODER = StringUtf8Coder.of(); private static final Coder SHARD_CODER = VarIntCoder.of(); private static final Coder PANE_INFO_CODER = NullableCoder.of(PaneInfoCoder.INSTANCE); - private final Coder windowCoder; + private final Coder destinationCoder; - protected FileResultCoder(Coder windowCoder) { + protected FileResultCoder( + Coder windowCoder, Coder destinationCoder) { this.windowCoder = NullableCoder.of(windowCoder); + this.destinationCoder = destinationCoder; } - public static FileResultCoder of(Coder windowCoder) { - return new FileResultCoder(windowCoder); + public static FileResultCoder of( + Coder windowCoder, Coder destinationCoder) { + return new FileResultCoder<>(windowCoder, destinationCoder); } @Override @@ -1038,8 +1106,7 @@ public List> getCoderArguments() { } @Override - public void encode(FileResult value, OutputStream outStream) - throws IOException { + public void encode(FileResult value, OutputStream outStream) throws IOException { if (value == null) { throw new CoderException("cannot encode a null value"); } @@ -1047,17 +1114,22 @@ public void encode(FileResult value, OutputStream outStream) windowCoder.encode(value.getWindow(), outStream); PANE_INFO_CODER.encode(value.getPaneInfo(), outStream); SHARD_CODER.encode(value.getShard(), outStream); + destinationCoder.encode(value.getDestination(), outStream); } @Override - public FileResult decode(InputStream inStream) - throws IOException { + public FileResult decode(InputStream inStream) throws IOException { String tempFilename = FILENAME_CODER.decode(inStream); BoundedWindow window = windowCoder.decode(inStream); PaneInfo paneInfo = PANE_INFO_CODER.decode(inStream); int shard = SHARD_CODER.decode(inStream); - return new FileResult(FileSystems.matchNewResource(tempFilename, false /* isDirectory */), - shard, window, paneInfo); + DestinationT destination = destinationCoder.decode(inStream); + return new FileResult<>( + FileSystems.matchNewResource(tempFilename, false /* isDirectory */), + shard, + window, + paneInfo, + destination); } @Override @@ -1066,25 +1138,15 @@ public void verifyDeterministic() throws NonDeterministicException { windowCoder.verifyDeterministic(); PANE_INFO_CODER.verifyDeterministic(); SHARD_CODER.verifyDeterministic(); + destinationCoder.verifyDeterministic(); } } /** - * Implementations create instances of {@link WritableByteChannel} used by {@link FileBasedSink} - * and related classes to allow decorating, or otherwise transforming, the raw data that - * would normally be written directly to the {@link WritableByteChannel} passed into - * {@link WritableByteChannelFactory#create(WritableByteChannel)}. - * - *

    Subclasses should override {@link #toString()} with something meaningful, as it is used when - * building {@link DisplayData}. + * Provides hints about how to generate output files, such as a suggested filename suffix (e.g. + * based on the compression type), and the file MIME type. */ - public interface WritableByteChannelFactory extends Serializable { - /** - * @param channel the {@link WritableByteChannel} to wrap - * @return the {@link WritableByteChannel} to be used during output - */ - WritableByteChannel create(WritableByteChannel channel) throws IOException; - + public interface OutputFileHints extends Serializable { /** * Returns the MIME type that should be used for the files that will hold the output data. May * return {@code null} if this {@code WritableByteChannelFactory} does not meaningfully change @@ -1101,6 +1163,23 @@ public interface WritableByteChannelFactory extends Serializable { * @return an optional filename suffix, eg, ".gz" is returned by {@link CompressionType#GZIP} */ @Nullable - String getFilenameSuffix(); + String getSuggestedFilenameSuffix(); + } + + /** + * Implementations create instances of {@link WritableByteChannel} used by {@link FileBasedSink} + * and related classes to allow decorating, or otherwise transforming, the raw data that + * would normally be written directly to the {@link WritableByteChannel} passed into {@link + * WritableByteChannelFactory#create(WritableByteChannel)}. + * + *

    Subclasses should override {@link #toString()} with something meaningful, as it is used when + * building {@link DisplayData}. + */ + public interface WritableByteChannelFactory extends OutputFileHints { + /** + * @param channel the {@link WritableByteChannel} to wrap + * @return the {@link WritableByteChannel} to be used during output + */ + WritableByteChannel create(WritableByteChannel channel) throws IOException; } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TFRecordIO.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TFRecordIO.java index e28807507819a..6e7b243b6284d 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TFRecordIO.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TFRecordIO.java @@ -45,6 +45,7 @@ import org.apache.beam.sdk.options.ValueProvider.StaticValueProvider; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.transforms.SerializableFunctions; import org.apache.beam.sdk.transforms.display.DisplayData; import org.apache.beam.sdk.util.MimeTypes; import org.apache.beam.sdk.values.PBegin; @@ -355,12 +356,11 @@ public Write withCompressionType(CompressionType compressionType) { public PDone expand(PCollection input) { checkState(getOutputPrefix() != null, "need to set the output prefix of a TFRecordIO.Write transform"); - WriteFiles write = WriteFiles.to( + WriteFiles write = + WriteFiles.to( new TFRecordSink( - getOutputPrefix(), - getShardTemplate(), - getFilenameSuffix(), - getCompressionType())); + getOutputPrefix(), getShardTemplate(), getFilenameSuffix(), getCompressionType()), + SerializableFunctions.identity()); if (getNumShards() > 0) { write = write.withNumShards(getNumShards()); } @@ -546,20 +546,20 @@ protected boolean readNextRecord() throws IOException { } } - /** - * A {@link FileBasedSink} for TFRecord files. Produces TFRecord files. - */ + /** A {@link FileBasedSink} for TFRecord files. Produces TFRecord files. */ @VisibleForTesting - static class TFRecordSink extends FileBasedSink { + static class TFRecordSink extends FileBasedSink { @VisibleForTesting - TFRecordSink(ValueProvider outputPrefix, + TFRecordSink( + ValueProvider outputPrefix, @Nullable String shardTemplate, @Nullable String suffix, TFRecordIO.CompressionType compressionType) { super( outputPrefix, - DefaultFilenamePolicy.constructUsingStandardParameters( - outputPrefix, shardTemplate, suffix, false), + DynamicFileDestinations.constant( + DefaultFilenamePolicy.fromStandardParameters( + outputPrefix, shardTemplate, suffix, false)), writableByteChannelFactory(compressionType)); } @@ -571,7 +571,7 @@ public ResourceId apply(ResourceId input) { } @Override - public WriteOperation createWriteOperation() { + public WriteOperation createWriteOperation() { return new TFRecordWriteOperation(this); } @@ -590,30 +590,24 @@ private static WritableByteChannelFactory writableByteChannelFactory( return CompressionType.UNCOMPRESSED; } - /** - * A {@link WriteOperation - * WriteOperation} for TFRecord files. - */ - private static class TFRecordWriteOperation extends WriteOperation { + /** A {@link WriteOperation WriteOperation} for TFRecord files. */ + private static class TFRecordWriteOperation extends WriteOperation { private TFRecordWriteOperation(TFRecordSink sink) { super(sink); } @Override - public Writer createWriter() throws Exception { + public Writer createWriter() throws Exception { return new TFRecordWriter(this); } } - /** - * A {@link Writer Writer} - * for TFRecord files. - */ - private static class TFRecordWriter extends Writer { + /** A {@link Writer Writer} for TFRecord files. */ + private static class TFRecordWriter extends Writer { private WritableByteChannel outChannel; private TFRecordCodec codec; - private TFRecordWriter(WriteOperation writeOperation) { + private TFRecordWriter(WriteOperation writeOperation) { super(writeOperation, MimeTypes.BINARY); } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TextIO.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TextIO.java index f1eb7c0bde421..524158968237a 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TextIO.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TextIO.java @@ -22,12 +22,15 @@ import static com.google.common.base.Preconditions.checkState; import com.google.auto.value.AutoValue; +import com.google.common.annotations.VisibleForTesting; import javax.annotation.Nullable; import org.apache.beam.sdk.annotations.Experimental; import org.apache.beam.sdk.annotations.Experimental.Kind; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.coders.VoidCoder; +import org.apache.beam.sdk.io.DefaultFilenamePolicy.Params; +import org.apache.beam.sdk.io.FileBasedSink.DynamicDestinations; import org.apache.beam.sdk.io.FileBasedSink.FilenamePolicy; import org.apache.beam.sdk.io.FileBasedSink.WritableByteChannelFactory; import org.apache.beam.sdk.io.Read.Bounded; @@ -37,6 +40,7 @@ import org.apache.beam.sdk.options.ValueProvider.StaticValueProvider; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.transforms.SerializableFunctions; import org.apache.beam.sdk.transforms.display.DisplayData; import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; @@ -65,19 +69,8 @@ *

    To write a {@link PCollection} to one or more text files, use {@code TextIO.write()}, using * {@link TextIO.Write#to(String)} to specify the output prefix of the files to write. * - *

    By default, all input is put into the global window before writing. If per-window writes are - * desired - for example, when using a streaming runner - - * {@link TextIO.Write#withWindowedWrites()} will cause windowing and triggering to be - * preserved. When producing windowed writes, the number of output shards must be set explicitly - * using {@link TextIO.Write#withNumShards(int)}; some runners may set this for you to a - * runner-chosen value, so you may need not set it yourself. A {@link FilenamePolicy} can also be - * set in case you need better control over naming files created by unique windows. - * {@link DefaultFilenamePolicy} policy for producing unique filenames might not be appropriate - * for your use case. - * - *

    Any existing files with the same names as generated output files will be overwritten. - * *

    For example: + * *

    {@code
      * // A simple Write to a local file (only runs locally):
      * PCollection lines = ...;
    @@ -85,10 +78,49 @@
      *
      * // Same as above, only with Gzip compression:
      * PCollection lines = ...;
    - * lines.apply(TextIO.write().to("/path/to/file.txt"));
    + * lines.apply(TextIO.write().to("/path/to/file.txt"))
      *      .withSuffix(".txt")
      *      .withWritableByteChannelFactory(FileBasedSink.CompressionType.GZIP));
      * }
    + * + *

    By default, all input is put into the global window before writing. If per-window writes are + * desired - for example, when using a streaming runner - {@link TextIO.Write#withWindowedWrites()} + * will cause windowing and triggering to be preserved. When producing windowed writes with a + * streaming runner that supports triggers, the number of output shards must be set explicitly using + * {@link TextIO.Write#withNumShards(int)}; some runners may set this for you to a runner-chosen + * value, so you may need not set it yourself. If setting an explicit template using {@link + * TextIO.Write#withShardNameTemplate(String)}, make sure that the template contains placeholders + * for the window and the pane; W is expanded into the window text, and P into the pane; the default + * template will include both the window and the pane in the filename. + * + *

    If you want better control over how filenames are generated than the default policy allows, a + * custom {@link FilenamePolicy} can also be set using {@link TextIO.Write#to(FilenamePolicy)}. + * + *

    TextIO also supports dynamic, value-dependent file destinations. The most general form of this + * is done via {@link TextIO.Write#to(DynamicDestinations)}. A {@link DynamicDestinations} class + * allows you to convert any input value into a custom destination object, and map that destination + * object to a {@link FilenamePolicy}. This allows using different filename policies (or more + * commonly, differently-configured instances of the same policy) based on the input record. Often + * this is used in conjunction with {@link TextIO#writeCustomType(SerializableFunction)}, which + * allows your {@link DynamicDestinations} object to examine the input type and takes a format + * function to convert that type to a string for writing. + * + *

    A convenience shortcut is provided for the case where the default naming policy is used, but + * different configurations of this policy are wanted based on the input record. Default naming + * policies can be configured using the {@link DefaultFilenamePolicy.Params} object. + * + *

    {@code
    + * PCollection> lines = ...;
    + * lines.apply(TextIO.writeCustomType(new FormatEvent())
    + *      .to(new SerializableFunction() {
    + *         public String apply(UserEvent value) {
    + *           return new Params().withBaseFilename(baseDirectory + "/" + value.country());
    + *         }
    + *       }),
    + *       new Params().withBaseFilename(baseDirectory + "/empty");
    + * }
    + * + *

    Any existing files with the same names as generated output files will be overwritten. */ public class TextIO { /** @@ -105,11 +137,29 @@ public static Read read() { * line. */ public static Write write() { - return new AutoValue_TextIO_Write.Builder() + return new TextIO.Write(); + } + + /** + * A {@link PTransform} that writes a {@link PCollection} to a text file (or multiple text files + * matching a sharding pattern), with each element of the input collection encoded into its own + * line. + * + *

    This version allows you to apply {@link TextIO} writes to a PCollection of a custom type + * {@link T}, along with a format function that converts the input type {@link T} to the String + * that will be written to the file. The advantage of this is it allows a user-provided {@link + * DynamicDestinations} object, set via {@link Write#to(DynamicDestinations)} to examine the + * user's custom type when choosing a destination. + */ + public static TypedWrite writeCustomType(SerializableFunction formatFunction) { + return new AutoValue_TextIO_TypedWrite.Builder() .setFilenamePrefix(null) + .setTempDirectory(null) .setShardTemplate(null) .setFilenameSuffix(null) .setFilenamePolicy(null) + .setDynamicDestinations(null) + .setFormatFunction(formatFunction) .setWritableByteChannelFactory(FileBasedSink.CompressionType.UNCOMPRESSED) .setWindowedWrites(false) .setNumShards(0) @@ -223,18 +273,21 @@ protected Coder getDefaultOutputCoder() { } } - - ///////////////////////////////////////////////////////////////////////////// + // /////////////////////////////////////////////////////////////////////////// /** Implementation of {@link #write}. */ @AutoValue - public abstract static class Write extends PTransform, PDone> { + public abstract static class TypedWrite extends PTransform, PDone> { /** The prefix of each file written, combined with suffix and shardTemplate. */ @Nullable abstract ValueProvider getFilenamePrefix(); /** The suffix of each file written, combined with prefix and shardTemplate. */ @Nullable abstract String getFilenameSuffix(); + /** The base directory used for generating temporary files. */ + @Nullable + abstract ValueProvider getTempDirectory(); + /** An optional header to add to each file. */ @Nullable abstract String getHeader(); @@ -250,6 +303,13 @@ public abstract static class Write extends PTransform, PDone /** A policy for naming output files. */ @Nullable abstract FilenamePolicy getFilenamePolicy(); + /** Allows for value-dependent {@link DynamicDestinations} to be vended. */ + @Nullable + abstract DynamicDestinations getDynamicDestinations(); + + /** A function that converts T to a String, for writing to the file. */ + abstract SerializableFunction getFormatFunction(); + /** Whether to write windowed output files. */ abstract boolean getWindowedWrites(); @@ -259,66 +319,68 @@ public abstract static class Write extends PTransform, PDone */ abstract WritableByteChannelFactory getWritableByteChannelFactory(); - abstract Builder toBuilder(); + abstract Builder toBuilder(); @AutoValue.Builder - abstract static class Builder { - abstract Builder setFilenamePrefix(ValueProvider filenamePrefix); - abstract Builder setShardTemplate(@Nullable String shardTemplate); - abstract Builder setFilenameSuffix(@Nullable String filenameSuffix); - abstract Builder setHeader(@Nullable String header); - abstract Builder setFooter(@Nullable String footer); - abstract Builder setFilenamePolicy(@Nullable FilenamePolicy filenamePolicy); - abstract Builder setNumShards(int numShards); - abstract Builder setWindowedWrites(boolean windowedWrites); - abstract Builder setWritableByteChannelFactory( + abstract static class Builder { + abstract Builder setFilenamePrefix(ValueProvider filenamePrefix); + + abstract Builder setTempDirectory(ValueProvider tempDirectory); + + abstract Builder setShardTemplate(@Nullable String shardTemplate); + + abstract Builder setFilenameSuffix(@Nullable String filenameSuffix); + + abstract Builder setHeader(@Nullable String header); + + abstract Builder setFooter(@Nullable String footer); + + abstract Builder setFilenamePolicy(@Nullable FilenamePolicy filenamePolicy); + + abstract Builder setDynamicDestinations( + @Nullable DynamicDestinations dynamicDestinations); + + abstract Builder setFormatFunction(SerializableFunction formatFunction); + + abstract Builder setNumShards(int numShards); + + abstract Builder setWindowedWrites(boolean windowedWrites); + + abstract Builder setWritableByteChannelFactory( WritableByteChannelFactory writableByteChannelFactory); - abstract Write build(); + abstract TypedWrite build(); } /** - * Writes to text files with the given prefix. The given {@code prefix} can reference any - * {@link FileSystem} on the classpath. - * - *

    The name of the output files will be determined by the {@link FilenamePolicy} used. + * Writes to text files with the given prefix. The given {@code prefix} can reference any {@link + * FileSystem} on the classpath. This prefix is used by the {@link DefaultFilenamePolicy} to + * generate filenames. * *

    By default, a {@link DefaultFilenamePolicy} will be used built using the specified prefix - * to define the base output directory and file prefix, a shard identifier (see - * {@link #withNumShards(int)}), and a common suffix (if supplied using - * {@link #withSuffix(String)}). + * to define the base output directory and file prefix, a shard identifier (see {@link + * #withNumShards(int)}), and a common suffix (if supplied using {@link #withSuffix(String)}). + * + *

    This default policy can be overridden using {@link #to(FilenamePolicy)}, in which case + * {@link #withShardNameTemplate(String)} and {@link #withSuffix(String)} should not be set. + * Custom filename policies do not automatically see this prefix - you should explicitly pass + * the prefix into your {@link FilenamePolicy} object if you need this. * - *

    This default policy can be overridden using {@link #withFilenamePolicy(FilenamePolicy)}, - * in which case {@link #withShardNameTemplate(String)} and {@link #withSuffix(String)} should - * not be set. + *

    If {@link #withTempDirectory} has not been called, this filename prefix will be used to + * infer a directory for temporary files. */ - public Write to(String filenamePrefix) { + public TypedWrite to(String filenamePrefix) { return to(FileBasedSink.convertToFileResourceIfPossible(filenamePrefix)); } - /** - * Writes to text files with prefix from the given resource. - * - *

    The name of the output files will be determined by the {@link FilenamePolicy} used. - * - *

    By default, a {@link DefaultFilenamePolicy} will be used built using the specified prefix - * to define the base output directory and file prefix, a shard identifier (see - * {@link #withNumShards(int)}), and a common suffix (if supplied using - * {@link #withSuffix(String)}). - * - *

    This default policy can be overridden using {@link #withFilenamePolicy(FilenamePolicy)}, - * in which case {@link #withShardNameTemplate(String)} and {@link #withSuffix(String)} should - * not be set. - */ + /** Like {@link #to(String)}. */ @Experimental(Kind.FILESYSTEM) - public Write to(ResourceId filenamePrefix) { + public TypedWrite to(ResourceId filenamePrefix) { return toResource(StaticValueProvider.of(filenamePrefix)); } - /** - * Like {@link #to(String)}. - */ - public Write to(ValueProvider outputPrefix) { + /** Like {@link #to(String)}. */ + public TypedWrite to(ValueProvider outputPrefix) { return toResource(NestedValueProvider.of(outputPrefix, new SerializableFunction() { @Override @@ -329,42 +391,76 @@ public ResourceId apply(String input) { } /** - * Like {@link #to(ResourceId)}. + * Writes to files named according to the given {@link FileBasedSink.FilenamePolicy}. A + * directory for temporary files must be specified using {@link #withTempDirectory}. */ + public TypedWrite to(FilenamePolicy filenamePolicy) { + return toBuilder().setFilenamePolicy(filenamePolicy).build(); + } + + /** + * Use a {@link DynamicDestinations} object to vend {@link FilenamePolicy} objects. These + * objects can examine the input record when creating a {@link FilenamePolicy}. A directory for + * temporary files must be specified using {@link #withTempDirectory}. + */ + public TypedWrite to(DynamicDestinations dynamicDestinations) { + return toBuilder().setDynamicDestinations(dynamicDestinations).build(); + } + + /** + * Write to dynamic destinations using the default filename policy. The destinationFunction maps + * the input record to a {@link DefaultFilenamePolicy.Params} object that specifies where the + * records should be written (base filename, file suffix, and shard template). The + * emptyDestination parameter specified where empty files should be written for when the written + * {@link PCollection} is empty. + */ + public TypedWrite to( + SerializableFunction destinationFunction, Params emptyDestination) { + return to(DynamicFileDestinations.toDefaultPolicies(destinationFunction, emptyDestination)); + } + + /** Like {@link #to(ResourceId)}. */ @Experimental(Kind.FILESYSTEM) - public Write toResource(ValueProvider filenamePrefix) { + public TypedWrite toResource(ValueProvider filenamePrefix) { return toBuilder().setFilenamePrefix(filenamePrefix).build(); } + /** Set the base directory used to generate temporary files. */ + @Experimental(Kind.FILESYSTEM) + public TypedWrite withTempDirectory(ValueProvider tempDirectory) { + return toBuilder().setTempDirectory(tempDirectory).build(); + } + + /** Set the base directory used to generate temporary files. */ + @Experimental(Kind.FILESYSTEM) + public TypedWrite withTempDirectory(ResourceId tempDirectory) { + return withTempDirectory(StaticValueProvider.of(tempDirectory)); + } + /** * Uses the given {@link ShardNameTemplate} for naming output files. This option may only be - * used when {@link #withFilenamePolicy(FilenamePolicy)} has not been configured. + * used when using one of the default filename-prefix to() overrides - i.e. not when using + * either {@link #to(FilenamePolicy)} or {@link #to(DynamicDestinations)}. * *

    See {@link DefaultFilenamePolicy} for how the prefix, shard name template, and suffix are * used. */ - public Write withShardNameTemplate(String shardTemplate) { + public TypedWrite withShardNameTemplate(String shardTemplate) { return toBuilder().setShardTemplate(shardTemplate).build(); } /** - * Configures the filename suffix for written files. This option may only be used when - * {@link #withFilenamePolicy(FilenamePolicy)} has not been configured. + * Configures the filename suffix for written files. This option may only be used when using one + * of the default filename-prefix to() overrides - i.e. not when using either {@link + * #to(FilenamePolicy)} or {@link #to(DynamicDestinations)}. * *

    See {@link DefaultFilenamePolicy} for how the prefix, shard name template, and suffix are * used. */ - public Write withSuffix(String filenameSuffix) { + public TypedWrite withSuffix(String filenameSuffix) { return toBuilder().setFilenameSuffix(filenameSuffix).build(); } - /** - * Configures the {@link FileBasedSink.FilenamePolicy} that will be used to name written files. - */ - public Write withFilenamePolicy(FilenamePolicy filenamePolicy) { - return toBuilder().setFilenamePolicy(filenamePolicy).build(); - } - /** * Configures the number of output shards produced overall (when using unwindowed writes) or * per-window (when using windowed writes). @@ -375,14 +471,13 @@ public Write withFilenamePolicy(FilenamePolicy filenamePolicy) { * * @param numShards the number of shards to use, or 0 to let the system decide. */ - public Write withNumShards(int numShards) { + public TypedWrite withNumShards(int numShards) { checkArgument(numShards >= 0); return toBuilder().setNumShards(numShards).build(); } /** - * Forces a single file as output and empty shard name template. This option is only compatible - * with unwindowed writes. + * Forces a single file as output and empty shard name template. * *

    For unwindowed writes, constraining the number of shards is likely to reduce the * performance of a pipeline. Setting this value is not recommended unless you require a @@ -390,7 +485,7 @@ public Write withNumShards(int numShards) { * *

    This is equivalent to {@code .withNumShards(1).withShardNameTemplate("")} */ - public Write withoutSharding() { + public TypedWrite withoutSharding() { return withNumShards(1).withShardNameTemplate(""); } @@ -399,7 +494,7 @@ public Write withoutSharding() { * *

    A {@code null} value will clear any previously configured header. */ - public Write withHeader(@Nullable String header) { + public TypedWrite withHeader(@Nullable String header) { return toBuilder().setHeader(header).build(); } @@ -408,48 +503,82 @@ public Write withHeader(@Nullable String header) { * *

    A {@code null} value will clear any previously configured footer. */ - public Write withFooter(@Nullable String footer) { + public TypedWrite withFooter(@Nullable String footer) { return toBuilder().setFooter(footer).build(); } /** - * Returns a transform for writing to text files like this one but that has the given - * {@link WritableByteChannelFactory} to be used by the {@link FileBasedSink} during output. - * The default is value is {@link FileBasedSink.CompressionType#UNCOMPRESSED}. + * Returns a transform for writing to text files like this one but that has the given {@link + * WritableByteChannelFactory} to be used by the {@link FileBasedSink} during output. The + * default is value is {@link FileBasedSink.CompressionType#UNCOMPRESSED}. * *

    A {@code null} value will reset the value to the default value mentioned above. */ - public Write withWritableByteChannelFactory( + public TypedWrite withWritableByteChannelFactory( WritableByteChannelFactory writableByteChannelFactory) { return toBuilder().setWritableByteChannelFactory(writableByteChannelFactory).build(); } - public Write withWindowedWrites() { + /** + * Preserves windowing of input elements and writes them to files based on the element's window. + * + *

    If using {@link #to(FileBasedSink.FilenamePolicy)}. Filenames will be generated using + * {@link FilenamePolicy#windowedFilename}. See also {@link WriteFiles#withWindowedWrites()}. + */ + public TypedWrite withWindowedWrites() { return toBuilder().setWindowedWrites(true).build(); } + private DynamicDestinations resolveDynamicDestinations() { + DynamicDestinations dynamicDestinations = getDynamicDestinations(); + if (dynamicDestinations == null) { + FilenamePolicy usedFilenamePolicy = getFilenamePolicy(); + if (usedFilenamePolicy == null) { + usedFilenamePolicy = + DefaultFilenamePolicy.fromStandardParameters( + getFilenamePrefix(), + getShardTemplate(), + getFilenameSuffix(), + getWindowedWrites()); + } + dynamicDestinations = DynamicFileDestinations.constant(usedFilenamePolicy); + } + return dynamicDestinations; + } + @Override - public PDone expand(PCollection input) { - checkState(getFilenamePrefix() != null, - "Need to set the filename prefix of a TextIO.Write transform."); + public PDone expand(PCollection input) { + checkState( + getFilenamePrefix() != null || getTempDirectory() != null, + "Need to set either the filename prefix or the tempDirectory of a TextIO.Write " + + "transform."); checkState( - (getFilenamePolicy() == null) - || (getShardTemplate() == null && getFilenameSuffix() == null), - "Cannot set a filename policy and also a filename template or suffix."); - - FilenamePolicy usedFilenamePolicy = getFilenamePolicy(); - if (usedFilenamePolicy == null) { - usedFilenamePolicy = DefaultFilenamePolicy.constructUsingStandardParameters( - getFilenamePrefix(), getShardTemplate(), getFilenameSuffix(), getWindowedWrites()); + getFilenamePolicy() == null || getDynamicDestinations() == null, + "Cannot specify both a filename policy and dynamic destinations"); + if (getFilenamePolicy() != null || getDynamicDestinations() != null) { + checkState( + getShardTemplate() == null && getFilenameSuffix() == null, + "shardTemplate and filenameSuffix should only be used with the default " + + "filename policy"); } - WriteFiles write = + return expandTyped(input, resolveDynamicDestinations()); + } + + public PDone expandTyped( + PCollection input, DynamicDestinations dynamicDestinations) { + ValueProvider tempDirectory = getTempDirectory(); + if (tempDirectory == null) { + tempDirectory = getFilenamePrefix(); + } + WriteFiles write = WriteFiles.to( - new TextSink( - getFilenamePrefix(), - usedFilenamePolicy, + new TextSink<>( + tempDirectory, + dynamicDestinations, getHeader(), getFooter(), - getWritableByteChannelFactory())); + getWritableByteChannelFactory()), + getFormatFunction()); if (getNumShards() > 0) { write = write.withNumShards(getNumShards()); } @@ -463,27 +592,26 @@ public PDone expand(PCollection input) { public void populateDisplayData(DisplayData.Builder builder) { super.populateDisplayData(builder); - String prefixString = ""; - if (getFilenamePrefix() != null) { - prefixString = getFilenamePrefix().isAccessible() - ? getFilenamePrefix().get().toString() : getFilenamePrefix().toString(); + resolveDynamicDestinations().populateDisplayData(builder); + String tempDirectory = null; + if (getTempDirectory() != null) { + tempDirectory = + getTempDirectory().isAccessible() + ? getTempDirectory().get().toString() + : getTempDirectory().toString(); } builder - .addIfNotNull(DisplayData.item("filePrefix", prefixString) - .withLabel("Output File Prefix")) - .addIfNotNull(DisplayData.item("fileSuffix", getFilenameSuffix()) - .withLabel("Output File Suffix")) - .addIfNotNull(DisplayData.item("shardNameTemplate", getShardTemplate()) - .withLabel("Output Shard Name Template")) - .addIfNotDefault(DisplayData.item("numShards", getNumShards()) - .withLabel("Maximum Output Shards"), 0) - .addIfNotNull(DisplayData.item("fileHeader", getHeader()) - .withLabel("File Header")) - .addIfNotNull(DisplayData.item("fileFooter", getFooter()) - .withLabel("File Footer")) - .add(DisplayData - .item("writableByteChannelFactory", getWritableByteChannelFactory().toString()) - .withLabel("Compression/Transformation Type")); + .addIfNotDefault( + DisplayData.item("numShards", getNumShards()).withLabel("Maximum Output Shards"), 0) + .addIfNotNull( + DisplayData.item("tempDirectory", tempDirectory) + .withLabel("Directory for temporary files")) + .addIfNotNull(DisplayData.item("fileHeader", getHeader()).withLabel("File Header")) + .addIfNotNull(DisplayData.item("fileFooter", getFooter()).withLabel("File Footer")) + .add( + DisplayData.item( + "writableByteChannelFactory", getWritableByteChannelFactory().toString()) + .withLabel("Compression/Transformation Type")); } @Override @@ -492,6 +620,128 @@ protected Coder getDefaultOutputCoder() { } } + /** + * This class is used as the default return value of {@link TextIO#write()}. + * + *

    All methods in this class delegate to the appropriate method of {@link TextIO.TypedWrite}. + * This class exists for backwards compatibility, and will be removed in Beam 3.0. + */ + public static class Write extends PTransform, PDone> { + @VisibleForTesting TypedWrite inner; + + Write() { + this(TextIO.writeCustomType(SerializableFunctions.identity())); + } + + Write(TypedWrite inner) { + this.inner = inner; + } + + /** See {@link TypedWrite#to(String)}. */ + public Write to(String filenamePrefix) { + return new Write(inner.to(filenamePrefix)); + } + + /** See {@link TypedWrite#to(ResourceId)}. */ + @Experimental(Kind.FILESYSTEM) + public Write to(ResourceId filenamePrefix) { + return new Write(inner.to(filenamePrefix)); + } + + /** See {@link TypedWrite#to(ValueProvider)}. */ + public Write to(ValueProvider outputPrefix) { + return new Write(inner.to(outputPrefix)); + } + + /** See {@link TypedWrite#toResource(ValueProvider)}. */ + @Experimental(Kind.FILESYSTEM) + public Write toResource(ValueProvider filenamePrefix) { + return new Write(inner.toResource(filenamePrefix)); + } + + /** See {@link TypedWrite#to(FilenamePolicy)}. */ + @Experimental(Kind.FILESYSTEM) + public Write to(FilenamePolicy filenamePolicy) { + return new Write(inner.to(filenamePolicy)); + } + + /** See {@link TypedWrite#to(DynamicDestinations)}. */ + @Experimental(Kind.FILESYSTEM) + public Write to(DynamicDestinations dynamicDestinations) { + return new Write(inner.to(dynamicDestinations)); + } + + /** See {@link TypedWrite#to(SerializableFunction, Params)}. */ + @Experimental(Kind.FILESYSTEM) + public Write to( + SerializableFunction destinationFunction, Params emptyDestination) { + return new Write(inner.to(destinationFunction, emptyDestination)); + } + + /** See {@link TypedWrite#withTempDirectory(ValueProvider)}. */ + @Experimental(Kind.FILESYSTEM) + public Write withTempDirectory(ValueProvider tempDirectory) { + return new Write(inner.withTempDirectory(tempDirectory)); + } + + /** See {@link TypedWrite#withTempDirectory(ResourceId)}. */ + @Experimental(Kind.FILESYSTEM) + public Write withTempDirectory(ResourceId tempDirectory) { + return new Write(inner.withTempDirectory(tempDirectory)); + } + + /** See {@link TypedWrite#withShardNameTemplate(String)}. */ + public Write withShardNameTemplate(String shardTemplate) { + return new Write(inner.withShardNameTemplate(shardTemplate)); + } + + /** See {@link TypedWrite#withSuffix(String)}. */ + public Write withSuffix(String filenameSuffix) { + return new Write(inner.withSuffix(filenameSuffix)); + } + + /** See {@link TypedWrite#withNumShards(int)}. */ + public Write withNumShards(int numShards) { + return new Write(inner.withNumShards(numShards)); + } + + /** See {@link TypedWrite#withoutSharding()}. */ + public Write withoutSharding() { + return new Write(inner.withoutSharding()); + } + + /** See {@link TypedWrite#withHeader(String)}. */ + public Write withHeader(@Nullable String header) { + return new Write(inner.withHeader(header)); + } + + /** See {@link TypedWrite#withFooter(String)}. */ + public Write withFooter(@Nullable String footer) { + return new Write(inner.withFooter(footer)); + } + + /** See {@link TypedWrite#withWritableByteChannelFactory(WritableByteChannelFactory)}. */ + public Write withWritableByteChannelFactory( + WritableByteChannelFactory writableByteChannelFactory) { + return new Write(inner.withWritableByteChannelFactory(writableByteChannelFactory)); + } + + /** See {@link TypedWrite#withWindowedWrites}. */ + public Write withWindowedWrites() { + return new Write(inner.withWindowedWrites()); + } + + @Override + public void populateDisplayData(DisplayData.Builder builder) { + inner.populateDisplayData(builder); + } + + @Override + public PDone expand(PCollection input) { + return inner.expand(input); + } + } + /** * Possible text file compression types. */ diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TextSink.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TextSink.java index 511d6976a4e90..b57b28c5c0310 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TextSink.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TextSink.java @@ -34,27 +34,29 @@ * '\n'} represented in {@code UTF-8} format as the record separator. Each record (including the * last) is terminated. */ -class TextSink extends FileBasedSink { +class TextSink extends FileBasedSink { @Nullable private final String header; @Nullable private final String footer; TextSink( ValueProvider baseOutputFilename, - FilenamePolicy filenamePolicy, + DynamicDestinations dynamicDestinations, @Nullable String header, @Nullable String footer, WritableByteChannelFactory writableByteChannelFactory) { - super(baseOutputFilename, filenamePolicy, writableByteChannelFactory); + super(baseOutputFilename, dynamicDestinations, writableByteChannelFactory); this.header = header; this.footer = footer; } + @Override - public WriteOperation createWriteOperation() { - return new TextWriteOperation(this, header, footer); + public WriteOperation createWriteOperation() { + return new TextWriteOperation<>(this, header, footer); } /** A {@link WriteOperation WriteOperation} for text files. */ - private static class TextWriteOperation extends WriteOperation { + private static class TextWriteOperation + extends WriteOperation { @Nullable private final String header; @Nullable private final String footer; @@ -65,20 +67,20 @@ private TextWriteOperation(TextSink sink, @Nullable String header, @Nullable Str } @Override - public Writer createWriter() throws Exception { - return new TextWriter(this, header, footer); + public Writer createWriter() throws Exception { + return new TextWriter<>(this, header, footer); } } /** A {@link Writer Writer} for text files. */ - private static class TextWriter extends Writer { + private static class TextWriter extends Writer { private static final String NEWLINE = "\n"; @Nullable private final String header; @Nullable private final String footer; private OutputStreamWriter out; public TextWriter( - WriteOperation writeOperation, + WriteOperation writeOperation, @Nullable String header, @Nullable String footer) { super(writeOperation, MimeTypes.TEXT); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/WriteFiles.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/WriteFiles.java index a220eabfe42d9..7013044f600ad 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/WriteFiles.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/WriteFiles.java @@ -20,9 +20,12 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkNotNull; +import com.google.common.base.Objects; import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; import com.google.common.collect.Maps; +import com.google.common.hash.Hashing; +import java.io.IOException; import java.util.List; import java.util.Map; import java.util.UUID; @@ -30,8 +33,11 @@ import javax.annotation.Nullable; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.annotations.Experimental; +import org.apache.beam.sdk.coders.CannotProvideCoderException; import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.Coder.NonDeterministicException; import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.coders.ShardedKeyCoder; import org.apache.beam.sdk.coders.VarIntCoder; import org.apache.beam.sdk.coders.VoidCoder; import org.apache.beam.sdk.io.FileBasedSink.FileResult; @@ -47,6 +53,7 @@ import org.apache.beam.sdk.transforms.GroupByKey; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.transforms.View; import org.apache.beam.sdk.transforms.WithKeys; import org.apache.beam.sdk.transforms.display.DisplayData; @@ -55,6 +62,7 @@ import org.apache.beam.sdk.transforms.windowing.GlobalWindows; import org.apache.beam.sdk.transforms.windowing.PaneInfo; import org.apache.beam.sdk.transforms.windowing.Window; +import org.apache.beam.sdk.util.CoderUtils; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollection.IsBounded; @@ -62,6 +70,7 @@ import org.apache.beam.sdk.values.PCollectionTuple; import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.PDone; +import org.apache.beam.sdk.values.ShardedKey; import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.TupleTagList; import org.slf4j.Logger; @@ -72,13 +81,12 @@ * global initialization of a sink, followed by a parallel write, and ends with a sequential * finalization of the write. The output of a write is {@link PDone}. * - *

    By default, every bundle in the input {@link PCollection} will be processed by a - * {@link WriteOperation}, so the number of output - * will vary based on runner behavior, though at least 1 output will always be produced. The - * exact parallelism of the write stage can be controlled using {@link WriteFiles#withNumShards}, - * typically used to control how many files are produced or to globally limit the number of - * workers connecting to an external service. However, this option can often hurt performance: it - * adds an additional {@link GroupByKey} to the pipeline. + *

    By default, every bundle in the input {@link PCollection} will be processed by a {@link + * WriteOperation}, so the number of output will vary based on runner behavior, though at least 1 + * output will always be produced. The exact parallelism of the write stage can be controlled using + * {@link WriteFiles#withNumShards}, typically used to control how many files are produced or to + * globally limit the number of workers connecting to an external service. However, this option can + * often hurt performance: it adds an additional {@link GroupByKey} to the pipeline. * *

    Example usage with runner-determined sharding: * @@ -89,7 +97,8 @@ *

    {@code p.apply(WriteFiles.to(new MySink(...)).withNumShards(3));}
    */ @Experimental(Experimental.Kind.SOURCE_SINK) -public class WriteFiles extends PTransform, PDone> { +public class WriteFiles + extends PTransform, PDone> { private static final Logger LOG = LoggerFactory.getLogger(WriteFiles.class); // The maximum number of file writers to keep open in a single bundle at a time, since file @@ -105,12 +114,12 @@ public class WriteFiles extends PTransform, PDone> { private static final int SPILLED_RECORD_SHARDING_FACTOR = 10; static final int UNKNOWN_SHARDNUM = -1; - private FileBasedSink sink; - private WriteOperation writeOperation; + private FileBasedSink sink; + private SerializableFunction formatFunction; + private WriteOperation writeOperation; // This allows the number of shards to be dynamically computed based on the input // PCollection. - @Nullable - private final PTransform, PCollectionView> computeNumShards; + @Nullable private final PTransform, PCollectionView> computeNumShards; // We don't use a side input for static sharding, as we want this value to be updatable // when a pipeline is updated. @Nullable @@ -122,19 +131,28 @@ public class WriteFiles extends PTransform, PDone> { * Creates a {@link WriteFiles} transform that writes to the given {@link FileBasedSink}, letting * the runner control how many different shards are produced. */ - public static WriteFiles to(FileBasedSink sink) { + public static WriteFiles to( + FileBasedSink sink, + SerializableFunction formatFunction) { checkNotNull(sink, "sink"); - return new WriteFiles<>(sink, null /* runner-determined sharding */, null, - false, DEFAULT_MAX_NUM_WRITERS_PER_BUNDLE); + return new WriteFiles<>( + sink, + formatFunction, + null /* runner-determined sharding */, + null, + false, + DEFAULT_MAX_NUM_WRITERS_PER_BUNDLE); } private WriteFiles( - FileBasedSink sink, - @Nullable PTransform, PCollectionView> computeNumShards, + FileBasedSink sink, + SerializableFunction formatFunction, + @Nullable PTransform, PCollectionView> computeNumShards, @Nullable ValueProvider numShardsProvider, boolean windowedWrites, int maxNumWritersPerBundle) { this.sink = sink; + this.formatFunction = checkNotNull(formatFunction); this.computeNumShards = computeNumShards; this.numShardsProvider = numShardsProvider; this.windowedWrites = windowedWrites; @@ -142,7 +160,7 @@ private WriteFiles( } @Override - public PDone expand(PCollection input) { + public PDone expand(PCollection input) { if (input.isBounded() == IsBounded.UNBOUNDED) { checkArgument(windowedWrites, "Must use windowed writes when applying %s to an unbounded PCollection", @@ -181,13 +199,16 @@ public void populateDisplayData(DisplayData.Builder builder) { } } - /** - * Returns the {@link FileBasedSink} associated with this PTransform. - */ - public FileBasedSink getSink() { + /** Returns the {@link FileBasedSink} associated with this PTransform. */ + public FileBasedSink getSink() { return sink; } + /** Returns the the format function that maps the user type to the record written to files. */ + public SerializableFunction getFormatFunction() { + return formatFunction; + } + /** * Returns whether or not to perform windowed writes. */ @@ -202,7 +223,7 @@ public boolean isWindowedWrites() { * #withRunnerDeterminedSharding()}. */ @Nullable - public PTransform, PCollectionView> getSharding() { + public PTransform, PCollectionView> getSharding() { return computeNumShards; } @@ -220,7 +241,7 @@ public ValueProvider getNumShards() { *

    A value less than or equal to 0 will be equivalent to the default behavior of * runner-determined sharding. */ - public WriteFiles withNumShards(int numShards) { + public WriteFiles withNumShards(int numShards) { if (numShards > 0) { return withNumShards(StaticValueProvider.of(numShards)); } @@ -234,16 +255,26 @@ public WriteFiles withNumShards(int numShards) { *

    This option should be used sparingly as it can hurt performance. See {@link WriteFiles} for * more information. */ - public WriteFiles withNumShards(ValueProvider numShardsProvider) { - return new WriteFiles<>(sink, null, numShardsProvider, windowedWrites, + public WriteFiles withNumShards( + ValueProvider numShardsProvider) { + return new WriteFiles<>( + sink, + formatFunction, + computeNumShards, + numShardsProvider, + windowedWrites, maxNumWritersPerBundle); } - /** - * Set the maximum number of writers created in a bundle before spilling to shuffle. - */ - public WriteFiles withMaxNumWritersPerBundle(int maxNumWritersPerBundle) { - return new WriteFiles<>(sink, null, numShardsProvider, windowedWrites, + /** Set the maximum number of writers created in a bundle before spilling to shuffle. */ + public WriteFiles withMaxNumWritersPerBundle( + int maxNumWritersPerBundle) { + return new WriteFiles<>( + sink, + formatFunction, + computeNumShards, + numShardsProvider, + windowedWrites, maxNumWritersPerBundle); } @@ -254,97 +285,167 @@ public WriteFiles withMaxNumWritersPerBundle(int maxNumWritersPerBundle) { *

    This option should be used sparingly as it can hurt performance. See {@link WriteFiles} for * more information. */ - public WriteFiles withSharding(PTransform, PCollectionView> sharding) { + public WriteFiles withSharding( + PTransform, PCollectionView> sharding) { checkNotNull( sharding, "Cannot provide null sharding. Use withRunnerDeterminedSharding() instead"); - return new WriteFiles<>(sink, sharding, null, windowedWrites, maxNumWritersPerBundle); + return new WriteFiles<>( + sink, formatFunction, sharding, null, windowedWrites, maxNumWritersPerBundle); } /** * Returns a new {@link WriteFiles} that will write to the current {@link FileBasedSink} with * runner-determined sharding. */ - public WriteFiles withRunnerDeterminedSharding() { - return new WriteFiles<>(sink, null, null, windowedWrites, maxNumWritersPerBundle); + public WriteFiles withRunnerDeterminedSharding() { + return new WriteFiles<>( + sink, formatFunction, null, null, windowedWrites, maxNumWritersPerBundle); } /** * Returns a new {@link WriteFiles} that writes preserves windowing on it's input. * - *

    If this option is not specified, windowing and triggering are replaced by - * {@link GlobalWindows} and {@link DefaultTrigger}. + *

    If this option is not specified, windowing and triggering are replaced by {@link + * GlobalWindows} and {@link DefaultTrigger}. * - *

    If there is no data for a window, no output shards will be generated for that window. - * If a window triggers multiple times, then more than a single output shard might be - * generated multiple times; it's up to the sink implementation to keep these output shards - * unique. + *

    If there is no data for a window, no output shards will be generated for that window. If a + * window triggers multiple times, then more than a single output shard might be generated + * multiple times; it's up to the sink implementation to keep these output shards unique. * - *

    This option can only be used if {@link #withNumShards(int)} is also set to a - * positive value. + *

    This option can only be used if {@link #withNumShards(int)} is also set to a positive value. */ - public WriteFiles withWindowedWrites() { - return new WriteFiles<>(sink, computeNumShards, numShardsProvider, true, - maxNumWritersPerBundle); + public WriteFiles withWindowedWrites() { + return new WriteFiles<>( + sink, formatFunction, computeNumShards, numShardsProvider, true, maxNumWritersPerBundle); + } + + private static class WriterKey { + private final BoundedWindow window; + private final PaneInfo paneInfo; + private final DestinationT destination; + + WriterKey(BoundedWindow window, PaneInfo paneInfo, DestinationT destination) { + this.window = window; + this.paneInfo = paneInfo; + this.destination = destination; + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof WriterKey)) { + return false; + } + WriterKey other = (WriterKey) o; + return Objects.equal(window, other.window) + && Objects.equal(paneInfo, other.paneInfo) + && Objects.equal(destination, other.destination); + } + + @Override + public int hashCode() { + return Objects.hashCode(window, paneInfo, destination); + } + } + + // Hash the destination in a manner that we can then use as a key in a GBK. Since Java's + // hashCode isn't guaranteed to be stable across machines, we instead serialize the destination + // and use murmur3_32 to hash it. We enforce that destinationCoder must be deterministic, so + // this can be used as a key. + private static int hashDestination( + DestinationT destination, Coder destinationCoder) throws IOException { + return Hashing.murmur3_32() + .hashBytes(CoderUtils.encodeToByteArray(destinationCoder, destination)) + .asInt(); } /** - * Writes all the elements in a bundle using a {@link Writer} produced by the - * {@link WriteOperation} associated with the {@link FileBasedSink} with windowed writes enabled. + * Writes all the elements in a bundle using a {@link Writer} produced by the {@link + * WriteOperation} associated with the {@link FileBasedSink}. */ - private class WriteWindowedBundles extends DoFn { - private final TupleTag> unwrittedRecordsTag; - private Map, Writer> windowedWriters; - int spilledShardNum = UNKNOWN_SHARDNUM; - - WriteWindowedBundles(TupleTag> unwrittedRecordsTag) { - this.unwrittedRecordsTag = unwrittedRecordsTag; + private class WriteBundles extends DoFn> { + private final TupleTag, UserT>> unwrittenRecordsTag; + private final Coder destinationCoder; + private final boolean windowedWrites; + + private Map, Writer> writers; + private int spilledShardNum = UNKNOWN_SHARDNUM; + + WriteBundles( + boolean windowedWrites, + TupleTag, UserT>> unwrittenRecordsTag, + Coder destinationCoder) { + this.windowedWrites = windowedWrites; + this.unwrittenRecordsTag = unwrittenRecordsTag; + this.destinationCoder = destinationCoder; } @StartBundle public void startBundle(StartBundleContext c) { // Reset state in case of reuse. We need to make sure that each bundle gets unique writers. - windowedWriters = Maps.newHashMap(); + writers = Maps.newHashMap(); } @ProcessElement public void processElement(ProcessContext c, BoundedWindow window) throws Exception { PaneInfo paneInfo = c.pane(); - Writer writer; // If we are doing windowed writes, we need to ensure that we have separate files for - // data in different windows/panes. - KV key = KV.of(window, paneInfo); - writer = windowedWriters.get(key); + // data in different windows/panes. Similar for dynamic writes, make sure that different + // destinations go to different writers. + // In the case of unwindowed writes, the window and the pane will always be the same, and + // the map will only have a single element. + DestinationT destination = sink.getDynamicDestinations().getDestination(c.element()); + WriterKey key = new WriterKey<>(window, c.pane(), destination); + Writer writer = writers.get(key); if (writer == null) { - if (windowedWriters.size() <= maxNumWritersPerBundle) { + if (writers.size() <= maxNumWritersPerBundle) { String uuid = UUID.randomUUID().toString(); LOG.info( - "Opening writer {} for write operation {}, window {} pane {}", + "Opening writer {} for write operation {}, window {} pane {} destination {}", uuid, writeOperation, window, - paneInfo); + paneInfo, + destination); writer = writeOperation.createWriter(); - writer.openWindowed(uuid, window, paneInfo, UNKNOWN_SHARDNUM); - windowedWriters.put(key, writer); + if (windowedWrites) { + writer.openWindowed(uuid, window, paneInfo, UNKNOWN_SHARDNUM, destination); + } else { + writer.openUnwindowed(uuid, UNKNOWN_SHARDNUM, destination); + } + writers.put(key, writer); LOG.debug("Done opening writer"); } else { if (spilledShardNum == UNKNOWN_SHARDNUM) { + // Cache the random value so we only call ThreadLocalRandom once per DoFn instance. spilledShardNum = ThreadLocalRandom.current().nextInt(SPILLED_RECORD_SHARDING_FACTOR); } else { spilledShardNum = (spilledShardNum + 1) % SPILLED_RECORD_SHARDING_FACTOR; } - c.output(unwrittedRecordsTag, KV.of(spilledShardNum, c.element())); + c.output( + unwrittenRecordsTag, + KV.of( + ShardedKey.of(hashDestination(destination, destinationCoder), spilledShardNum), + c.element())); return; } } - writeOrClose(writer, c.element()); + writeOrClose(writer, formatFunction.apply(c.element())); } @FinishBundle public void finishBundle(FinishBundleContext c) throws Exception { - for (Map.Entry, Writer> entry : windowedWriters.entrySet()) { - FileResult result = entry.getValue().close(); - BoundedWindow window = entry.getKey().getKey(); + for (Map.Entry, Writer> entry : + writers.entrySet()) { + Writer writer = entry.getValue(); + FileResult result; + try { + result = writer.close(); + } catch (Exception e) { + // If anything goes wrong, make sure to delete the temporary file. + writer.cleanup(); + throw e; + } + BoundedWindow window = entry.getKey().window; c.output(result, window.maxTimestamp(), window); } } @@ -355,90 +456,62 @@ public void populateDisplayData(DisplayData.Builder builder) { } } - /** - * Writes all the elements in a bundle using a {@link Writer} produced by the - * {@link WriteOperation} associated with the {@link FileBasedSink} with windowed writes disabled. - */ - private class WriteUnwindowedBundles extends DoFn { - // Writer that will write the records in this bundle. Lazily - // initialized in processElement. - private Writer writer = null; - private BoundedWindow window = null; - - @StartBundle - public void startBundle(StartBundleContext c) { - // Reset state in case of reuse. We need to make sure that each bundle gets unique writers. - writer = null; - } - - @ProcessElement - public void processElement(ProcessContext c, BoundedWindow window) throws Exception { - // Cache a single writer for the bundle. - if (writer == null) { - LOG.info("Opening writer for write operation {}", writeOperation); - writer = writeOperation.createWriter(); - writer.openUnwindowed(UUID.randomUUID().toString(), UNKNOWN_SHARDNUM); - LOG.debug("Done opening writer"); - } - this.window = window; - writeOrClose(this.writer, c.element()); - } + enum ShardAssignment { ASSIGN_IN_FINALIZE, ASSIGN_WHEN_WRITING } - @FinishBundle - public void finishBundle(FinishBundleContext c) throws Exception { - if (writer == null) { - return; - } - FileResult result = writer.close(); - c.output(result, window.maxTimestamp(), window); - } - - @Override - public void populateDisplayData(DisplayData.Builder builder) { - builder.delegate(WriteFiles.this); - } - } - - enum ShardAssignment { ASSIGN_IN_FINALIZE, ASSIGN_WHEN_WRITING }; - - /** - * Like {@link WriteWindowedBundles} and {@link WriteUnwindowedBundles}, but where the elements - * for each shard have been collected into a single iterable. + /* + * Like {@link WriteBundles}, but where the elements for each shard have been collected into a + * single iterable. */ - private class WriteShardedBundles extends DoFn>, FileResult> { + private class WriteShardedBundles + extends DoFn, Iterable>, FileResult> { ShardAssignment shardNumberAssignment; WriteShardedBundles(ShardAssignment shardNumberAssignment) { this.shardNumberAssignment = shardNumberAssignment; } + @ProcessElement public void processElement(ProcessContext c, BoundedWindow window) throws Exception { - // In a sharded write, single input element represents one shard. We can open and close - // the writer in each call to processElement. - LOG.info("Opening writer for write operation {}", writeOperation); - Writer writer = writeOperation.createWriter(); - if (windowedWrites) { - int shardNumber = shardNumberAssignment == ShardAssignment.ASSIGN_WHEN_WRITING - ? c.element().getKey() : UNKNOWN_SHARDNUM; - writer.openWindowed(UUID.randomUUID().toString(), window, c.pane(), shardNumber); - } else { - writer.openUnwindowed(UUID.randomUUID().toString(), UNKNOWN_SHARDNUM); - } - LOG.debug("Done opening writer"); - - try { - for (T t : c.element().getValue()) { - writeOrClose(writer, t); + // Since we key by a 32-bit hash of the destination, there might be multiple destinations + // in this iterable. The number of destinations is generally very small (1000s or less), so + // there will rarely be hash collisions. + Map> writers = Maps.newHashMap(); + for (UserT input : c.element().getValue()) { + DestinationT destination = sink.getDynamicDestinations().getDestination(input); + Writer writer = writers.get(destination); + if (writer == null) { + LOG.debug("Opening writer for write operation {}", writeOperation); + writer = writeOperation.createWriter(); + if (windowedWrites) { + int shardNumber = + shardNumberAssignment == ShardAssignment.ASSIGN_WHEN_WRITING + ? c.element().getKey().getShardNumber() + : UNKNOWN_SHARDNUM; + writer.openWindowed( + UUID.randomUUID().toString(), window, c.pane(), shardNumber, destination); + } else { + writer.openUnwindowed(UUID.randomUUID().toString(), UNKNOWN_SHARDNUM, destination); + } + LOG.debug("Done opening writer"); + writers.put(destination, writer); + } + writeOrClose(writer, formatFunction.apply(input)); } - // Close the writer; if this throws let the error propagate. - FileResult result = writer.close(); - c.output(result); - } catch (Exception e) { - // If anything goes wrong, make sure to delete the temporary file. - writer.cleanup(); - throw e; + // Close all writers. + for (Map.Entry> entry : writers.entrySet()) { + Writer writer = entry.getValue(); + FileResult result; + try { + // Close the writer; if this throws let the error propagate. + result = writer.close(); + c.output(result); + } catch (Exception e) { + // If anything goes wrong, make sure to delete the temporary file. + writer.cleanup(); + throw e; + } + } } - } @Override public void populateDisplayData(DisplayData.Builder builder) { @@ -446,12 +519,15 @@ public void populateDisplayData(DisplayData.Builder builder) { } } - private static void writeOrClose(Writer writer, T t) throws Exception { + private static void writeOrClose( + Writer writer, OutputT t) throws Exception { try { writer.write(t); } catch (Exception e) { try { writer.close(); + // If anything goes wrong, make sure to delete the temporary file. + writer.cleanup(); } catch (Exception closeException) { if (closeException instanceof InterruptedException) { // Do not silently ignore interrupted state. @@ -464,20 +540,25 @@ private static void writeOrClose(Writer writer, T t) throws Exception { } } - private static class ApplyShardingKey extends DoFn> { + private class ApplyShardingKey extends DoFn, UserT>> { private final PCollectionView numShardsView; private final ValueProvider numShardsProvider; + private final Coder destinationCoder; + private int shardNumber; - ApplyShardingKey(PCollectionView numShardsView, - ValueProvider numShardsProvider) { + ApplyShardingKey( + PCollectionView numShardsView, + ValueProvider numShardsProvider, + Coder destinationCoder) { + this.destinationCoder = destinationCoder; this.numShardsView = numShardsView; this.numShardsProvider = numShardsProvider; shardNumber = UNKNOWN_SHARDNUM; } @ProcessElement - public void processElement(ProcessContext context) { + public void processElement(ProcessContext context) throws IOException { final int shardCount; if (numShardsView != null) { shardCount = context.sideInput(numShardsView); @@ -497,86 +578,110 @@ public void processElement(ProcessContext context) { } else { shardNumber = (shardNumber + 1) % shardCount; } - context.output(KV.of(shardNumber, context.element())); + // We avoid using destination itself as a sharding key, because destination is often large. + // e.g. when using {@link DefaultFilenamePolicy}, the destination contains the entire path + // to the file. Often most of the path is constant across all destinations, just the path + // suffix is appended by the destination function. Instead we key by a 32-bit hash (carefully + // chosen to be guaranteed stable), and call getDestination again in the next ParDo to resolve + // the destinations. This does mean that multiple destinations might end up on the same shard, + // however the number of collisions should be small, so there's no need to worry about memory + // issues. + DestinationT destination = sink.getDynamicDestinations().getDestination(context.element()); + context.output( + KV.of( + ShardedKey.of(hashDestination(destination, destinationCoder), shardNumber), + context.element())); } } /** * A write is performed as sequence of three {@link ParDo}'s. * - *

    This singleton collection containing the WriteOperation is then used as a side - * input to a ParDo over the PCollection of elements to write. In this bundle-writing phase, - * {@link WriteOperation#createWriter} is called to obtain a {@link Writer}. - * {@link Writer#open} and {@link Writer#close} are called in - * {@link DoFn.StartBundle} and {@link DoFn.FinishBundle}, respectively, and - * {@link Writer#write} method is called for every element in the bundle. The output - * of this ParDo is a PCollection of writer result objects (see {@link FileBasedSink} - * for a description of writer results)-one for each bundle. + *

    This singleton collection containing the WriteOperation is then used as a side input to a + * ParDo over the PCollection of elements to write. In this bundle-writing phase, {@link + * WriteOperation#createWriter} is called to obtain a {@link Writer}. {@link Writer#open} and + * {@link Writer#close} are called in {@link DoFn.StartBundle} and {@link DoFn.FinishBundle}, + * respectively, and {@link Writer#write} method is called for every element in the bundle. The + * output of this ParDo is a PCollection of writer result objects (see {@link + * FileBasedSink} for a description of writer results)-one for each bundle. * *

    The final do-once ParDo uses a singleton collection asinput and the collection of writer - * results as a side-input. In this ParDo, {@link WriteOperation#finalize} is called - * to finalize the write. + * results as a side-input. In this ParDo, {@link WriteOperation#finalize} is called to finalize + * the write. * - *

    If the write of any element in the PCollection fails, {@link Writer#close} will be - * called before the exception that caused the write to fail is propagated and the write result - * will be discarded. + *

    If the write of any element in the PCollection fails, {@link Writer#close} will be called + * before the exception that caused the write to fail is propagated and the write result will be + * discarded. * *

    Since the {@link WriteOperation} is serialized after the initialization ParDo and * deserialized in the bundle-writing and finalization phases, any state change to the - * WriteOperation object that occurs during initialization is visible in the latter - * phases. However, the WriteOperation is not serialized after the bundle-writing - * phase. This is why implementations should guarantee that - * {@link WriteOperation#createWriter} does not mutate WriteOperation). + * WriteOperation object that occurs during initialization is visible in the latter phases. + * However, the WriteOperation is not serialized after the bundle-writing phase. This is why + * implementations should guarantee that {@link WriteOperation#createWriter} does not mutate + * WriteOperation). */ - private PDone createWrite(PCollection input) { + private PDone createWrite(PCollection input) { Pipeline p = input.getPipeline(); if (!windowedWrites) { // Re-window the data into the global window and remove any existing triggers. input = input.apply( - Window.into(new GlobalWindows()) + Window.into(new GlobalWindows()) .triggering(DefaultTrigger.of()) .discardingFiredPanes()); } - // Perform the per-bundle writes as a ParDo on the input PCollection (with the // WriteOperation as a side input) and collect the results of the writes in a // PCollection. There is a dependency between this ParDo and the first (the // WriteOperation PCollection as a side input), so this will happen after the // initial ParDo. - PCollection results; + PCollection> results; final PCollectionView numShardsView; @SuppressWarnings("unchecked") Coder shardedWindowCoder = (Coder) input.getWindowingStrategy().getWindowFn().windowCoder(); + final Coder destinationCoder; + try { + destinationCoder = + sink.getDynamicDestinations() + .getDestinationCoderWithDefault(input.getPipeline().getCoderRegistry()); + destinationCoder.verifyDeterministic(); + } catch (CannotProvideCoderException | NonDeterministicException e) { + throw new RuntimeException(e); + } + if (computeNumShards == null && numShardsProvider == null) { numShardsView = null; - if (windowedWrites) { - TupleTag writtenRecordsTag = new TupleTag<>("writtenRecordsTag"); - TupleTag> unwrittedRecordsTag = new TupleTag<>("unwrittenRecordsTag"); - PCollectionTuple writeTuple = input.apply("WriteWindowedBundles", ParDo.of( - new WriteWindowedBundles(unwrittedRecordsTag)) - .withOutputTags(writtenRecordsTag, TupleTagList.of(unwrittedRecordsTag))); - PCollection writtenBundleFiles = writeTuple.get(writtenRecordsTag) - .setCoder(FileResultCoder.of(shardedWindowCoder)); - // Any "spilled" elements are written using WriteShardedBundles. Assign shard numbers in - // finalize to stay consistent with what WriteWindowedBundles does. - PCollection writtenGroupedFiles = - writeTuple - .get(unwrittedRecordsTag) - .setCoder(KvCoder.of(VarIntCoder.of(), input.getCoder())) - .apply("GroupUnwritten", GroupByKey.create()) - .apply("WriteUnwritten", ParDo.of( - new WriteShardedBundles(ShardAssignment.ASSIGN_IN_FINALIZE))) - .setCoder(FileResultCoder.of(shardedWindowCoder)); - results = PCollectionList.of(writtenBundleFiles).and(writtenGroupedFiles) - .apply(Flatten.pCollections()); - } else { - results = - input.apply("WriteUnwindowedBundles", ParDo.of(new WriteUnwindowedBundles())); - } + TupleTag> writtenRecordsTag = new TupleTag<>("writtenRecordsTag"); + TupleTag, UserT>> unwrittedRecordsTag = + new TupleTag<>("unwrittenRecordsTag"); + String writeName = windowedWrites ? "WriteWindowedBundles" : "WriteBundles"; + PCollectionTuple writeTuple = + input.apply( + writeName, + ParDo.of(new WriteBundles(windowedWrites, unwrittedRecordsTag, destinationCoder)) + .withOutputTags(writtenRecordsTag, TupleTagList.of(unwrittedRecordsTag))); + PCollection> writtenBundleFiles = + writeTuple + .get(writtenRecordsTag) + .setCoder(FileResultCoder.of(shardedWindowCoder, destinationCoder)); + // Any "spilled" elements are written using WriteShardedBundles. Assign shard numbers in + // finalize to stay consistent with what WriteWindowedBundles does. + PCollection> writtenGroupedFiles = + writeTuple + .get(unwrittedRecordsTag) + .setCoder(KvCoder.of(ShardedKeyCoder.of(VarIntCoder.of()), input.getCoder())) + .apply("GroupUnwritten", GroupByKey., UserT>create()) + .apply( + "WriteUnwritten", + ParDo.of(new WriteShardedBundles(ShardAssignment.ASSIGN_IN_FINALIZE))) + .setCoder(FileResultCoder.of(shardedWindowCoder, destinationCoder)); + results = + PCollectionList.of(writtenBundleFiles) + .and(writtenGroupedFiles) + .apply(Flatten.>pCollections()); } else { List> sideInputs = Lists.newArrayList(); if (computeNumShards != null) { @@ -585,23 +690,31 @@ private PDone createWrite(PCollection input) { } else { numShardsView = null; } - - PCollection>> sharded = + PCollection, Iterable>> sharded = input - .apply("ApplyShardLabel", ParDo.of( - new ApplyShardingKey(numShardsView, - (numShardsView != null) ? null : numShardsProvider)) - .withSideInputs(sideInputs)) - .apply("GroupIntoShards", GroupByKey.create()); + .apply( + "ApplyShardLabel", + ParDo.of( + new ApplyShardingKey( + numShardsView, + (numShardsView != null) ? null : numShardsProvider, + destinationCoder)) + .withSideInputs(sideInputs)) + .setCoder(KvCoder.of(ShardedKeyCoder.of(VarIntCoder.of()), input.getCoder())) + .apply("GroupIntoShards", GroupByKey., UserT>create()); + shardedWindowCoder = + (Coder) sharded.getWindowingStrategy().getWindowFn().windowCoder(); // Since this path might be used by streaming runners processing triggers, it's important // to assign shard numbers here so that they are deterministic. The ASSIGN_IN_FINALIZE // strategy works by sorting all FileResult objects and assigning them numbers, which is not // guaranteed to work well when processing triggers - if the finalize step retries it might // see a different Iterable of FileResult objects, and it will assign different shard numbers. - results = sharded.apply("WriteShardedBundles", - ParDo.of(new WriteShardedBundles(ShardAssignment.ASSIGN_WHEN_WRITING))); + results = + sharded.apply( + "WriteShardedBundles", + ParDo.of(new WriteShardedBundles(ShardAssignment.ASSIGN_WHEN_WRITING))); } - results.setCoder(FileResultCoder.of(shardedWindowCoder)); + results.setCoder(FileResultCoder.of(shardedWindowCoder, destinationCoder)); if (windowedWrites) { // When processing streaming windowed writes, results will arrive multiple times. This @@ -609,26 +722,31 @@ private PDone createWrite(PCollection input) { // as new data arriving into a side input does not trigger the listening DoFn. Instead // we aggregate the result set using a singleton GroupByKey, so the DoFn will be triggered // whenever new data arrives. - PCollection> keyedResults = - results.apply("AttachSingletonKey", WithKeys.of((Void) null)); - keyedResults.setCoder(KvCoder.of(VoidCoder.of(), - FileResultCoder.of(shardedWindowCoder))); + PCollection>> keyedResults = + results.apply( + "AttachSingletonKey", WithKeys.>of((Void) null)); + keyedResults.setCoder( + KvCoder.of(VoidCoder.of(), FileResultCoder.of(shardedWindowCoder, destinationCoder))); // Is the continuation trigger sufficient? keyedResults - .apply("FinalizeGroupByKey", GroupByKey.create()) - .apply("Finalize", ParDo.of(new DoFn>, Integer>() { - @ProcessElement - public void processElement(ProcessContext c) throws Exception { - LOG.info("Finalizing write operation {}.", writeOperation); - List results = Lists.newArrayList(c.element().getValue()); - writeOperation.finalize(results); - LOG.debug("Done finalizing write operation"); - } - })); + .apply("FinalizeGroupByKey", GroupByKey.>create()) + .apply( + "Finalize", + ParDo.of( + new DoFn>>, Integer>() { + @ProcessElement + public void processElement(ProcessContext c) throws Exception { + LOG.info("Finalizing write operation {}.", writeOperation); + List> results = + Lists.newArrayList(c.element().getValue()); + writeOperation.finalize(results); + LOG.debug("Done finalizing write operation"); + } + })); } else { - final PCollectionView> resultsView = - results.apply(View.asIterable()); + final PCollectionView>> resultsView = + results.apply(View.>asIterable()); ImmutableList.Builder> sideInputs = ImmutableList.>builder().add(resultsView); if (numShardsView != null) { @@ -644,41 +762,53 @@ public void processElement(ProcessContext c) throws Exception { // set numShards, then all shards will be written out as empty files. For this reason we // use a side input here. PCollection singletonCollection = p.apply(Create.of((Void) null)); - singletonCollection - .apply("Finalize", ParDo.of(new DoFn() { - @ProcessElement - public void processElement(ProcessContext c) throws Exception { - LOG.info("Finalizing write operation {}.", writeOperation); - List results = Lists.newArrayList(c.sideInput(resultsView)); - LOG.debug("Side input initialized to finalize write operation {}.", writeOperation); - - // We must always output at least 1 shard, and honor user-specified numShards if - // set. - int minShardsNeeded; - if (numShardsView != null) { - minShardsNeeded = c.sideInput(numShardsView); - } else if (numShardsProvider != null) { - minShardsNeeded = numShardsProvider.get(); - } else { - minShardsNeeded = 1; - } - int extraShardsNeeded = minShardsNeeded - results.size(); - if (extraShardsNeeded > 0) { - LOG.info( - "Creating {} empty output shards in addition to {} written for a total of {}.", - extraShardsNeeded, results.size(), minShardsNeeded); - for (int i = 0; i < extraShardsNeeded; ++i) { - Writer writer = writeOperation.createWriter(); - writer.openUnwindowed(UUID.randomUUID().toString(), UNKNOWN_SHARDNUM); - FileResult emptyWrite = writer.close(); - results.add(emptyWrite); - } - LOG.debug("Done creating extra shards."); - } - writeOperation.finalize(results); - LOG.debug("Done finalizing write operation {}", writeOperation); - } - }).withSideInputs(sideInputs.build())); + singletonCollection.apply( + "Finalize", + ParDo.of( + new DoFn() { + @ProcessElement + public void processElement(ProcessContext c) throws Exception { + LOG.info("Finalizing write operation {}.", writeOperation); + List> results = + Lists.newArrayList(c.sideInput(resultsView)); + LOG.debug( + "Side input initialized to finalize write operation {}.", writeOperation); + + // We must always output at least 1 shard, and honor user-specified numShards + // if + // set. + int minShardsNeeded; + if (numShardsView != null) { + minShardsNeeded = c.sideInput(numShardsView); + } else if (numShardsProvider != null) { + minShardsNeeded = numShardsProvider.get(); + } else { + minShardsNeeded = 1; + } + int extraShardsNeeded = minShardsNeeded - results.size(); + if (extraShardsNeeded > 0) { + LOG.info( + "Creating {} empty output shards in addition to {} written " + + "for a total of {}.", + extraShardsNeeded, + results.size(), + minShardsNeeded); + for (int i = 0; i < extraShardsNeeded; ++i) { + Writer writer = writeOperation.createWriter(); + writer.openUnwindowed( + UUID.randomUUID().toString(), + UNKNOWN_SHARDNUM, + sink.getDynamicDestinations().getDefaultDestination()); + FileResult emptyWrite = writer.close(); + results.add(emptyWrite); + } + LOG.debug("Done creating extra shards."); + } + writeOperation.finalize(results); + LOG.debug("Done finalizing write operation {}", writeOperation); + } + }) + .withSideInputs(sideInputs.build())); } return PDone.in(input.getPipeline()); } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/SerializableFunctions.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/SerializableFunctions.java new file mode 100644 index 0000000000000..d057d81eb45d2 --- /dev/null +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/SerializableFunctions.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.sdk.transforms; + +/** Useful {@link SerializableFunction} overrides. */ +public class SerializableFunctions { + private static class Identity implements SerializableFunction { + @Override + public T apply(T input) { + return input; + } + } + + private static class Constant implements SerializableFunction { + OutT value; + + Constant(OutT value) { + this.value = value; + } + + @Override + public OutT apply(InT input) { + return value; + } + } + + public static SerializableFunction identity() { + return new Identity<>(); + } + + public static SerializableFunction constant(OutT value) { + return new Constant<>(value); + } +} diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/ShardedKey.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/ShardedKey.java similarity index 90% rename from sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/ShardedKey.java rename to sdks/java/core/src/main/java/org/apache/beam/sdk/values/ShardedKey.java index c2b739f999e6a..e56af134fefa5 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/ShardedKey.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/ShardedKey.java @@ -16,15 +16,13 @@ * limitations under the License. */ -package org.apache.beam.sdk.io.gcp.bigquery; +package org.apache.beam.sdk.values; import java.io.Serializable; import java.util.Objects; -/** - * A key and a shard number. - */ -class ShardedKey implements Serializable { +/** A key and a shard number. */ +public class ShardedKey implements Serializable { private static final long serialVersionUID = 1L; private final K key; private final int shardNumber; @@ -33,7 +31,7 @@ public static ShardedKey of(K key, int shardNumber) { return new ShardedKey<>(key, shardNumber); } - ShardedKey(K key, int shardNumber) { + private ShardedKey(K key, int shardNumber) { this.key = key; this.shardNumber = shardNumber; } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/AvroIOTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/AvroIOTest.java index 6d01d3237fc38..260e47a25a5cb 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/AvroIOTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/AvroIOTest.java @@ -54,10 +54,11 @@ import org.apache.avro.reflect.ReflectDatumReader; import org.apache.beam.sdk.coders.AvroCoder; import org.apache.beam.sdk.coders.DefaultCoder; -import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.io.FileBasedSink.FilenamePolicy; +import org.apache.beam.sdk.io.FileBasedSink.OutputFileHints; import org.apache.beam.sdk.io.fs.ResolveOptions.StandardResolveOptions; import org.apache.beam.sdk.io.fs.ResourceId; +import org.apache.beam.sdk.options.ValueProvider.StaticValueProvider; import org.apache.beam.sdk.testing.NeedsRunner; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; @@ -276,37 +277,42 @@ public void testAvroIOWriteAndReadSchemaUpgrade() throws Throwable { } private static class WindowedFilenamePolicy extends FilenamePolicy { - final String outputFilePrefix; + final ResourceId outputFilePrefix; - WindowedFilenamePolicy(String outputFilePrefix) { + WindowedFilenamePolicy(ResourceId outputFilePrefix) { this.outputFilePrefix = outputFilePrefix; } @Override - public ResourceId windowedFilename( - ResourceId outputDirectory, WindowedContext input, String extension) { - String filename = String.format( - "%s-%s-%s-of-%s-pane-%s%s%s", - outputFilePrefix, - input.getWindow(), - input.getShardNumber(), - input.getNumShards() - 1, - input.getPaneInfo().getIndex(), - input.getPaneInfo().isLast() ? "-final" : "", - extension); - return outputDirectory.resolve(filename, StandardResolveOptions.RESOLVE_FILE); + public ResourceId windowedFilename(WindowedContext input, OutputFileHints outputFileHints) { + String filenamePrefix = + outputFilePrefix.isDirectory() ? "" : firstNonNull(outputFilePrefix.getFilename(), ""); + + String filename = + String.format( + "%s-%s-%s-of-%s-pane-%s%s%s", + filenamePrefix, + input.getWindow(), + input.getShardNumber(), + input.getNumShards() - 1, + input.getPaneInfo().getIndex(), + input.getPaneInfo().isLast() ? "-final" : "", + outputFileHints.getSuggestedFilenameSuffix()); + return outputFilePrefix + .getCurrentDirectory() + .resolve(filename, StandardResolveOptions.RESOLVE_FILE); } @Override - public ResourceId unwindowedFilename( - ResourceId outputDirectory, Context input, String extension) { + public ResourceId unwindowedFilename(Context input, OutputFileHints outputFileHints) { throw new UnsupportedOperationException("Expecting windowed outputs only"); } @Override public void populateDisplayData(DisplayData.Builder builder) { - builder.add(DisplayData.item("fileNamePrefix", outputFilePrefix) - .withLabel("File Name Prefix")); + builder.add( + DisplayData.item("fileNamePrefix", outputFilePrefix.toString()) + .withLabel("File Name Prefix")); } } @@ -359,15 +365,18 @@ public void testWindowedAvroIOWrite() throws Throwable { Arrays.copyOfRange(secondWindowArray, 1, secondWindowArray.length)) .advanceWatermarkToInfinity(); - FilenamePolicy policy = new WindowedFilenamePolicy(baseFilename); + FilenamePolicy policy = + new WindowedFilenamePolicy(FileBasedSink.convertToFileResourceIfPossible(baseFilename)); windowedAvroWritePipeline .apply(values) .apply(Window.into(FixedWindows.of(Duration.standardMinutes(1)))) - .apply(AvroIO.write(GenericClass.class) - .to(baseFilename) - .withFilenamePolicy(policy) - .withWindowedWrites() - .withNumShards(2)); + .apply( + AvroIO.write(GenericClass.class) + .to(policy) + .withTempDirectory( + StaticValueProvider.of(FileSystems.matchNewResource(baseDir.toString(), true))) + .withWindowedWrites() + .withNumShards(2)); windowedAvroWritePipeline.run(); // Validate that the data written matches the expected elements in the expected order @@ -494,13 +503,14 @@ public static void assertTestOutputs( expectedFiles.add( new File( DefaultFilenamePolicy.constructName( - outputFilePrefix, - shardNameTemplate, - "" /* no suffix */, - i, - numShards, - null, - null))); + FileBasedSink.convertToFileResourceIfPossible(outputFilePrefix), + shardNameTemplate, + "" /* no suffix */, + i, + numShards, + null, + null) + .toString())); } List actualElements = new ArrayList<>(); @@ -572,15 +582,4 @@ public void testWriteDisplayData() { assertThat(displayData, hasDisplayItem("numShards", 100)); assertThat(displayData, hasDisplayItem("codec", CodecFactory.snappyCodec().toString())); } - - @Test - public void testWindowedWriteRequiresFilenamePolicy() { - PCollection emptyInput = p.apply(Create.empty(StringUtf8Coder.of())); - AvroIO.Write write = AvroIO.write(String.class).to("/tmp/some/file").withWindowedWrites(); - - expectedException.expect(IllegalStateException.class); - expectedException.expectMessage( - "When using windowed writes, a filename policy must be set via withFilenamePolicy()"); - emptyInput.apply(write); - } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/DefaultFilenamePolicyTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/DefaultFilenamePolicyTest.java index 217420cac8817..9dc6d33c3409b 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/DefaultFilenamePolicyTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/DefaultFilenamePolicyTest.java @@ -17,9 +17,9 @@ */ package org.apache.beam.sdk.io; -import static org.apache.beam.sdk.io.DefaultFilenamePolicy.constructName; import static org.junit.Assert.assertEquals; +import org.apache.beam.sdk.io.fs.ResourceId; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -30,69 +30,108 @@ @RunWith(JUnit4.class) public class DefaultFilenamePolicyTest { + private static String constructName( + String baseFilename, + String shardTemplate, + String suffix, + int shardNum, + int numShards, + String paneStr, + String windowStr) { + ResourceId constructed = + DefaultFilenamePolicy.constructName( + FileSystems.matchNewResource(baseFilename, false), + shardTemplate, + suffix, + shardNum, + numShards, + paneStr, + windowStr); + return constructed.toString(); + } + @Test public void testConstructName() { - assertEquals("output-001-of-123.txt", - constructName("output", "-SSS-of-NNN", ".txt", 1, 123, null, null)); + assertEquals( + "/path/to/output-001-of-123.txt", + constructName("/path/to/output", "-SSS-of-NNN", ".txt", 1, 123, null, null)); - assertEquals("out.txt/part-00042", - constructName("out.txt", "/part-SSSSS", "", 42, 100, null, null)); + assertEquals( + "/path/to/out.txt/part-00042", + constructName("/path/to/out.txt", "/part-SSSSS", "", 42, 100, null, null)); - assertEquals("out.txt", - constructName("ou", "t.t", "xt", 1, 1, null, null)); + assertEquals("/path/to/out.txt", constructName("/path/to/ou", "t.t", "xt", 1, 1, null, null)); - assertEquals("out0102shard.txt", - constructName("out", "SSNNshard", ".txt", 1, 2, null, null)); + assertEquals( + "/path/to/out0102shard.txt", + constructName("/path/to/out", "SSNNshard", ".txt", 1, 2, null, null)); - assertEquals("out-2/1.part-1-of-2.txt", - constructName("out", "-N/S.part-S-of-N", ".txt", 1, 2, null, null)); + assertEquals( + "/path/to/out-2/1.part-1-of-2.txt", + constructName("/path/to/out", "-N/S.part-S-of-N", ".txt", 1, 2, null, null)); } @Test public void testConstructNameWithLargeShardCount() { - assertEquals("out-100-of-5000.txt", - constructName("out", "-SS-of-NN", ".txt", 100, 5000, null, null)); + assertEquals( + "/out-100-of-5000.txt", constructName("/out", "-SS-of-NN", ".txt", 100, 5000, null, null)); } @Test public void testConstructWindowedName() { - assertEquals("output-001-of-123.txt", - constructName("output", "-SSS-of-NNN", ".txt", 1, 123, null, null)); - - assertEquals("output-001-of-123-PPP-W.txt", - constructName("output", "-SSS-of-NNN-PPP-W", ".txt", 1, 123, null, null)); - - assertEquals("out.txt/part-00042-myPaneStr-myWindowStr", - constructName("out.txt", "/part-SSSSS-P-W", "", 42, 100, "myPaneStr", - "myWindowStr")); - - assertEquals("out.txt", constructName("ou", "t.t", "xt", 1, 1, "myPaneStr2", - "anotherWindowStr")); - - assertEquals("out0102shard-oneMoreWindowStr-anotherPaneStr.txt", - constructName("out", "SSNNshard-W-P", ".txt", 1, 2, "anotherPaneStr", - "oneMoreWindowStr")); - - assertEquals("out-2/1.part-1-of-2-slidingWindow1-myPaneStr3-windowslidingWindow1-" - + "panemyPaneStr3.txt", - constructName("out", "-N/S.part-S-of-N-W-P-windowW-paneP", ".txt", 1, 2, "myPaneStr3", - "slidingWindow1")); + assertEquals( + "/path/to/output-001-of-123.txt", + constructName("/path/to/output", "-SSS-of-NNN", ".txt", 1, 123, null, null)); + + assertEquals( + "/path/to/output-001-of-123-PPP-W.txt", + constructName("/path/to/output", "-SSS-of-NNN-PPP-W", ".txt", 1, 123, null, null)); + + assertEquals( + "/path/to/out" + ".txt/part-00042-myPaneStr-myWindowStr", + constructName( + "/path/to/out.txt", "/part-SSSSS-P-W", "", 42, 100, "myPaneStr", "myWindowStr")); + + assertEquals( + "/path/to/out.txt", + constructName("/path/to/ou", "t.t", "xt", 1, 1, "myPaneStr2", "anotherWindowStr")); + + assertEquals( + "/path/to/out0102shard-oneMoreWindowStr-anotherPaneStr.txt", + constructName( + "/path/to/out", "SSNNshard-W-P", ".txt", 1, 2, "anotherPaneStr", "oneMoreWindowStr")); + + assertEquals( + "/out-2/1.part-1-of-2-slidingWindow1-myPaneStr3-windowslidingWindow1-" + + "panemyPaneStr3.txt", + constructName( + "/out", + "-N/S.part-S-of-N-W-P-windowW-paneP", + ".txt", + 1, + 2, + "myPaneStr3", + "slidingWindow1")); // test first/last pane - assertEquals("out.txt/part-00042-myWindowStr-pane-11-true-false", - constructName("out.txt", "/part-SSSSS-W-P", "", 42, 100, "pane-11-true-false", - "myWindowStr")); - - assertEquals("out.txt", constructName("ou", "t.t", "xt", 1, 1, "pane", - "anotherWindowStr")); - - assertEquals("out0102shard-oneMoreWindowStr-pane--1-false-false-pane--1-false-false.txt", - constructName("out", "SSNNshard-W-P-P", ".txt", 1, 2, "pane--1-false-false", - "oneMoreWindowStr")); - - assertEquals("out-2/1.part-1-of-2-sWindow1-winsWindow1-ppaneL.txt", - constructName("out", - "-N/S.part-S-of-N-W-winW-pP", ".txt", 1, 2, "paneL", "sWindow1")); + assertEquals( + "/out.txt/part-00042-myWindowStr-pane-11-true-false", + constructName( + "/out.txt", "/part-SSSSS-W-P", "", 42, 100, "pane-11-true-false", "myWindowStr")); + + assertEquals( + "/path/to/out.txt", + constructName("/path/to/ou", "t.t", "xt", 1, 1, "pane", "anotherWindowStr")); + + assertEquals( + "/out0102shard-oneMoreWindowStr-pane--1-false-false-pane--1-false-false.txt", + constructName( + "/out", "SSNNshard-W-P-P", ".txt", 1, 2, "pane--1-false-false", "oneMoreWindowStr")); + + assertEquals( + "/path/to/out-2/1.part-1-of-2-sWindow1-winsWindow1-ppaneL.txt", + constructName( + "/path/to/out", "-N/S.part-S-of-N-W-winW-pP", ".txt", 1, 2, "paneL", "sWindow1")); } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/DrunkWritableByteChannelFactory.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/DrunkWritableByteChannelFactory.java index 6615a2e9d3333..a7644b6365258 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/DrunkWritableByteChannelFactory.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/DrunkWritableByteChannelFactory.java @@ -39,7 +39,7 @@ public String getMimeType() { } @Override - public String getFilenameSuffix() { + public String getSuggestedFilenameSuffix() { return ".drunk"; } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/FileBasedSinkTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/FileBasedSinkTest.java index caad75989cbd4..755bb598524d6 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/FileBasedSinkTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/FileBasedSinkTest.java @@ -103,7 +103,7 @@ public void testWriter() throws Exception { SimpleSink.SimpleWriter writer = buildWriteOperationWithTempDir(getBaseTempDirectory()).createWriter(); - writer.openUnwindowed(testUid, -1); + writer.openUnwindowed(testUid, -1, null); for (String value : values) { writer.write(value); } @@ -198,23 +198,27 @@ private void runFinalize(SimpleSink.SimpleWriteOperation writeOp, List tem throws Exception { int numFiles = temporaryFiles.size(); - List fileResults = new ArrayList<>(); + List> fileResults = new ArrayList<>(); // Create temporary output bundles and output File objects. for (int i = 0; i < numFiles; i++) { fileResults.add( - new FileResult( + new FileResult( LocalResources.fromFile(temporaryFiles.get(i), false), WriteFiles.UNKNOWN_SHARDNUM, null, + null, null)); } writeOp.finalize(fileResults); - ResourceId outputDirectory = writeOp.getSink().getBaseOutputDirectoryProvider().get(); for (int i = 0; i < numFiles; i++) { - ResourceId outputFilename = writeOp.getSink().getFilenamePolicy() - .unwindowedFilename(outputDirectory, new Context(i, numFiles), ""); + ResourceId outputFilename = + writeOp + .getSink() + .getDynamicDestinations() + .getFilenamePolicy(null) + .unwindowedFilename(new Context(i, numFiles), CompressionType.UNCOMPRESSED); assertTrue(new File(outputFilename.toString()).exists()); assertFalse(temporaryFiles.get(i).exists()); } @@ -231,11 +235,12 @@ private void runFinalize(SimpleSink.SimpleWriteOperation writeOp, List tem private void testRemoveTemporaryFiles(int numFiles, ResourceId tempDirectory) throws Exception { String prefix = "file"; - SimpleSink sink = - new SimpleSink(getBaseOutputDirectory(), prefix, "", ""); + SimpleSink sink = + SimpleSink.makeSimpleSink( + getBaseOutputDirectory(), prefix, "", "", CompressionType.UNCOMPRESSED); - WriteOperation writeOp = - new SimpleSink.SimpleWriteOperation(sink, tempDirectory); + WriteOperation writeOp = + new SimpleSink.SimpleWriteOperation<>(sink, tempDirectory); List temporaryFiles = new ArrayList<>(); List outputFiles = new ArrayList<>(); @@ -272,8 +277,6 @@ private void testRemoveTemporaryFiles(int numFiles, ResourceId tempDirectory) @Test public void testCopyToOutputFiles() throws Exception { SimpleSink.SimpleWriteOperation writeOp = buildWriteOperation(); - ResourceId outputDirectory = writeOp.getSink().getBaseOutputDirectoryProvider().get(); - List inputFilenames = Arrays.asList("input-1", "input-2", "input-3"); List inputContents = Arrays.asList("1", "2", "3"); List expectedOutputFilenames = Arrays.asList( @@ -292,9 +295,14 @@ public void testCopyToOutputFiles() throws Exception { File inputTmpFile = tmpFolder.newFile(inputFilenames.get(i)); List lines = Collections.singletonList(inputContents.get(i)); writeFile(lines, inputTmpFile); - inputFilePaths.put(LocalResources.fromFile(inputTmpFile, false), - writeOp.getSink().getFilenamePolicy() - .unwindowedFilename(outputDirectory, new Context(i, inputFilenames.size()), "")); + inputFilePaths.put( + LocalResources.fromFile(inputTmpFile, false), + writeOp + .getSink() + .getDynamicDestinations() + .getFilenamePolicy(null) + .unwindowedFilename( + new Context(i, inputFilenames.size()), CompressionType.UNCOMPRESSED)); } // Copy input files to output files. @@ -311,7 +319,8 @@ public List generateDestinationFilenames( ResourceId outputDirectory, FilenamePolicy policy, int numFiles) { List filenames = new ArrayList<>(); for (int i = 0; i < numFiles; i++) { - filenames.add(policy.unwindowedFilename(outputDirectory, new Context(i, numFiles), "")); + filenames.add( + policy.unwindowedFilename(new Context(i, numFiles), CompressionType.UNCOMPRESSED)); } return filenames; } @@ -326,8 +335,10 @@ public void testGenerateOutputFilenames() { List actual; ResourceId root = getBaseOutputDirectory(); - SimpleSink sink = new SimpleSink(root, "file", ".SSSSS.of.NNNNN", ".test"); - FilenamePolicy policy = sink.getFilenamePolicy(); + SimpleSink sink = + SimpleSink.makeSimpleSink( + root, "file", ".SSSSS.of.NNNNN", ".test", CompressionType.UNCOMPRESSED); + FilenamePolicy policy = sink.getDynamicDestinations().getFilenamePolicy(null); expected = Arrays.asList( root.resolve("file.00000.of.00003.test", StandardResolveOptions.RESOLVE_FILE), @@ -352,8 +363,9 @@ public void testGenerateOutputFilenames() { @Test public void testCollidingOutputFilenames() throws IOException { ResourceId root = getBaseOutputDirectory(); - SimpleSink sink = new SimpleSink(root, "file", "-NN", "test"); - SimpleSink.SimpleWriteOperation writeOp = new SimpleSink.SimpleWriteOperation(sink); + SimpleSink sink = + SimpleSink.makeSimpleSink(root, "file", "-NN", "test", CompressionType.UNCOMPRESSED); + SimpleSink.SimpleWriteOperation writeOp = new SimpleSink.SimpleWriteOperation<>(sink); ResourceId temp1 = root.resolve("temp1", StandardResolveOptions.RESOLVE_FILE); ResourceId temp2 = root.resolve("temp2", StandardResolveOptions.RESOLVE_FILE); @@ -361,11 +373,11 @@ public void testCollidingOutputFilenames() throws IOException { ResourceId output = root.resolve("file-03.test", StandardResolveOptions.RESOLVE_FILE); // More than one shard does. try { - Iterable results = + Iterable> results = Lists.newArrayList( - new FileResult(temp1, 1, null, null), - new FileResult(temp2, 1, null, null), - new FileResult(temp3, 1, null, null)); + new FileResult(temp1, 1, null, null, null), + new FileResult(temp2, 1, null, null, null), + new FileResult(temp3, 1, null, null, null)); writeOp.buildOutputFilenames(results); fail("Should have failed."); } catch (IllegalStateException exn) { @@ -379,8 +391,10 @@ public void testGenerateOutputFilenamesWithoutExtension() { List expected; List actual; ResourceId root = getBaseOutputDirectory(); - SimpleSink sink = new SimpleSink(root, "file", "-SSSSS-of-NNNNN", ""); - FilenamePolicy policy = sink.getFilenamePolicy(); + SimpleSink sink = + SimpleSink.makeSimpleSink( + root, "file", "-SSSSS-of-NNNNN", "", CompressionType.UNCOMPRESSED); + FilenamePolicy policy = sink.getDynamicDestinations().getFilenamePolicy(null); expected = Arrays.asList( root.resolve("file-00000-of-00003", StandardResolveOptions.RESOLVE_FILE), @@ -486,10 +500,11 @@ private File writeValuesWithWritableByteChannelFactory(final WritableByteChannel public void testFileBasedWriterWithWritableByteChannelFactory() throws Exception { final String testUid = "testId"; ResourceId root = getBaseOutputDirectory(); - WriteOperation writeOp = - new SimpleSink(root, "file", "-SS-of-NN", "txt", new DrunkWritableByteChannelFactory()) + WriteOperation writeOp = + SimpleSink.makeSimpleSink( + root, "file", "-SS-of-NN", "txt", new DrunkWritableByteChannelFactory()) .createWriteOperation(); - final Writer writer = writeOp.createWriter(); + final Writer writer = writeOp.createWriter(); final ResourceId expectedFile = writeOp.tempDirectory.get().resolve(testUid, StandardResolveOptions.RESOLVE_FILE); @@ -503,7 +518,7 @@ public void testFileBasedWriterWithWritableByteChannelFactory() throws Exception expected.add("footer"); expected.add("footer"); - writer.openUnwindowed(testUid, -1); + writer.openUnwindowed(testUid, -1, null); writer.write("a"); writer.write("b"); final FileResult result = writer.close(); @@ -513,20 +528,20 @@ public void testFileBasedWriterWithWritableByteChannelFactory() throws Exception } /** Build a SimpleSink with default options. */ - private SimpleSink buildSink() { - return new SimpleSink(getBaseOutputDirectory(), "file", "-SS-of-NN", ".test"); + private SimpleSink buildSink() { + return SimpleSink.makeSimpleSink( + getBaseOutputDirectory(), "file", "-SS-of-NN", ".test", CompressionType.UNCOMPRESSED); } - /** - * Build a SimpleWriteOperation with default options and the given temporary directory. - */ - private SimpleSink.SimpleWriteOperation buildWriteOperationWithTempDir(ResourceId tempDirectory) { - SimpleSink sink = buildSink(); - return new SimpleSink.SimpleWriteOperation(sink, tempDirectory); + /** Build a SimpleWriteOperation with default options and the given temporary directory. */ + private SimpleSink.SimpleWriteOperation buildWriteOperationWithTempDir( + ResourceId tempDirectory) { + SimpleSink sink = buildSink(); + return new SimpleSink.SimpleWriteOperation<>(sink, tempDirectory); } /** Build a write operation with the default options for it and its parent sink. */ - private SimpleSink.SimpleWriteOperation buildWriteOperation() { + private SimpleSink.SimpleWriteOperation buildWriteOperation() { return buildSink().createWriteOperation(); } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/SimpleSink.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/SimpleSink.java index bdf37f635ef5f..9196178104dbd 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/SimpleSink.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/SimpleSink.java @@ -19,37 +19,55 @@ import java.nio.ByteBuffer; import java.nio.channels.WritableByteChannel; +import org.apache.beam.sdk.io.DefaultFilenamePolicy.Params; +import org.apache.beam.sdk.io.fs.ResolveOptions.StandardResolveOptions; import org.apache.beam.sdk.io.fs.ResourceId; import org.apache.beam.sdk.options.ValueProvider.StaticValueProvider; import org.apache.beam.sdk.util.MimeTypes; /** - * A simple {@link FileBasedSink} that writes {@link String} values as lines with - * header and footer. + * A simple {@link FileBasedSink} that writes {@link String} values as lines with header and footer. */ -class SimpleSink extends FileBasedSink { - public SimpleSink(ResourceId baseOutputDirectory, String prefix, String template, String suffix) { - this(baseOutputDirectory, prefix, template, suffix, CompressionType.UNCOMPRESSED); +class SimpleSink extends FileBasedSink { + public SimpleSink( + ResourceId tempDirectory, + DynamicDestinations dynamicDestinations, + WritableByteChannelFactory writableByteChannelFactory) { + super(StaticValueProvider.of(tempDirectory), dynamicDestinations, writableByteChannelFactory); } - public SimpleSink(ResourceId baseOutputDirectory, String prefix, String template, String suffix, - WritableByteChannelFactory writableByteChannelFactory) { - super( - StaticValueProvider.of(baseOutputDirectory), - new DefaultFilenamePolicy(StaticValueProvider.of(prefix), template, suffix), - writableByteChannelFactory); + public static SimpleSink makeSimpleSink( + ResourceId tempDirectory, FilenamePolicy filenamePolicy) { + return new SimpleSink<>( + tempDirectory, + DynamicFileDestinations.constant(filenamePolicy), + CompressionType.UNCOMPRESSED); } - public SimpleSink(ResourceId baseOutputDirectory, FilenamePolicy filenamePolicy) { - super(StaticValueProvider.of(baseOutputDirectory), filenamePolicy); + public static SimpleSink makeSimpleSink( + ResourceId baseDirectory, + String prefix, + String shardTemplate, + String suffix, + WritableByteChannelFactory writableByteChannelFactory) { + DynamicDestinations dynamicDestinations = + DynamicFileDestinations.constant( + DefaultFilenamePolicy.fromParams( + new Params() + .withBaseFilename( + baseDirectory.resolve(prefix, StandardResolveOptions.RESOLVE_FILE)) + .withShardTemplate(shardTemplate) + .withSuffix(suffix))); + return new SimpleSink<>(baseDirectory, dynamicDestinations, writableByteChannelFactory); } @Override - public SimpleWriteOperation createWriteOperation() { - return new SimpleWriteOperation(this); + public SimpleWriteOperation createWriteOperation() { + return new SimpleWriteOperation<>(this); } - static final class SimpleWriteOperation extends WriteOperation { + static final class SimpleWriteOperation + extends WriteOperation { public SimpleWriteOperation(SimpleSink sink, ResourceId tempOutputDirectory) { super(sink, tempOutputDirectory); } @@ -59,12 +77,12 @@ public SimpleWriteOperation(SimpleSink sink) { } @Override - public SimpleWriter createWriter() throws Exception { - return new SimpleWriter(this); + public SimpleWriter createWriter() throws Exception { + return new SimpleWriter<>(this); } } - static final class SimpleWriter extends Writer { + static final class SimpleWriter extends Writer { static final String HEADER = "header"; static final String FOOTER = "footer"; diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TextIOTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TextIOTest.java index 9468893b3d882..8797ff76c7942 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TextIOTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TextIOTest.java @@ -42,7 +42,9 @@ import static org.junit.Assert.assertTrue; import com.google.common.base.Function; +import com.google.common.base.Functions; import com.google.common.base.Predicate; +import com.google.common.base.Predicates; import com.google.common.collect.FluentIterable; import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterables; @@ -69,22 +71,31 @@ import java.util.zip.ZipEntry; import java.util.zip.ZipOutputStream; import javax.annotation.Nullable; +import org.apache.beam.sdk.coders.AvroCoder; import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.DefaultCoder; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.io.BoundedSource.BoundedReader; +import org.apache.beam.sdk.io.DefaultFilenamePolicy.Params; +import org.apache.beam.sdk.io.FileBasedSink.DynamicDestinations; +import org.apache.beam.sdk.io.FileBasedSink.FilenamePolicy; import org.apache.beam.sdk.io.FileBasedSink.WritableByteChannelFactory; import org.apache.beam.sdk.io.TextIO.CompressionType; import org.apache.beam.sdk.io.fs.MatchResult; import org.apache.beam.sdk.io.fs.MatchResult.Metadata; +import org.apache.beam.sdk.io.fs.ResolveOptions.StandardResolveOptions; +import org.apache.beam.sdk.io.fs.ResourceId; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.options.ValueProvider; +import org.apache.beam.sdk.options.ValueProvider.StaticValueProvider; import org.apache.beam.sdk.testing.NeedsRunner; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.SourceTestUtils; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.testing.ValidatesRunner; import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.transforms.display.DisplayData; import org.apache.beam.sdk.transforms.display.DisplayDataEvaluator; import org.apache.beam.sdk.util.CoderUtils; @@ -205,7 +216,7 @@ public FileVisitResult postVisitDirectory(Path dir, IOException exc) throws IOEx }); } - private void runTestRead(String[] expected) throws Exception { + private void runTestRead(String[] expected) throws Exception { File tmpFile = Files.createTempFile(tempFolder, "file", "txt").toFile(); String filename = tmpFile.getPath(); @@ -274,6 +285,213 @@ public void testPrimitiveReadDisplayData() { displayData, hasItem(hasDisplayItem(hasValue(startsWith("foobar"))))); } + static class TestDynamicDestinations extends DynamicDestinations { + ResourceId baseDir; + + TestDynamicDestinations(ResourceId baseDir) { + this.baseDir = baseDir; + } + + @Override + public String getDestination(String element) { + // Destination is based on first character of string. + return element.substring(0, 1); + } + + @Override + public String getDefaultDestination() { + return ""; + } + + @Nullable + @Override + public Coder getDestinationCoder() { + return StringUtf8Coder.of(); + } + + @Override + public FilenamePolicy getFilenamePolicy(String destination) { + return DefaultFilenamePolicy.fromStandardParameters( + StaticValueProvider.of( + baseDir.resolve("file_" + destination + ".txt", StandardResolveOptions.RESOLVE_FILE)), + null, + null, + false); + } + } + + class StartsWith implements Predicate { + String prefix; + + StartsWith(String prefix) { + this.prefix = prefix; + } + + @Override + public boolean apply(@Nullable String input) { + return input.startsWith(prefix); + } + } + + @Test + @Category(NeedsRunner.class) + public void testDynamicDestinations() throws Exception { + ResourceId baseDir = + FileSystems.matchNewResource( + Files.createTempDirectory(tempFolder, "testDynamicDestinations").toString(), true); + + List elements = Lists.newArrayList("aaaa", "aaab", "baaa", "baab", "caaa", "caab"); + PCollection input = p.apply(Create.of(elements).withCoder(StringUtf8Coder.of())); + input.apply( + TextIO.write() + .to(new TestDynamicDestinations(baseDir)) + .withTempDirectory(FileSystems.matchNewResource(baseDir.toString(), true))); + p.run(); + + assertOutputFiles( + Iterables.toArray(Iterables.filter(elements, new StartsWith("a")), String.class), + null, + null, + 0, + baseDir.resolve("file_a.txt", StandardResolveOptions.RESOLVE_FILE), + DefaultFilenamePolicy.DEFAULT_UNWINDOWED_SHARD_TEMPLATE); + assertOutputFiles( + Iterables.toArray(Iterables.filter(elements, new StartsWith("b")), String.class), + null, + null, + 0, + baseDir.resolve("file_b.txt", StandardResolveOptions.RESOLVE_FILE), + DefaultFilenamePolicy.DEFAULT_UNWINDOWED_SHARD_TEMPLATE); + assertOutputFiles( + Iterables.toArray(Iterables.filter(elements, new StartsWith("c")), String.class), + null, + null, + 0, + baseDir.resolve("file_c.txt", StandardResolveOptions.RESOLVE_FILE), + DefaultFilenamePolicy.DEFAULT_UNWINDOWED_SHARD_TEMPLATE); + } + + @DefaultCoder(AvroCoder.class) + private static class UserWriteType { + String destination; + String metadata; + + UserWriteType() { + this.destination = ""; + this.metadata = ""; + } + + UserWriteType(String destination, String metadata) { + this.destination = destination; + this.metadata = metadata; + } + + @Override + public String toString() { + return String.format("destination: %s metadata : %s", destination, metadata); + } + } + + private static class SerializeUserWrite implements SerializableFunction { + @Override + public String apply(UserWriteType input) { + return input.toString(); + } + } + + private static class UserWriteDestination implements SerializableFunction { + private ResourceId baseDir; + + UserWriteDestination(ResourceId baseDir) { + this.baseDir = baseDir; + } + + @Override + public Params apply(UserWriteType input) { + return new Params() + .withBaseFilename( + baseDir.resolve( + "file_" + input.destination.substring(0, 1) + ".txt", + StandardResolveOptions.RESOLVE_FILE)); + } + } + + private static class ExtractWriteDestination implements Function { + @Override + public String apply(@Nullable UserWriteType input) { + return input.destination; + } + } + + @Test + @Category(NeedsRunner.class) + public void testDynamicDefaultFilenamePolicy() throws Exception { + ResourceId baseDir = + FileSystems.matchNewResource( + Files.createTempDirectory(tempFolder, "testDynamicDestinations").toString(), true); + + List elements = + Lists.newArrayList( + new UserWriteType("aaaa", "first"), + new UserWriteType("aaab", "second"), + new UserWriteType("baaa", "third"), + new UserWriteType("baab", "fourth"), + new UserWriteType("caaa", "fifth"), + new UserWriteType("caab", "sixth")); + PCollection input = p.apply(Create.of(elements)); + input.apply( + TextIO.writeCustomType(new SerializeUserWrite()) + .to(new UserWriteDestination(baseDir), new Params()) + .withTempDirectory(FileSystems.matchNewResource(baseDir.toString(), true))); + p.run(); + + String[] aElements = + Iterables.toArray( + Iterables.transform( + Iterables.filter( + elements, + Predicates.compose(new StartsWith("a"), new ExtractWriteDestination())), + Functions.toStringFunction()), + String.class); + String[] bElements = + Iterables.toArray( + Iterables.transform( + Iterables.filter( + elements, + Predicates.compose(new StartsWith("b"), new ExtractWriteDestination())), + Functions.toStringFunction()), + String.class); + String[] cElements = + Iterables.toArray( + Iterables.transform( + Iterables.filter( + elements, + Predicates.compose(new StartsWith("c"), new ExtractWriteDestination())), + Functions.toStringFunction()), + String.class); + assertOutputFiles( + aElements, + null, + null, + 0, + baseDir.resolve("file_a.txt", StandardResolveOptions.RESOLVE_FILE), + DefaultFilenamePolicy.DEFAULT_UNWINDOWED_SHARD_TEMPLATE); + assertOutputFiles( + bElements, + null, + null, + 0, + baseDir.resolve("file_b.txt", StandardResolveOptions.RESOLVE_FILE), + DefaultFilenamePolicy.DEFAULT_UNWINDOWED_SHARD_TEMPLATE); + assertOutputFiles( + cElements, + null, + null, + 0, + baseDir.resolve("file_c.txt", StandardResolveOptions.RESOLVE_FILE), + DefaultFilenamePolicy.DEFAULT_UNWINDOWED_SHARD_TEMPLATE); + } + private void runTestWrite(String[] elems) throws Exception { runTestWrite(elems, null, null, 1); } @@ -291,7 +509,8 @@ private void runTestWrite( String[] elems, String header, String footer, int numShards) throws Exception { String outputName = "file.txt"; Path baseDir = Files.createTempDirectory(tempFolder, "testwrite"); - String baseFilename = baseDir.resolve(outputName).toString(); + ResourceId baseFilename = + FileBasedSink.convertToFileResourceIfPossible(baseDir.resolve(outputName).toString()); PCollection input = p.apply(Create.of(Arrays.asList(elems)).withCoder(StringUtf8Coder.of())); @@ -311,8 +530,14 @@ private void runTestWrite( p.run(); - assertOutputFiles(elems, header, footer, numShards, baseDir, outputName, - firstNonNull(write.getShardTemplate(), + assertOutputFiles( + elems, + header, + footer, + numShards, + baseFilename, + firstNonNull( + write.inner.getShardTemplate(), DefaultFilenamePolicy.DEFAULT_UNWINDOWED_SHARD_TEMPLATE)); } @@ -321,13 +546,12 @@ public static void assertOutputFiles( final String header, final String footer, int numShards, - Path rootLocation, - String outputName, + ResourceId outputPrefix, String shardNameTemplate) throws Exception { List expectedFiles = new ArrayList<>(); if (numShards == 0) { - String pattern = rootLocation.toAbsolutePath().resolve(outputName + "*").toString(); + String pattern = outputPrefix.toString() + "*"; List matches = FileSystems.match(Collections.singletonList(pattern)); for (Metadata expectedFile : Iterables.getOnlyElement(matches).metadata()) { expectedFiles.add(new File(expectedFile.resourceId().toString())); @@ -336,9 +560,9 @@ public static void assertOutputFiles( for (int i = 0; i < numShards; i++) { expectedFiles.add( new File( - rootLocation.toString(), DefaultFilenamePolicy.constructName( - outputName, shardNameTemplate, "", i, numShards, null, null))); + outputPrefix, shardNameTemplate, "", i, numShards, null, null) + .toString())); } } @@ -456,14 +680,19 @@ public void testWriteWithHeaderAndFooter() throws Exception { public void testWriteWithWritableByteChannelFactory() throws Exception { Coder coder = StringUtf8Coder.of(); String outputName = "file.txt"; - Path baseDir = Files.createTempDirectory(tempFolder, "testwrite"); + ResourceId baseDir = + FileSystems.matchNewResource( + Files.createTempDirectory(tempFolder, "testwrite").toString(), true); PCollection input = p.apply(Create.of(Arrays.asList(LINES2_ARRAY)).withCoder(coder)); final WritableByteChannelFactory writableByteChannelFactory = new DrunkWritableByteChannelFactory(); - TextIO.Write write = TextIO.write().to(baseDir.resolve(outputName).toString()) - .withoutSharding().withWritableByteChannelFactory(writableByteChannelFactory); + TextIO.Write write = + TextIO.write() + .to(baseDir.resolve(outputName, StandardResolveOptions.RESOLVE_FILE).toString()) + .withoutSharding() + .withWritableByteChannelFactory(writableByteChannelFactory); DisplayData displayData = DisplayData.from(write); assertThat(displayData, hasDisplayItem("writableByteChannelFactory", "DRUNK")); @@ -476,8 +705,15 @@ public void testWriteWithWritableByteChannelFactory() throws Exception { drunkElems.add(elem); drunkElems.add(elem); } - assertOutputFiles(drunkElems.toArray(new String[0]), null, null, 1, baseDir, - outputName + writableByteChannelFactory.getFilenameSuffix(), write.getShardTemplate()); + assertOutputFiles( + drunkElems.toArray(new String[0]), + null, + null, + 1, + baseDir.resolve( + outputName + writableByteChannelFactory.getSuggestedFilenameSuffix(), + StandardResolveOptions.RESOLVE_FILE), + write.inner.getShardTemplate()); } @Test diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/WriteFilesTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/WriteFilesTest.java index e6a0dcf2c66fa..55f2a87205601 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/WriteFilesTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/WriteFilesTest.java @@ -17,6 +17,7 @@ */ package org.apache.beam.sdk.io; +import static com.google.common.base.MoreObjects.firstNonNull; import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem; import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.includesDisplayDataFor; import static org.hamcrest.Matchers.containsInAnyOrder; @@ -41,7 +42,11 @@ import java.util.concurrent.ThreadLocalRandom; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.io.DefaultFilenamePolicy.Params; +import org.apache.beam.sdk.io.FileBasedSink.CompressionType; +import org.apache.beam.sdk.io.FileBasedSink.DynamicDestinations; import org.apache.beam.sdk.io.FileBasedSink.FilenamePolicy; +import org.apache.beam.sdk.io.FileBasedSink.OutputFileHints; import org.apache.beam.sdk.io.SimpleSink.SimpleWriter; import org.apache.beam.sdk.io.fs.MatchResult.Metadata; import org.apache.beam.sdk.io.fs.ResolveOptions.StandardResolveOptions; @@ -58,16 +63,20 @@ import org.apache.beam.sdk.transforms.MapElements; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.transforms.SerializableFunctions; import org.apache.beam.sdk.transforms.SimpleFunction; import org.apache.beam.sdk.transforms.Top; import org.apache.beam.sdk.transforms.View; import org.apache.beam.sdk.transforms.display.DisplayData; +import org.apache.beam.sdk.transforms.display.DisplayData.Builder; import org.apache.beam.sdk.transforms.windowing.FixedWindows; import org.apache.beam.sdk.transforms.windowing.IntervalWindow; import org.apache.beam.sdk.transforms.windowing.Sessions; import org.apache.beam.sdk.transforms.windowing.Window; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollection.IsBounded; import org.apache.beam.sdk.values.PCollectionView; import org.joda.time.Duration; import org.joda.time.format.DateTimeFormatter; @@ -164,7 +173,11 @@ private String getBaseOutputFilename() { public void testWrite() throws IOException { List inputs = Arrays.asList("Critical canary", "Apprehensive eagle", "Intimidating pigeon", "Pedantic gull", "Frisky finch"); - runWrite(inputs, IDENTITY_MAP, getBaseOutputFilename(), WriteFiles.to(makeSimpleSink())); + runWrite( + inputs, + IDENTITY_MAP, + getBaseOutputFilename(), + WriteFiles.to(makeSimpleSink(), SerializableFunctions.identity())); } /** @@ -173,8 +186,11 @@ public void testWrite() throws IOException { @Test @Category(NeedsRunner.class) public void testEmptyWrite() throws IOException { - runWrite(Collections.emptyList(), IDENTITY_MAP, getBaseOutputFilename(), - WriteFiles.to(makeSimpleSink())); + runWrite( + Collections.emptyList(), + IDENTITY_MAP, + getBaseOutputFilename(), + WriteFiles.to(makeSimpleSink(), SerializableFunctions.identity())); checkFileContents(getBaseOutputFilename(), Collections.emptyList(), Optional.of(1)); } @@ -190,7 +206,7 @@ public void testShardedWrite() throws IOException { Arrays.asList("one", "two", "three", "four", "five", "six"), IDENTITY_MAP, getBaseOutputFilename(), - WriteFiles.to(makeSimpleSink()).withNumShards(1)); + WriteFiles.to(makeSimpleSink(), SerializableFunctions.identity()).withNumShards(1)); } private ResourceId getBaseOutputDirectory() { @@ -198,9 +214,13 @@ private ResourceId getBaseOutputDirectory() { .resolve("output", StandardResolveOptions.RESOLVE_DIRECTORY); } - private SimpleSink makeSimpleSink() { - FilenamePolicy filenamePolicy = new PerWindowFiles("file", "simple"); - return new SimpleSink(getBaseOutputDirectory(), filenamePolicy); + + private SimpleSink makeSimpleSink() { + FilenamePolicy filenamePolicy = + new PerWindowFiles( + getBaseOutputDirectory().resolve("file", StandardResolveOptions.RESOLVE_FILE), + "simple"); + return SimpleSink.makeSimpleSink(getBaseOutputDirectory(), filenamePolicy); } @Test @@ -219,8 +239,10 @@ public void testCustomShardedWrite() throws IOException { timestamps.add(i + 1); } - SimpleSink sink = makeSimpleSink(); - WriteFiles write = WriteFiles.to(sink).withSharding(new LargestInt()); + SimpleSink sink = makeSimpleSink(); + WriteFiles write = + WriteFiles.to(sink, SerializableFunctions.identity()) + .withSharding(new LargestInt()); p.apply(Create.timestamped(inputs, timestamps).withCoder(StringUtf8Coder.of())) .apply(IDENTITY_MAP) .apply(write); @@ -241,7 +263,8 @@ public void testExpandShardedWrite() throws IOException { Arrays.asList("one", "two", "three", "four", "five", "six"), IDENTITY_MAP, getBaseOutputFilename(), - WriteFiles.to(makeSimpleSink()).withNumShards(20)); + WriteFiles.to(makeSimpleSink(), SerializableFunctions.identity()) + .withNumShards(20)); } /** @@ -251,7 +274,11 @@ public void testExpandShardedWrite() throws IOException { @Category(NeedsRunner.class) public void testWriteWithEmptyPCollection() throws IOException { List inputs = new ArrayList<>(); - runWrite(inputs, IDENTITY_MAP, getBaseOutputFilename(), WriteFiles.to(makeSimpleSink())); + runWrite( + inputs, + IDENTITY_MAP, + getBaseOutputFilename(), + WriteFiles.to(makeSimpleSink(), SerializableFunctions.identity())); } /** @@ -263,8 +290,10 @@ public void testWriteWindowed() throws IOException { List inputs = Arrays.asList("Critical canary", "Apprehensive eagle", "Intimidating pigeon", "Pedantic gull", "Frisky finch"); runWrite( - inputs, new WindowAndReshuffle<>(Window.into(FixedWindows.of(Duration.millis(2)))), - getBaseOutputFilename(), WriteFiles.to(makeSimpleSink())); + inputs, + new WindowAndReshuffle<>(Window.into(FixedWindows.of(Duration.millis(2)))), + getBaseOutputFilename(), + WriteFiles.to(makeSimpleSink(), SerializableFunctions.identity())); } /** @@ -278,10 +307,9 @@ public void testWriteWithSessions() throws IOException { runWrite( inputs, - new WindowAndReshuffle<>( - Window.into(Sessions.withGapDuration(Duration.millis(1)))), + new WindowAndReshuffle<>(Window.into(Sessions.withGapDuration(Duration.millis(1)))), getBaseOutputFilename(), - WriteFiles.to(makeSimpleSink())); + WriteFiles.to(makeSimpleSink(), SerializableFunctions.identity())); } @Test @@ -292,15 +320,19 @@ public void testWriteSpilling() throws IOException { inputs.add("mambo_number_" + i); } runWrite( - inputs, Window.into(FixedWindows.of(Duration.millis(2))), + inputs, + Window.into(FixedWindows.of(Duration.millis(2))), getBaseOutputFilename(), - WriteFiles.to(makeSimpleSink()).withMaxNumWritersPerBundle(2).withWindowedWrites()); + WriteFiles.to(makeSimpleSink(), SerializableFunctions.identity()) + .withMaxNumWritersPerBundle(2) + .withWindowedWrites()); } public void testBuildWrite() { - SimpleSink sink = makeSimpleSink(); - WriteFiles write = WriteFiles.to(sink).withNumShards(3); - assertThat((SimpleSink) write.getSink(), is(sink)); + SimpleSink sink = makeSimpleSink(); + WriteFiles write = + WriteFiles.to(sink, SerializableFunctions.identity()).withNumShards(3); + assertThat((SimpleSink) write.getSink(), is(sink)); PTransform, PCollectionView> originalSharding = write.getSharding(); @@ -309,40 +341,183 @@ public void testBuildWrite() { assertThat(write.getNumShards().get(), equalTo(3)); assertThat(write.getSharding(), equalTo(originalSharding)); - WriteFiles write2 = write.withSharding(SHARDING_TRANSFORM); - assertThat((SimpleSink) write2.getSink(), is(sink)); + WriteFiles write2 = write.withSharding(SHARDING_TRANSFORM); + assertThat((SimpleSink) write2.getSink(), is(sink)); assertThat(write2.getSharding(), equalTo(SHARDING_TRANSFORM)); // original unchanged - WriteFiles writeUnsharded = write2.withRunnerDeterminedSharding(); + WriteFiles writeUnsharded = write2.withRunnerDeterminedSharding(); assertThat(writeUnsharded.getSharding(), nullValue()); assertThat(write.getSharding(), equalTo(originalSharding)); } @Test public void testDisplayData() { - SimpleSink sink = new SimpleSink(getBaseOutputDirectory(), "file", "-SS-of-NN", "") { - @Override - public void populateDisplayData(DisplayData.Builder builder) { - builder.add(DisplayData.item("foo", "bar")); - } - }; - WriteFiles write = WriteFiles.to(sink); + DynamicDestinations dynamicDestinations = + DynamicFileDestinations.constant( + DefaultFilenamePolicy.fromParams( + new Params() + .withBaseFilename( + getBaseOutputDirectory() + .resolve("file", StandardResolveOptions.RESOLVE_FILE)) + .withShardTemplate("-SS-of-NN"))); + SimpleSink sink = + new SimpleSink( + getBaseOutputDirectory(), dynamicDestinations, CompressionType.UNCOMPRESSED) { + @Override + public void populateDisplayData(DisplayData.Builder builder) { + builder.add(DisplayData.item("foo", "bar")); + } + }; + WriteFiles write = + WriteFiles.to(sink, SerializableFunctions.identity()); + DisplayData displayData = DisplayData.from(write); assertThat(displayData, hasDisplayItem("sink", sink.getClass())); assertThat(displayData, includesDisplayDataFor("sink", sink)); } + @Test + @Category(NeedsRunner.class) + public void testUnboundedNeedsWindowed() { + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage( + "Must use windowed writes when applying WriteFiles to an unbounded PCollection"); + + SimpleSink sink = makeSimpleSink(); + p.apply(Create.of("foo")) + .setIsBoundedInternal(IsBounded.UNBOUNDED) + .apply(WriteFiles.to(sink, SerializableFunctions.identity())); + p.run(); + } + + @Test + @Category(NeedsRunner.class) + public void testUnboundedNeedsSharding() { + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage( + "When applying WriteFiles to an unbounded PCollection, " + + "must specify number of output shards explicitly"); + + SimpleSink sink = makeSimpleSink(); + p.apply(Create.of("foo")) + .setIsBoundedInternal(IsBounded.UNBOUNDED) + .apply(WriteFiles.to(sink, SerializableFunctions.identity()).withWindowedWrites()); + p.run(); + } + + // Test DynamicDestinations class. Expects user values to be string-encoded integers. + // Stores the integer mod 5 as the destination, and uses that in the file prefix. + static class TestDestinations extends DynamicDestinations { + private ResourceId baseOutputDirectory; + + TestDestinations(ResourceId baseOutputDirectory) { + this.baseOutputDirectory = baseOutputDirectory; + } + + @Override + public Integer getDestination(String element) { + return Integer.valueOf(element) % 5; + } + + @Override + public Integer getDefaultDestination() { + return 0; + } + + @Override + public FilenamePolicy getFilenamePolicy(Integer destination) { + return new PerWindowFiles( + baseOutputDirectory.resolve("file_" + destination, StandardResolveOptions.RESOLVE_FILE), + "simple"); + } + + @Override + public void populateDisplayData(Builder builder) { + super.populateDisplayData(builder); + } + } + + // Test format function. Prepend a string to each record before writing. + static class TestDynamicFormatFunction implements SerializableFunction { + @Override + public String apply(String input) { + return "record_" + input; + } + } + + @Test + @Category(NeedsRunner.class) + public void testDynamicDestinationsBounded() throws Exception { + testDynamicDestinationsHelper(true); + } + + @Test + @Category(NeedsRunner.class) + public void testDynamicDestinationsUnbounded() throws Exception { + testDynamicDestinationsHelper(false); + } + + private void testDynamicDestinationsHelper(boolean bounded) throws IOException { + TestDestinations dynamicDestinations = new TestDestinations(getBaseOutputDirectory()); + SimpleSink sink = + new SimpleSink<>( + getBaseOutputDirectory(), dynamicDestinations, CompressionType.UNCOMPRESSED); + + // Flag to validate that the pipeline options are passed to the Sink. + WriteOptions options = TestPipeline.testingPipelineOptions().as(WriteOptions.class); + options.setTestFlag("test_value"); + Pipeline p = TestPipeline.create(options); + + List inputs = Lists.newArrayList("0", "1", "2", "3", "4", "5", "6", "7", "8", "9"); + // Prepare timestamps for the elements. + List timestamps = new ArrayList<>(); + for (long i = 0; i < inputs.size(); i++) { + timestamps.add(i + 1); + } + + WriteFiles writeFiles = + WriteFiles.to(sink, new TestDynamicFormatFunction()).withNumShards(1); + + PCollection input = p.apply(Create.timestamped(inputs, timestamps)); + if (!bounded) { + input.setIsBoundedInternal(IsBounded.UNBOUNDED); + input = input.apply(Window.into(FixedWindows.of(Duration.standardDays(1)))); + input.apply(writeFiles.withWindowedWrites()); + } else { + input.apply(writeFiles); + } + p.run(); + + for (int i = 0; i < 5; ++i) { + ResourceId base = + getBaseOutputDirectory().resolve("file_" + i, StandardResolveOptions.RESOLVE_FILE); + List expected = Lists.newArrayList("record_" + i, "record_" + (i + 5)); + checkFileContents(base.toString(), expected, Optional.of(1)); + } + } + @Test public void testShardedDisplayData() { - SimpleSink sink = new SimpleSink(getBaseOutputDirectory(), "file", "-SS-of-NN", "") { - @Override - public void populateDisplayData(DisplayData.Builder builder) { - builder.add(DisplayData.item("foo", "bar")); - } - }; - WriteFiles write = WriteFiles.to(sink).withNumShards(1); + DynamicDestinations dynamicDestinations = + DynamicFileDestinations.constant( + DefaultFilenamePolicy.fromParams( + new Params() + .withBaseFilename( + getBaseOutputDirectory() + .resolve("file", StandardResolveOptions.RESOLVE_FILE)) + .withShardTemplate("-SS-of-NN"))); + SimpleSink sink = + new SimpleSink( + getBaseOutputDirectory(), dynamicDestinations, CompressionType.UNCOMPRESSED) { + @Override + public void populateDisplayData(DisplayData.Builder builder) { + builder.add(DisplayData.item("foo", "bar")); + } + }; + WriteFiles write = + WriteFiles.to(sink, SerializableFunctions.identity()).withNumShards(1); DisplayData displayData = DisplayData.from(write); assertThat(displayData, hasDisplayItem("sink", sink.getClass())); assertThat(displayData, includesDisplayDataFor("sink", sink)); @@ -351,14 +526,24 @@ public void populateDisplayData(DisplayData.Builder builder) { @Test public void testCustomShardStrategyDisplayData() { - SimpleSink sink = new SimpleSink(getBaseOutputDirectory(), "file", "-SS-of-NN", "") { - @Override - public void populateDisplayData(DisplayData.Builder builder) { - builder.add(DisplayData.item("foo", "bar")); - } - }; - WriteFiles write = - WriteFiles.to(sink) + DynamicDestinations dynamicDestinations = + DynamicFileDestinations.constant( + DefaultFilenamePolicy.fromParams( + new Params() + .withBaseFilename( + getBaseOutputDirectory() + .resolve("file", StandardResolveOptions.RESOLVE_FILE)) + .withShardTemplate("-SS-of-NN"))); + SimpleSink sink = + new SimpleSink( + getBaseOutputDirectory(), dynamicDestinations, CompressionType.UNCOMPRESSED) { + @Override + public void populateDisplayData(DisplayData.Builder builder) { + builder.add(DisplayData.item("foo", "bar")); + } + }; + WriteFiles write = + WriteFiles.to(sink, SerializableFunctions.identity()) .withSharding( new PTransform, PCollectionView>() { @Override @@ -383,59 +568,77 @@ public void populateDisplayData(DisplayData.Builder builder) { * PCollection are written to the sink. */ private void runWrite( - List inputs, PTransform, PCollection> transform, - String baseName, WriteFiles write) throws IOException { + List inputs, + PTransform, PCollection> transform, + String baseName, + WriteFiles write) + throws IOException { runShardedWrite(inputs, transform, baseName, write); } private static class PerWindowFiles extends FilenamePolicy { private static final DateTimeFormatter FORMATTER = ISODateTimeFormat.hourMinuteSecondMillis(); - private final String prefix; + private final ResourceId baseFilename; private final String suffix; - public PerWindowFiles(String prefix, String suffix) { - this.prefix = prefix; + public PerWindowFiles(ResourceId baseFilename, String suffix) { + this.baseFilename = baseFilename; this.suffix = suffix; } public String filenamePrefixForWindow(IntervalWindow window) { + String prefix = + baseFilename.isDirectory() ? "" : firstNonNull(baseFilename.getFilename(), ""); return String.format("%s%s-%s", prefix, FORMATTER.print(window.start()), FORMATTER.print(window.end())); } @Override - public ResourceId windowedFilename( - ResourceId outputDirectory, WindowedContext context, String extension) { + public ResourceId windowedFilename(WindowedContext context, OutputFileHints outputFileHints) { IntervalWindow window = (IntervalWindow) context.getWindow(); - String filename = String.format( - "%s-%s-of-%s%s%s", - filenamePrefixForWindow(window), context.getShardNumber(), context.getNumShards(), - extension, suffix); - return outputDirectory.resolve(filename, StandardResolveOptions.RESOLVE_FILE); + String filename = + String.format( + "%s-%s-of-%s%s%s", + filenamePrefixForWindow(window), + context.getShardNumber(), + context.getNumShards(), + outputFileHints.getSuggestedFilenameSuffix(), + suffix); + return baseFilename + .getCurrentDirectory() + .resolve(filename, StandardResolveOptions.RESOLVE_FILE); } @Override - public ResourceId unwindowedFilename( - ResourceId outputDirectory, Context context, String extension) { - String filename = String.format( - "%s%s-of-%s%s%s", - prefix, context.getShardNumber(), context.getNumShards(), - extension, suffix); - return outputDirectory.resolve(filename, StandardResolveOptions.RESOLVE_FILE); + public ResourceId unwindowedFilename(Context context, OutputFileHints outputFileHints) { + String prefix = + baseFilename.isDirectory() ? "" : firstNonNull(baseFilename.getFilename(), ""); + String filename = + String.format( + "%s-%s-of-%s%s%s", + prefix, + context.getShardNumber(), + context.getNumShards(), + outputFileHints.getSuggestedFilenameSuffix(), + suffix); + return baseFilename + .getCurrentDirectory() + .resolve(filename, StandardResolveOptions.RESOLVE_FILE); } } /** * Performs a WriteFiles transform with the desired number of shards. Verifies the WriteFiles * transform calls the appropriate methods on a test sink in the correct order, as well as - * verifies that the elements of a PCollection are written to the sink. If numConfiguredShards - * is not null, also verifies that the output number of shards is correct. + * verifies that the elements of a PCollection are written to the sink. If numConfiguredShards is + * not null, also verifies that the output number of shards is correct. */ private void runShardedWrite( List inputs, PTransform, PCollection> transform, String baseName, - WriteFiles write) throws IOException { + WriteFiles write) + throws IOException { // Flag to validate that the pipeline options are passed to the Sink WriteOptions options = TestPipeline.testingPipelineOptions().as(WriteOptions.class); options.setTestFlag("test_value"); diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BatchLoads.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BatchLoads.java index 4393a63d86e33..e46b1d3f945be 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BatchLoads.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BatchLoads.java @@ -32,6 +32,7 @@ import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.coders.ListCoder; import org.apache.beam.sdk.coders.NullableCoder; +import org.apache.beam.sdk.coders.ShardedKeyCoder; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.coders.VoidCoder; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.CreateDisposition; @@ -57,6 +58,7 @@ import org.apache.beam.sdk.values.PCollectionList; import org.apache.beam.sdk.values.PCollectionTuple; import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.ShardedKey; import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.TupleTagList; import org.apache.beam.sdk.values.TypeDescriptor; diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/DynamicDestinations.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/DynamicDestinations.java index edb1e0d982e09..c5c2462be3276 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/DynamicDestinations.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/DynamicDestinations.java @@ -23,8 +23,7 @@ import com.google.api.services.bigquery.model.TableSchema; import com.google.common.collect.Lists; import java.io.Serializable; -import java.lang.reflect.ParameterizedType; -import java.lang.reflect.Type; +import java.lang.reflect.TypeVariable; import java.util.List; import javax.annotation.Nullable; import org.apache.beam.sdk.coders.CannotProvideCoderException; @@ -32,6 +31,7 @@ import org.apache.beam.sdk.coders.CoderRegistry; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.sdk.values.ValueInSingleWindow; /** @@ -158,21 +158,16 @@ Coder getDestinationCoderWithDefault(CoderRegistry registry) } // If dynamicDestinations doesn't provide a coder, try to find it in the coder registry. // We must first use reflection to figure out what the type parameter is. - for (Type superclass = getClass().getGenericSuperclass(); - superclass != null; - superclass = ((Class) superclass).getGenericSuperclass()) { - if (superclass instanceof ParameterizedType) { - ParameterizedType parameterized = (ParameterizedType) superclass; - if (parameterized.getRawType() == DynamicDestinations.class) { - // DestinationT is the second parameter. - Type parameter = parameterized.getActualTypeArguments()[1]; - @SuppressWarnings("unchecked") - Class parameterClass = (Class) parameter; - return registry.getCoder(parameterClass); - } - } + TypeDescriptor superDescriptor = + TypeDescriptor.of(getClass()).getSupertype(DynamicDestinations.class); + if (!superDescriptor.getRawType().equals(DynamicDestinations.class)) { + throw new AssertionError( + "Couldn't find the DynamicDestinations superclass of " + this.getClass()); } - throw new AssertionError( - "Couldn't find the DynamicDestinations superclass of " + this.getClass()); + TypeVariable typeVariable = superDescriptor.getTypeParameter("DestinationT"); + @SuppressWarnings("unchecked") + TypeDescriptor descriptor = + (TypeDescriptor) superDescriptor.resolveType(typeVariable); + return registry.getCoder(descriptor); } } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/GenerateShardedTable.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/GenerateShardedTable.java index 90d41a0778bb2..55672ff7811e6 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/GenerateShardedTable.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/GenerateShardedTable.java @@ -23,6 +23,7 @@ import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.ShardedKey; /** * Given a write to a specific table, assign that to one of the diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StreamingWriteFn.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StreamingWriteFn.java index 63e5bc1ceb4f8..a210858640322 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StreamingWriteFn.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StreamingWriteFn.java @@ -33,6 +33,7 @@ import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.util.SystemDoFnInternal; import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.ShardedKey; import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.ValueInSingleWindow; diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StreamingWriteTables.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StreamingWriteTables.java index 18b203379c7c5..fa5b3ce12a088 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StreamingWriteTables.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StreamingWriteTables.java @@ -19,6 +19,7 @@ import com.google.api.services.bigquery.model.TableRow; import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.coders.ShardedKeyCoder; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; @@ -29,6 +30,7 @@ import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionTuple; +import org.apache.beam.sdk.values.ShardedKey; import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.TupleTagList; diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/TagWithUniqueIds.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/TagWithUniqueIds.java index cd88222da69f3..51b9375a587ff 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/TagWithUniqueIds.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/TagWithUniqueIds.java @@ -26,6 +26,7 @@ import org.apache.beam.sdk.transforms.display.DisplayData; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.ShardedKey; /** * Fn that tags each table row with a unique id and destination table. To avoid calling diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteBundlesToFiles.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteBundlesToFiles.java index d68779a273f6e..e1ed746b40382 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteBundlesToFiles.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteBundlesToFiles.java @@ -19,6 +19,7 @@ package org.apache.beam.sdk.io.gcp.bigquery; import static com.google.common.base.Preconditions.checkNotNull; + import com.google.api.services.bigquery.model.TableRow; import com.google.common.collect.Lists; import com.google.common.collect.Maps; @@ -40,6 +41,7 @@ import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.ShardedKey; import org.apache.beam.sdk.values.TupleTag; import org.slf4j.Logger; import org.slf4j.LoggerFactory; diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteGroupedRecordsToFiles.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteGroupedRecordsToFiles.java index 45dc2a83898b8..887cb9377442e 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteGroupedRecordsToFiles.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteGroupedRecordsToFiles.java @@ -22,6 +22,7 @@ import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.ShardedKey; import org.slf4j.Logger; import org.slf4j.LoggerFactory; diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WritePartition.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WritePartition.java index acd113296c839..451d1bddd606a 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WritePartition.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WritePartition.java @@ -26,6 +26,7 @@ import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.ShardedKey; import org.apache.beam.sdk.values.TupleTag; /** diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteTables.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteTables.java index c5494d8834835..9ed2916b36a48 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteTables.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteTables.java @@ -42,6 +42,7 @@ import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.ShardedKey; import org.slf4j.Logger; import org.slf4j.LoggerFactory; diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOTest.java index bfd260a308791..d31f3a09e8593 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOTest.java @@ -82,6 +82,7 @@ import org.apache.beam.sdk.coders.CoderException; import org.apache.beam.sdk.coders.IterableCoder; import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.coders.ShardedKeyCoder; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.coders.VarIntCoder; import org.apache.beam.sdk.io.BoundedSource; @@ -131,6 +132,7 @@ import org.apache.beam.sdk.values.PCollection.IsBounded; import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.PCollectionViews; +import org.apache.beam.sdk.values.ShardedKey; import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.sdk.values.ValueInSingleWindow; diff --git a/sdks/java/io/xml/src/main/java/org/apache/beam/sdk/io/xml/XmlIO.java b/sdks/java/io/xml/src/main/java/org/apache/beam/sdk/io/xml/XmlIO.java index 7255a94357eb4..442fba5c0bfcb 100644 --- a/sdks/java/io/xml/src/main/java/org/apache/beam/sdk/io/xml/XmlIO.java +++ b/sdks/java/io/xml/src/main/java/org/apache/beam/sdk/io/xml/XmlIO.java @@ -36,6 +36,7 @@ import org.apache.beam.sdk.options.ValueProvider; import org.apache.beam.sdk.options.ValueProvider.StaticValueProvider; import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.SerializableFunctions; import org.apache.beam.sdk.transforms.display.DisplayData; import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; @@ -521,7 +522,8 @@ public void validate(PipelineOptions options) { @Override public PDone expand(PCollection input) { - return input.apply(org.apache.beam.sdk.io.WriteFiles.to(createSink())); + return input.apply( + org.apache.beam.sdk.io.WriteFiles.to(createSink(), SerializableFunctions.identity())); } @VisibleForTesting diff --git a/sdks/java/io/xml/src/main/java/org/apache/beam/sdk/io/xml/XmlSink.java b/sdks/java/io/xml/src/main/java/org/apache/beam/sdk/io/xml/XmlSink.java index 6ae83f2830661..74e0bda25e77e 100644 --- a/sdks/java/io/xml/src/main/java/org/apache/beam/sdk/io/xml/XmlSink.java +++ b/sdks/java/io/xml/src/main/java/org/apache/beam/sdk/io/xml/XmlSink.java @@ -25,6 +25,7 @@ import javax.xml.bind.Marshaller; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.io.DefaultFilenamePolicy; +import org.apache.beam.sdk.io.DynamicFileDestinations; import org.apache.beam.sdk.io.FileBasedSink; import org.apache.beam.sdk.io.ShardNameTemplate; import org.apache.beam.sdk.io.fs.ResourceId; @@ -34,18 +35,18 @@ import org.apache.beam.sdk.util.MimeTypes; /** Implementation of {@link XmlIO#write}. */ -class XmlSink extends FileBasedSink { +class XmlSink extends FileBasedSink { private static final String XML_EXTENSION = ".xml"; private final XmlIO.Write spec; - private static DefaultFilenamePolicy makeFilenamePolicy(XmlIO.Write spec) { - return DefaultFilenamePolicy.constructUsingStandardParameters( + private static DefaultFilenamePolicy makeFilenamePolicy(XmlIO.Write spec) { + return DefaultFilenamePolicy.fromStandardParameters( spec.getFilenamePrefix(), ShardNameTemplate.INDEX_OF_MAX, XML_EXTENSION, false); } XmlSink(XmlIO.Write spec) { - super(spec.getFilenamePrefix(), makeFilenamePolicy(spec)); + super(spec.getFilenamePrefix(), DynamicFileDestinations.constant(makeFilenamePolicy(spec))); this.spec = spec; } @@ -75,10 +76,8 @@ void populateFileBasedDisplayData(DisplayData.Builder builder) { super.populateDisplayData(builder); } - /** - * {@link WriteOperation} for XML {@link FileBasedSink}s. - */ - protected static final class XmlWriteOperation extends WriteOperation { + /** {@link WriteOperation} for XML {@link FileBasedSink}s. */ + protected static final class XmlWriteOperation extends WriteOperation { public XmlWriteOperation(XmlSink sink) { super(sink); } @@ -112,10 +111,8 @@ ResourceId getTemporaryDirectory() { } } - /** - * A {@link Writer} that can write objects as XML elements. - */ - protected static final class XmlWriter extends Writer { + /** A {@link Writer} that can write objects as XML elements. */ + protected static final class XmlWriter extends Writer { final Marshaller marshaller; private OutputStream os = null; diff --git a/sdks/java/io/xml/src/test/java/org/apache/beam/sdk/io/xml/XmlSinkTest.java b/sdks/java/io/xml/src/test/java/org/apache/beam/sdk/io/xml/XmlSinkTest.java index aa0c1c3dbfd3d..d1584dc1d8e02 100644 --- a/sdks/java/io/xml/src/test/java/org/apache/beam/sdk/io/xml/XmlSinkTest.java +++ b/sdks/java/io/xml/src/test/java/org/apache/beam/sdk/io/xml/XmlSinkTest.java @@ -197,8 +197,8 @@ public void testDisplayData() { .withRecordClass(Integer.class); DisplayData displayData = DisplayData.from(write); - - assertThat(displayData, hasDisplayItem("filenamePattern", "file-SSSSS-of-NNNNN.xml")); + assertThat( + displayData, hasDisplayItem("filenamePattern", "/path/to/file-SSSSS-of-NNNNN" + ".xml")); assertThat(displayData, hasDisplayItem("rootElement", "bird")); assertThat(displayData, hasDisplayItem("recordClass", Integer.class)); } From 20ce0756c97f5ed47ad9c8cb46da574c273b5b46 Mon Sep 17 00:00:00 2001 From: Kenneth Knowles Date: Thu, 6 Jul 2017 09:24:22 -0700 Subject: [PATCH 183/200] Rehydrate PCollections --- .../construction/PCollectionTranslation.java | 16 ++++++++++++++ .../PCollectionTranslationTest.java | 22 +++++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PCollectionTranslation.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PCollectionTranslation.java index 968966f459e1d..52526bbebe75e 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PCollectionTranslation.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PCollectionTranslation.java @@ -20,6 +20,7 @@ import com.google.protobuf.InvalidProtocolBufferException; import java.io.IOException; +import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.common.runner.v1.RunnerApi; import org.apache.beam.sdk.values.PCollection; @@ -47,6 +48,21 @@ public static RunnerApi.PCollection toProto(PCollection pCollection, SdkCompo .build(); } + public static PCollection fromProto( + Pipeline pipeline, RunnerApi.PCollection pCollection, RunnerApi.Components components) + throws IOException { + return PCollection.createPrimitiveOutputInternal( + pipeline, + WindowingStrategyTranslation.fromProto( + components.getWindowingStrategiesOrThrow(pCollection.getWindowingStrategyId()), + components), + fromProto(pCollection.getIsBounded())) + .setCoder( + (Coder) + CoderTranslation.fromProto( + components.getCodersOrThrow(pCollection.getCoderId()), components)); + } + public static IsBounded isBounded(RunnerApi.PCollection pCollection) { return fromProto(pCollection.getIsBounded()); } diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/PCollectionTranslationTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/PCollectionTranslationTest.java index 3b942206b1123..5c4548705059b 100644 --- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/PCollectionTranslationTest.java +++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/PCollectionTranslationTest.java @@ -113,6 +113,28 @@ public static Iterable> data() { @Test public void testEncodeDecodeCycle() throws Exception { + // Encode + SdkComponents sdkComponents = SdkComponents.create(); + RunnerApi.PCollection protoCollection = + PCollectionTranslation.toProto(testCollection, sdkComponents); + RunnerApi.Components protoComponents = sdkComponents.toComponents(); + + // Decode + Pipeline pipeline = Pipeline.create(); + PCollection decodedCollection = + PCollectionTranslation.fromProto(pipeline, protoCollection, protoComponents); + + // Verify + assertThat(decodedCollection.getCoder(), Matchers.>equalTo(testCollection.getCoder())); + assertThat( + decodedCollection.getWindowingStrategy(), + Matchers.>equalTo( + testCollection.getWindowingStrategy().fixDefaults())); + assertThat(decodedCollection.isBounded(), equalTo(testCollection.isBounded())); + } + + @Test + public void testEncodeDecodeFields() throws Exception { SdkComponents sdkComponents = SdkComponents.create(); RunnerApi.PCollection protoCollection = PCollectionTranslation .toProto(testCollection, sdkComponents); From 165dfa688beaeb2de9b5041c81f6e02b517f74fd Mon Sep 17 00:00:00 2001 From: Kenneth Knowles Date: Thu, 8 Jun 2017 13:46:18 -0700 Subject: [PATCH 184/200] Add more utilities to ParDoTranslation --- .../core/construction/ParDoTranslation.java | 48 +++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java index 34e0d86f1bffb..5f2bcae9e5c6c 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java @@ -34,9 +34,11 @@ import com.google.protobuf.InvalidProtocolBufferException; import java.io.IOException; import java.io.Serializable; +import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.Set; import org.apache.beam.runners.core.construction.PTransformTranslation.TransformPayloadTranslator; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.IterableCoder; @@ -74,6 +76,7 @@ import org.apache.beam.sdk.util.WindowedValue.FullWindowedValueCoder; import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.TupleTagList; import org.apache.beam.sdk.values.WindowingStrategy; /** @@ -215,11 +218,56 @@ private static TimerSpec getTimerSpecOrCrash( return doFnAndMainOutputTagFromProto(payload.getDoFn()).getDoFn(); } + public static DoFn getDoFn(AppliedPTransform application) throws IOException { + return getDoFn(getParDoPayload(application)); + } + public static TupleTag getMainOutputTag(ParDoPayload payload) throws InvalidProtocolBufferException { return doFnAndMainOutputTagFromProto(payload.getDoFn()).getMainOutputTag(); } + public static TupleTag getMainOutputTag(AppliedPTransform application) + throws IOException { + return getMainOutputTag(getParDoPayload(application)); + } + + public static TupleTagList getAdditionalOutputTags(AppliedPTransform application) + throws IOException { + + RunnerApi.PTransform protoTransform = + PTransformTranslation.toProto(application, SdkComponents.create()); + + ParDoPayload payload = protoTransform.getSpec().getParameter().unpack(ParDoPayload.class); + TupleTag mainOutputTag = getMainOutputTag(payload); + Set outputTags = + Sets.difference( + protoTransform.getOutputsMap().keySet(), Collections.singleton(mainOutputTag.getId())); + + ArrayList> additionalOutputTags = new ArrayList<>(); + for (String outputTag : outputTags) { + additionalOutputTags.add(new TupleTag<>(outputTag)); + } + return TupleTagList.of(additionalOutputTags); + } + + public static List> getSideInputs(AppliedPTransform application) + throws IOException { + + SdkComponents sdkComponents = SdkComponents.create(); + RunnerApi.PTransform parDoProto = + PTransformTranslation.toProto(application, sdkComponents); + ParDoPayload payload = parDoProto.getSpec().getParameter().unpack(ParDoPayload.class); + + List> views = new ArrayList<>(); + for (Map.Entry sideInput : payload.getSideInputsMap().entrySet()) { + views.add( + fromProto( + sideInput.getValue(), sideInput.getKey(), parDoProto, sdkComponents.toComponents())); + } + return views; + } + public static RunnerApi.PCollection getMainInput( RunnerApi.PTransform ptransform, Components components) throws IOException { checkArgument( From de39f324c4b0914418894a41c6f75596310bf633 Mon Sep 17 00:00:00 2001 From: Kenneth Knowles Date: Thu, 6 Jul 2017 09:24:55 -0700 Subject: [PATCH 185/200] Include PCollection in rehydrated PCollectionView --- .../core/construction/ParDoTranslation.java | 51 ++++++++++++++++--- .../construction/RunnerPCollectionView.java | 7 +-- .../construction/ParDoTranslationTest.java | 28 ++++++---- 3 files changed, 67 insertions(+), 19 deletions(-) diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java index 5f2bcae9e5c6c..fe8c5aad247b3 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java @@ -40,6 +40,7 @@ import java.util.Map; import java.util.Set; import org.apache.beam.runners.core.construction.PTransformTranslation.TransformPayloadTranslator; +import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.IterableCoder; import org.apache.beam.sdk.common.runner.v1.RunnerApi; @@ -74,6 +75,7 @@ import org.apache.beam.sdk.util.SerializableUtils; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.util.WindowedValue.FullWindowedValueCoder; +import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.TupleTagList; @@ -262,8 +264,12 @@ public static List> getSideInputs(AppliedPTransform List> views = new ArrayList<>(); for (Map.Entry sideInput : payload.getSideInputsMap().entrySet()) { views.add( - fromProto( - sideInput.getValue(), sideInput.getKey(), parDoProto, sdkComponents.toComponents())); + viewFromProto( + application.getPipeline(), + sideInput.getValue(), + sideInput.getKey(), + parDoProto, + sdkComponents.toComponents())); } return views; } @@ -495,15 +501,47 @@ private static SideInput toProto(PCollectionView view) { return builder.build(); } - public static PCollectionView fromProto( - SideInput sideInput, String id, RunnerApi.PTransform parDoTransform, Components components) + public static PCollectionView viewFromProto( + Pipeline pipeline, + SideInput sideInput, + String localName, + RunnerApi.PTransform parDoTransform, + Components components) throws IOException { - TupleTag tag = new TupleTag<>(id); + + String pCollectionId = parDoTransform.getInputsOrThrow(localName); + + // This may be a PCollection defined in another language, but we should be + // able to rehydrate it enough to stick it in a side input. The coder may not + // be grokkable in Java. + PCollection pCollection = + PCollectionTranslation.fromProto( + pipeline, components.getPcollectionsOrThrow(pCollectionId), components); + + return viewFromProto(sideInput, localName, pCollection, parDoTransform, components); + } + + /** + * Create a {@link PCollectionView} from a side input spec and an already-deserialized {@link + * PCollection} that should be wired up. + */ + public static PCollectionView viewFromProto( + SideInput sideInput, + String localName, + PCollection pCollection, + RunnerApi.PTransform parDoTransform, + Components components) + throws IOException { + checkArgument( + localName != null, + "%s.viewFromProto: localName must not be null", + ParDoTranslation.class.getSimpleName()); + TupleTag tag = new TupleTag<>(localName); WindowMappingFn windowMappingFn = windowMappingFnFromProto(sideInput.getWindowMappingFn()); ViewFn viewFn = viewFnFromProto(sideInput.getViewFn()); RunnerApi.PCollection inputCollection = - components.getPcollectionsOrThrow(parDoTransform.getInputsOrThrow(id)); + components.getPcollectionsOrThrow(parDoTransform.getInputsOrThrow(localName)); WindowingStrategy windowingStrategy = WindowingStrategyTranslation.fromProto( components.getWindowingStrategiesOrThrow(inputCollection.getWindowingStrategyId()), @@ -523,6 +561,7 @@ public static PCollectionView fromProto( PCollectionView view = new RunnerPCollectionView<>( + pCollection, (TupleTag>>) tag, (ViewFn>, ?>) viewFn, windowMappingFn, diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/RunnerPCollectionView.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/RunnerPCollectionView.java index c359cecce361d..b27518824b8c4 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/RunnerPCollectionView.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/RunnerPCollectionView.java @@ -39,16 +39,19 @@ class RunnerPCollectionView extends PValueBase implements PCollectionView private final WindowMappingFn windowMappingFn; private final WindowingStrategy windowingStrategy; private final Coder>> coder; + private final transient PCollection pCollection; /** * Create a new {@link RunnerPCollectionView} from the provided components. */ RunnerPCollectionView( + PCollection pCollection, TupleTag>> tag, ViewFn>, T> viewFn, WindowMappingFn windowMappingFn, @Nullable WindowingStrategy windowingStrategy, @Nullable Coder>> coder) { + this.pCollection = pCollection; this.tag = tag; this.viewFn = viewFn; this.windowMappingFn = windowMappingFn; @@ -56,11 +59,9 @@ class RunnerPCollectionView extends PValueBase implements PCollectionView this.coder = coder; } - @Nullable @Override public PCollection getPCollection() { - throw new IllegalStateException( - String.format("Cannot call getPCollection on a %s", getClass().getSimpleName())); + return pCollection; } @Override diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/ParDoTranslationTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/ParDoTranslationTest.java index a8490bf276d61..6fdf9d6ad8b73 100644 --- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/ParDoTranslationTest.java +++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/ParDoTranslationTest.java @@ -23,9 +23,9 @@ import static org.junit.Assert.assertThat; import com.google.common.collect.ImmutableList; -import java.util.Collections; import java.util.HashMap; import java.util.Map; +import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.coders.VarIntCoder; @@ -143,22 +143,30 @@ public void toAndFromTransformProto() throws Exception { inputs.putAll(parDo.getAdditionalInputs()); PCollectionTuple output = mainInput.apply(parDo); - SdkComponents components = SdkComponents.create(); - String transformId = - components.registerPTransform( + SdkComponents sdkComponents = SdkComponents.create(); + + // Encode + RunnerApi.PTransform protoTransform = + PTransformTranslation.toProto( AppliedPTransform.>, PCollection, MultiOutput>of( "foo", inputs, output.expand(), parDo, p), - Collections.>emptyList()); + sdkComponents); + Components protoComponents = sdkComponents.toComponents(); + + // Decode + Pipeline rehydratedPipeline = Pipeline.create(); - Components protoComponents = components.toComponents(); - RunnerApi.PTransform protoTransform = protoComponents.getTransformsOrThrow(transformId); ParDoPayload parDoPayload = protoTransform.getSpec().getParameter().unpack(ParDoPayload.class); for (PCollectionView view : parDo.getSideInputs()) { SideInput sideInput = parDoPayload.getSideInputsOrThrow(view.getTagInternal().getId()); PCollectionView restoredView = - ParDoTranslation.fromProto( - sideInput, view.getTagInternal().getId(), protoTransform, protoComponents); + ParDoTranslation.viewFromProto( + rehydratedPipeline, + sideInput, + view.getTagInternal().getId(), + protoTransform, + protoComponents); assertThat(restoredView.getTagInternal(), equalTo(view.getTagInternal())); assertThat(restoredView.getViewFn(), instanceOf(view.getViewFn().getClass())); assertThat( @@ -169,7 +177,7 @@ public void toAndFromTransformProto() throws Exception { view.getWindowingStrategyInternal().fixDefaults())); assertThat(restoredView.getCoderInternal(), equalTo(view.getCoderInternal())); } - String mainInputId = components.registerPCollection(mainInput); + String mainInputId = sdkComponents.registerPCollection(mainInput); assertThat( ParDoTranslation.getMainInput(protoTransform, protoComponents), equalTo(protoComponents.getPcollectionsOrThrow(mainInputId))); From fa61ed17424083fa53b8aa8e70908fb6194ad4ad Mon Sep 17 00:00:00 2001 From: Kenneth Knowles Date: Thu, 8 Jun 2017 14:27:02 -0700 Subject: [PATCH 186/200] Enable SplittableParDo on rehydrated ParDo transform --- .../core/construction/SplittableParDo.java | 25 +++++++++++++ .../direct/ParDoMultiOverrideFactory.java | 36 +++++++++++++------ .../FlinkStreamingPipelineTranslator.java | 2 +- 3 files changed, 52 insertions(+), 11 deletions(-) diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDo.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDo.java index f31b495739b8a..e71187be6c256 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDo.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDo.java @@ -19,6 +19,7 @@ import static com.google.common.base.Preconditions.checkArgument; +import java.io.IOException; import java.util.List; import java.util.Map; import java.util.UUID; @@ -26,6 +27,7 @@ import org.apache.beam.sdk.annotations.Experimental; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.runners.AppliedPTransform; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; @@ -103,6 +105,9 @@ private SplittableParDo( public static SplittableParDo forJavaParDo( ParDo.MultiOutput parDo) { checkArgument(parDo != null, "parDo must not be null"); + checkArgument( + DoFnSignatures.getSignature(parDo.getFn().getClass()).processElement().isSplittable(), + "fn must be a splittable DoFn"); return new SplittableParDo( parDo.getFn(), parDo.getMainOutputTag(), @@ -110,6 +115,26 @@ private SplittableParDo( parDo.getAdditionalOutputTags()); } + /** + * Creates the transform for a {@link ParDo}-compatible {@link AppliedPTransform}. + * + *

    The input may generally be a deserialized transform so it may not actually be a {@link + * ParDo}. Instead {@link ParDoTranslation} will be used to extract fields. + */ + public static SplittableParDo forAppliedParDo(AppliedPTransform parDo) { + checkArgument(parDo != null, "parDo must not be null"); + + try { + return new SplittableParDo<>( + ParDoTranslation.getDoFn(parDo), + (TupleTag) ParDoTranslation.getMainOutputTag(parDo), + ParDoTranslation.getSideInputs(parDo), + ParDoTranslation.getAdditionalOutputTags(parDo)); + } catch (IOException exc) { + throw new RuntimeException(exc); + } + } + @Override public PCollectionTuple expand(PCollection input) { Coder restrictionCoder = diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java index 2904bc170c442..888196765ad0d 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java @@ -19,6 +19,7 @@ import static com.google.common.base.Preconditions.checkState; +import java.io.IOException; import java.util.List; import java.util.Map; import org.apache.beam.runners.core.KeyedWorkItem; @@ -26,6 +27,7 @@ import org.apache.beam.runners.core.KeyedWorkItems; import org.apache.beam.runners.core.construction.PTransformReplacements; import org.apache.beam.runners.core.construction.PTransformTranslation; +import org.apache.beam.runners.core.construction.ParDoTranslation; import org.apache.beam.runners.core.construction.ReplacementOutputs; import org.apache.beam.runners.core.construction.SplittableParDo; import org.apache.beam.sdk.coders.Coder; @@ -62,29 +64,43 @@ */ class ParDoMultiOverrideFactory implements PTransformOverrideFactory< - PCollection, PCollectionTuple, MultiOutput> { + PCollection, PCollectionTuple, + PTransform, PCollectionTuple>> { @Override public PTransformReplacement, PCollectionTuple> getReplacementTransform( AppliedPTransform< - PCollection, PCollectionTuple, MultiOutput> - transform) { + PCollection, PCollectionTuple, + PTransform, PCollectionTuple>> + application) { return PTransformReplacement.of( - PTransformReplacements.getSingletonMainInput(transform), - getReplacementTransform(transform.getTransform())); + PTransformReplacements.getSingletonMainInput(application), + getReplacementForApplication(application)); } @SuppressWarnings("unchecked") - private PTransform, PCollectionTuple> getReplacementTransform( - MultiOutput transform) { + private PTransform, PCollectionTuple> getReplacementForApplication( + AppliedPTransform< + PCollection, PCollectionTuple, + PTransform, PCollectionTuple>> + application) { + + DoFn fn; + try { + fn = (DoFn) ParDoTranslation.getDoFn(application); + } catch (IOException exc) { + throw new RuntimeException(exc); + } - DoFn fn = transform.getFn(); DoFnSignature signature = DoFnSignatures.getSignature(fn.getClass()); if (signature.processElement().isSplittable()) { - return (PTransform) SplittableParDo.forJavaParDo(transform); + return (PTransform) SplittableParDo.forAppliedParDo(application); } else if (signature.stateDeclarations().size() > 0 || signature.timerDeclarations().size() > 0) { + MultiOutput transform = + (MultiOutput) application.getTransform(); + // Based on the fact that the signature is stateful, DoFnSignatures ensures // that it is also keyed return new GbkThenStatefulParDo( @@ -93,7 +109,7 @@ private PTransform, PCollectionTuple> getReplaceme transform.getAdditionalOutputTags(), transform.getSideInputs()); } else { - return transform; + return application.getTransform(); } } diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPipelineTranslator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPipelineTranslator.java index ebc934516181f..f733e2e7513a3 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPipelineTranslator.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPipelineTranslator.java @@ -188,7 +188,7 @@ static class SplittableParDoOverrideFactory transform) { return PTransformReplacement.of( PTransformReplacements.getSingletonMainInput(transform), - SplittableParDo.forJavaParDo(transform.getTransform())); + (SplittableParDo) SplittableParDo.forAppliedParDo(transform)); } @Override From 1ac4b7e6f7dbbd68c27c6634cd52767885a42760 Mon Sep 17 00:00:00 2001 From: Kenneth Knowles Date: Thu, 8 Jun 2017 13:44:52 -0700 Subject: [PATCH 187/200] Port DirectRunner ParDo overrides to SDK-agnostic APIs --- .../core/construction/ParDoTranslation.java | 16 ++++++--- .../construction/RunnerPCollectionView.java | 16 +++++++++ .../direct/ParDoMultiOverrideFactory.java | 35 ++++++++----------- 3 files changed, 43 insertions(+), 24 deletions(-) diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java index fe8c5aad247b3..90c9aadfdfb86 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java @@ -19,6 +19,7 @@ package org.apache.beam.runners.core.construction; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; import static org.apache.beam.runners.core.construction.PTransformTranslation.PAR_DO_TRANSFORM_URN; @@ -262,12 +263,19 @@ public static List> getSideInputs(AppliedPTransform ParDoPayload payload = parDoProto.getSpec().getParameter().unpack(ParDoPayload.class); List> views = new ArrayList<>(); - for (Map.Entry sideInput : payload.getSideInputsMap().entrySet()) { + for (Map.Entry sideInputEntry : payload.getSideInputsMap().entrySet()) { + String sideInputTag = sideInputEntry.getKey(); + RunnerApi.SideInput sideInput = sideInputEntry.getValue(); + PCollection originalPCollection = + checkNotNull( + (PCollection) application.getInputs().get(new TupleTag<>(sideInputTag)), + "no input with tag %s", + sideInputTag); views.add( viewFromProto( - application.getPipeline(), - sideInput.getValue(), - sideInput.getKey(), + sideInput, + sideInputTag, + originalPCollection, parDoProto, sdkComponents.toComponents())); } diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/RunnerPCollectionView.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/RunnerPCollectionView.java index b27518824b8c4..85139e8851ad4 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/RunnerPCollectionView.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/RunnerPCollectionView.java @@ -19,6 +19,7 @@ package org.apache.beam.runners.core.construction; import java.util.Map; +import java.util.Objects; import javax.annotation.Nullable; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.common.runner.v1.RunnerApi.SideInput; @@ -94,4 +95,19 @@ public Map, PValue> expand() { throw new UnsupportedOperationException(String.format( "A %s cannot be expanded", RunnerPCollectionView.class.getSimpleName())); } + + @Override + public boolean equals(Object other) { + if (!(other instanceof PCollectionView)) { + return false; + } + @SuppressWarnings("unchecked") + PCollectionView otherView = (PCollectionView) other; + return tag.equals(otherView.getTagInternal()); + } + + @Override + public int hashCode() { + return Objects.hash(tag); + } } diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java index 888196765ad0d..891d1020787b4 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java @@ -38,7 +38,6 @@ import org.apache.beam.sdk.transforms.GroupByKey; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; -import org.apache.beam.sdk.transforms.ParDo.MultiOutput; import org.apache.beam.sdk.transforms.reflect.DoFnSignature; import org.apache.beam.sdk.transforms.reflect.DoFnSignatures; import org.apache.beam.sdk.transforms.windowing.AfterPane; @@ -73,9 +72,14 @@ class ParDoMultiOverrideFactory PCollection, PCollectionTuple, PTransform, PCollectionTuple>> application) { - return PTransformReplacement.of( - PTransformReplacements.getSingletonMainInput(application), - getReplacementForApplication(application)); + + try { + return PTransformReplacement.of( + PTransformReplacements.getSingletonMainInput(application), + getReplacementForApplication(application)); + } catch (IOException exc) { + throw new RuntimeException(exc); + } } @SuppressWarnings("unchecked") @@ -83,31 +87,22 @@ private PTransform, PCollectionTuple> getReplaceme AppliedPTransform< PCollection, PCollectionTuple, PTransform, PCollectionTuple>> - application) { + application) + throws IOException { - DoFn fn; - try { - fn = (DoFn) ParDoTranslation.getDoFn(application); - } catch (IOException exc) { - throw new RuntimeException(exc); - } + DoFn fn = (DoFn) ParDoTranslation.getDoFn(application); DoFnSignature signature = DoFnSignatures.getSignature(fn.getClass()); + if (signature.processElement().isSplittable()) { return (PTransform) SplittableParDo.forAppliedParDo(application); } else if (signature.stateDeclarations().size() > 0 || signature.timerDeclarations().size() > 0) { - - MultiOutput transform = - (MultiOutput) application.getTransform(); - - // Based on the fact that the signature is stateful, DoFnSignatures ensures - // that it is also keyed return new GbkThenStatefulParDo( fn, - transform.getMainOutputTag(), - transform.getAdditionalOutputTags(), - transform.getSideInputs()); + ParDoTranslation.getMainOutputTag(application), + ParDoTranslation.getAdditionalOutputTags(application), + ParDoTranslation.getSideInputs(application)); } else { return application.getTransform(); } From be9a387976adf3424d680778f92ce22f728ffa32 Mon Sep 17 00:00:00 2001 From: Kenneth Knowles Date: Mon, 12 Jun 2017 15:11:49 -0700 Subject: [PATCH 188/200] Fix misleading comment in TransformHierarchy --- .../java/org/apache/beam/sdk/runners/TransformHierarchy.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformHierarchy.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformHierarchy.java index 9c5f14843c240..6f1ee94b9ba95 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformHierarchy.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformHierarchy.java @@ -406,7 +406,7 @@ public String getFullName() { return fullName; } - /** Returns the transform input, in unexpanded form. */ + /** Returns the transform input, in fully expanded form. */ public Map, PValue> getInputs() { return inputs == null ? Collections., PValue>emptyMap() : inputs; } From 1518d732e74c61d021509d2fc325427cb93e73e8 Mon Sep 17 00:00:00 2001 From: Kenneth Knowles Date: Mon, 12 Jun 2017 15:12:18 -0700 Subject: [PATCH 189/200] Fix null checks in TransformHierarchy --- .../apache/beam/sdk/runners/TransformHierarchy.java | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformHierarchy.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformHierarchy.java index 6f1ee94b9ba95..d8ff59e7b7b16 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformHierarchy.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformHierarchy.java @@ -145,14 +145,6 @@ public void finishSpecifyingInput() { Node producerNode = getProducer(inputValue); PInput input = producerInput.remove(inputValue); inputValue.finishSpecifying(input, producerNode.getTransform()); - checkState( - producers.get(inputValue) != null, - "Producer unknown for input %s", - inputValue); - checkState( - producers.get(inputValue) != null, - "Producer unknown for input %s", - inputValue); } } @@ -201,7 +193,7 @@ public void popNode() { } Node getProducer(PValue produced) { - return producers.get(produced); + return checkNotNull(producers.get(produced), "No producer found for %s", produced); } public Set visit(PipelineVisitor visitor) { From 3099e81428ea19cdb7e3ef6b35a5de462c598ef8 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Wed, 28 Jun 2017 18:20:12 -0700 Subject: [PATCH 190/200] Split bundle processor into separate class. --- .../runners/portability/fn_api_runner.py | 20 +- .../runners/worker/bundle_processor.py | 426 ++++++++++++++++++ .../apache_beam/runners/worker/sdk_worker.py | 398 +--------------- 3 files changed, 444 insertions(+), 400 deletions(-) create mode 100644 sdks/python/apache_beam/runners/worker/bundle_processor.py diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner.py b/sdks/python/apache_beam/runners/portability/fn_api_runner.py index f52286456530c..f88fe53309beb 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner.py @@ -38,6 +38,7 @@ from apache_beam.portability.api import beam_runner_api_pb2 from apache_beam.runners import pipeline_context from apache_beam.runners.portability import maptask_executor_runner +from apache_beam.runners.worker import bundle_processor from apache_beam.runners.worker import data_plane from apache_beam.runners.worker import operation_specs from apache_beam.runners.worker import sdk_worker @@ -186,7 +187,7 @@ def only_element(iterable): target_name = only_element(get_inputs(operation).keys()) runner_sinks[(transform_id, target_name)] = operation transform_spec = beam_runner_api_pb2.FunctionSpec( - urn=sdk_worker.DATA_OUTPUT_URN, + urn=bundle_processor.DATA_OUTPUT_URN, parameter=proto_utils.pack_Any(data_operation_spec)) elif isinstance(operation, operation_specs.WorkerRead): @@ -200,7 +201,7 @@ def only_element(iterable): operation.source.source.read(None), operation.source.source.default_output_coder()) transform_spec = beam_runner_api_pb2.FunctionSpec( - urn=sdk_worker.DATA_INPUT_URN, + urn=bundle_processor.DATA_INPUT_URN, parameter=proto_utils.pack_Any(data_operation_spec)) else: @@ -209,7 +210,7 @@ def only_element(iterable): # The Dataflow runner harness strips the base64 encoding. do the same # here until we get the same thing back that we sent in. transform_spec = beam_runner_api_pb2.FunctionSpec( - urn=sdk_worker.PYTHON_SOURCE_URN, + urn=bundle_processor.PYTHON_SOURCE_URN, parameter=proto_utils.pack_Any( wrappers_pb2.BytesValue( value=base64.b64decode( @@ -223,21 +224,22 @@ def only_element(iterable): element_coder = si.source.default_output_coder() # TODO(robertwb): Actually flesh out the ViewFn API. side_input_extras.append((si.tag, element_coder)) - side_input_data[sdk_worker.side_input_tag(transform_id, si.tag)] = ( - self._reencode_elements( - si.source.read(si.source.get_range_tracker(None, None)), - element_coder)) + side_input_data[ + bundle_processor.side_input_tag(transform_id, si.tag)] = ( + self._reencode_elements( + si.source.read(si.source.get_range_tracker(None, None)), + element_coder)) augmented_serialized_fn = pickler.dumps( (operation.serialized_fn, side_input_extras)) transform_spec = beam_runner_api_pb2.FunctionSpec( - urn=sdk_worker.PYTHON_DOFN_URN, + urn=bundle_processor.PYTHON_DOFN_URN, parameter=proto_utils.pack_Any( wrappers_pb2.BytesValue(value=augmented_serialized_fn))) elif isinstance(operation, operation_specs.WorkerFlatten): # Flatten is nice and simple. transform_spec = beam_runner_api_pb2.FunctionSpec( - urn=sdk_worker.IDENTITY_DOFN_URN) + urn=bundle_processor.IDENTITY_DOFN_URN) else: raise NotImplementedError(operation) diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor.py b/sdks/python/apache_beam/runners/worker/bundle_processor.py new file mode 100644 index 0000000000000..2669bfce947a2 --- /dev/null +++ b/sdks/python/apache_beam/runners/worker/bundle_processor.py @@ -0,0 +1,426 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""SDK harness for executing Python Fns via the Fn API.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import base64 +import collections +import json +import logging + +from google.protobuf import wrappers_pb2 + +from apache_beam.coders import coder_impl +from apache_beam.coders import WindowedValueCoder +from apache_beam.internal import pickler +from apache_beam.io import iobase +from apache_beam.portability.api import beam_fn_api_pb2 +from apache_beam.runners.dataflow.native_io import iobase as native_iobase +from apache_beam.runners import pipeline_context +from apache_beam.runners.worker import operation_specs +from apache_beam.runners.worker import operations +from apache_beam.utils import counters +from apache_beam.utils import proto_utils + +# This module is experimental. No backwards-compatibility guarantees. + + +try: + from apache_beam.runners.worker import statesampler +except ImportError: + from apache_beam.runners.worker import statesampler_fake as statesampler + + +DATA_INPUT_URN = 'urn:org.apache.beam:source:runner:0.1' +DATA_OUTPUT_URN = 'urn:org.apache.beam:sink:runner:0.1' +IDENTITY_DOFN_URN = 'urn:org.apache.beam:dofn:identity:0.1' +PYTHON_ITERABLE_VIEWFN_URN = 'urn:org.apache.beam:viewfn:iterable:python:0.1' +PYTHON_CODER_URN = 'urn:org.apache.beam:coder:python:0.1' +# TODO(vikasrk): Fix this once runner sends appropriate python urns. +PYTHON_DOFN_URN = 'urn:org.apache.beam:dofn:java:0.1' +PYTHON_SOURCE_URN = 'urn:org.apache.beam:source:java:0.1' + + +def side_input_tag(transform_id, tag): + return str("%d[%s][%s]" % (len(transform_id), transform_id, tag)) + + +class RunnerIOOperation(operations.Operation): + """Common baseclass for runner harness IO operations.""" + + def __init__(self, operation_name, step_name, consumers, counter_factory, + state_sampler, windowed_coder, target, data_channel): + super(RunnerIOOperation, self).__init__( + operation_name, None, counter_factory, state_sampler) + self.windowed_coder = windowed_coder + self.step_name = step_name + # target represents the consumer for the bytes in the data plane for a + # DataInputOperation or a producer of these bytes for a DataOutputOperation. + self.target = target + self.data_channel = data_channel + for _, consumer_ops in consumers.items(): + for consumer in consumer_ops: + self.add_receiver(consumer, 0) + + +class DataOutputOperation(RunnerIOOperation): + """A sink-like operation that gathers outputs to be sent back to the runner. + """ + + def set_output_stream(self, output_stream): + self.output_stream = output_stream + + def process(self, windowed_value): + self.windowed_coder.get_impl().encode_to_stream( + windowed_value, self.output_stream, True) + + def finish(self): + self.output_stream.close() + super(DataOutputOperation, self).finish() + + +class DataInputOperation(RunnerIOOperation): + """A source-like operation that gathers input from the runner. + """ + + def __init__(self, operation_name, step_name, consumers, counter_factory, + state_sampler, windowed_coder, input_target, data_channel): + super(DataInputOperation, self).__init__( + operation_name, step_name, consumers, counter_factory, state_sampler, + windowed_coder, target=input_target, data_channel=data_channel) + # We must do this manually as we don't have a spec or spec.output_coders. + self.receivers = [ + operations.ConsumerSet(self.counter_factory, self.step_name, 0, + consumers.itervalues().next(), + self.windowed_coder)] + + def process(self, windowed_value): + self.output(windowed_value) + + def process_encoded(self, encoded_windowed_values): + input_stream = coder_impl.create_InputStream(encoded_windowed_values) + while input_stream.size() > 0: + decoded_value = self.windowed_coder.get_impl().decode_from_stream( + input_stream, True) + self.output(decoded_value) + + +# TODO(robertwb): Revise side input API to not be in terms of native sources. +# This will enable lookups, but there's an open question as to how to handle +# custom sources without forcing intermediate materialization. This seems very +# related to the desire to inject key and window preserving [Splittable]DoFns +# into the view computation. +class SideInputSource(native_iobase.NativeSource, + native_iobase.NativeSourceReader): + """A 'source' for reading side inputs via state API calls. + """ + + def __init__(self, state_handler, state_key, coder): + self._state_handler = state_handler + self._state_key = state_key + self._coder = coder + + def reader(self): + return self + + @property + def returns_windowed_values(self): + return True + + def __enter__(self): + return self + + def __exit__(self, *exn_info): + pass + + def __iter__(self): + # TODO(robertwb): Support pagination. + input_stream = coder_impl.create_InputStream( + self._state_handler.Get(self._state_key).data) + while input_stream.size() > 0: + yield self._coder.get_impl().decode_from_stream(input_stream, True) + + +def memoize(func): + cache = {} + missing = object() + + def wrapper(*args): + result = cache.get(args, missing) + if result is missing: + result = cache[args] = func(*args) + return result + return wrapper + + +def only_element(iterable): + element, = iterable + return element + + +class BundleProcessor(object): + """A class for processing bundles of elements. + """ + def __init__( + self, process_bundle_descriptor, state_handler, data_channel_factory): + self.process_bundle_descriptor = process_bundle_descriptor + self.state_handler = state_handler + self.data_channel_factory = data_channel_factory + + def create_execution_tree(self, descriptor): + # TODO(robertwb): Figure out the correct prefix to use for output counters + # from StateSampler. + counter_factory = counters.CounterFactory() + state_sampler = statesampler.StateSampler( + 'fnapi-step%s-' % descriptor.id, counter_factory) + + transform_factory = BeamTransformFactory( + descriptor, self.data_channel_factory, counter_factory, state_sampler, + self.state_handler) + + pcoll_consumers = collections.defaultdict(list) + for transform_id, transform_proto in descriptor.transforms.items(): + for pcoll_id in transform_proto.inputs.values(): + pcoll_consumers[pcoll_id].append(transform_id) + + @memoize + def get_operation(transform_id): + transform_consumers = { + tag: [get_operation(op) for op in pcoll_consumers[pcoll_id]] + for tag, pcoll_id + in descriptor.transforms[transform_id].outputs.items() + } + return transform_factory.create_operation( + transform_id, transform_consumers) + + # Operations must be started (hence returned) in order. + @memoize + def topological_height(transform_id): + return 1 + max( + [0] + + [topological_height(consumer) + for pcoll in descriptor.transforms[transform_id].outputs.values() + for consumer in pcoll_consumers[pcoll]]) + + return [get_operation(transform_id) + for transform_id in sorted( + descriptor.transforms, key=topological_height, reverse=True)] + + def process_bundle(self, instruction_id): + ops = self.create_execution_tree(self.process_bundle_descriptor) + + expected_inputs = [] + for op in ops: + if isinstance(op, DataOutputOperation): + # TODO(robertwb): Is there a better way to pass the instruction id to + # the operation? + op.set_output_stream(op.data_channel.output_stream( + instruction_id, op.target)) + elif isinstance(op, DataInputOperation): + # We must wait until we receive "end of stream" for each of these ops. + expected_inputs.append(op) + + # Start all operations. + for op in reversed(ops): + logging.info('start %s', op) + op.start() + + # Inject inputs from data plane. + for input_op in expected_inputs: + for data in input_op.data_channel.input_elements( + instruction_id, [input_op.target]): + # ignores input name + input_op.process_encoded(data.data) + + # Finish all operations. + for op in ops: + logging.info('finish %s', op) + op.finish() + + +class BeamTransformFactory(object): + """Factory for turning transform_protos into executable operations.""" + def __init__(self, descriptor, data_channel_factory, counter_factory, + state_sampler, state_handler): + self.descriptor = descriptor + self.data_channel_factory = data_channel_factory + self.counter_factory = counter_factory + self.state_sampler = state_sampler + self.state_handler = state_handler + self.context = pipeline_context.PipelineContext(descriptor) + + _known_urns = {} + + @classmethod + def register_urn(cls, urn, parameter_type): + def wrapper(func): + cls._known_urns[urn] = func, parameter_type + return func + return wrapper + + def create_operation(self, transform_id, consumers): + transform_proto = self.descriptor.transforms[transform_id] + creator, parameter_type = self._known_urns[transform_proto.spec.urn] + parameter = proto_utils.unpack_Any( + transform_proto.spec.parameter, parameter_type) + return creator(self, transform_id, transform_proto, parameter, consumers) + + def get_coder(self, coder_id): + coder_proto = self.descriptor.coders[coder_id] + if coder_proto.spec.spec.urn: + return self.context.coders.get_by_id(coder_id) + else: + # No URN, assume cloud object encoding json bytes. + return operation_specs.get_coder_from_spec( + json.loads( + proto_utils.unpack_Any(coder_proto.spec.spec.parameter, + wrappers_pb2.BytesValue).value)) + + def get_output_coders(self, transform_proto): + return { + tag: self.get_coder(self.descriptor.pcollections[pcoll_id].coder_id) + for tag, pcoll_id in transform_proto.outputs.items() + } + + def get_only_output_coder(self, transform_proto): + return only_element(self.get_output_coders(transform_proto).values()) + + def get_input_coders(self, transform_proto): + return { + tag: self.get_coder(self.descriptor.pcollections[pcoll_id].coder_id) + for tag, pcoll_id in transform_proto.inputs.items() + } + + def get_only_input_coder(self, transform_proto): + return only_element(self.get_input_coders(transform_proto).values()) + + # TODO(robertwb): Update all operations to take these in the constructor. + @staticmethod + def augment_oldstyle_op(op, step_name, consumers, tag_list=None): + op.step_name = step_name + for tag, op_consumers in consumers.items(): + for consumer in op_consumers: + op.add_receiver(consumer, tag_list.index(tag) if tag_list else 0) + return op + + +@BeamTransformFactory.register_urn( + DATA_INPUT_URN, beam_fn_api_pb2.RemoteGrpcPort) +def create(factory, transform_id, transform_proto, grpc_port, consumers): + target = beam_fn_api_pb2.Target( + primitive_transform_reference=transform_id, + name=only_element(transform_proto.outputs.keys())) + return DataInputOperation( + transform_proto.unique_name, + transform_proto.unique_name, + consumers, + factory.counter_factory, + factory.state_sampler, + factory.get_only_output_coder(transform_proto), + input_target=target, + data_channel=factory.data_channel_factory.create_data_channel(grpc_port)) + + +@BeamTransformFactory.register_urn( + DATA_OUTPUT_URN, beam_fn_api_pb2.RemoteGrpcPort) +def create(factory, transform_id, transform_proto, grpc_port, consumers): + target = beam_fn_api_pb2.Target( + primitive_transform_reference=transform_id, + name=only_element(transform_proto.inputs.keys())) + return DataOutputOperation( + transform_proto.unique_name, + transform_proto.unique_name, + consumers, + factory.counter_factory, + factory.state_sampler, + # TODO(robertwb): Perhaps this could be distinct from the input coder? + factory.get_only_input_coder(transform_proto), + target=target, + data_channel=factory.data_channel_factory.create_data_channel(grpc_port)) + + +@BeamTransformFactory.register_urn(PYTHON_SOURCE_URN, wrappers_pb2.BytesValue) +def create(factory, transform_id, transform_proto, parameter, consumers): + # The Dataflow runner harness strips the base64 encoding. + source = pickler.loads(base64.b64encode(parameter.value)) + spec = operation_specs.WorkerRead( + iobase.SourceBundle(1.0, source, None, None), + [WindowedValueCoder(source.default_output_coder())]) + return factory.augment_oldstyle_op( + operations.ReadOperation( + transform_proto.unique_name, + spec, + factory.counter_factory, + factory.state_sampler), + transform_proto.unique_name, + consumers) + + +@BeamTransformFactory.register_urn(PYTHON_DOFN_URN, wrappers_pb2.BytesValue) +def create(factory, transform_id, transform_proto, parameter, consumers): + dofn_data = pickler.loads(parameter.value) + if len(dofn_data) == 2: + # Has side input data. + serialized_fn, side_input_data = dofn_data + else: + # No side input data. + serialized_fn, side_input_data = parameter.value, [] + + def create_side_input(tag, coder): + # TODO(robertwb): Extract windows (and keys) out of element data. + # TODO(robertwb): Extract state key from ParDoPayload. + return operation_specs.WorkerSideInputSource( + tag=tag, + source=SideInputSource( + factory.state_handler, + beam_fn_api_pb2.StateKey.MultimapSideInput( + key=side_input_tag(transform_id, tag)), + coder=coder)) + output_tags = list(transform_proto.outputs.keys()) + output_coders = factory.get_output_coders(transform_proto) + spec = operation_specs.WorkerDoFn( + serialized_fn=serialized_fn, + output_tags=output_tags, + input=None, + side_inputs=[ + create_side_input(tag, coder) for tag, coder in side_input_data], + output_coders=[output_coders[tag] for tag in output_tags]) + return factory.augment_oldstyle_op( + operations.DoOperation( + transform_proto.unique_name, + spec, + factory.counter_factory, + factory.state_sampler), + transform_proto.unique_name, + consumers, + output_tags) + + +@BeamTransformFactory.register_urn(IDENTITY_DOFN_URN, None) +def create(factory, transform_id, transform_proto, unused_parameter, consumers): + return factory.augment_oldstyle_op( + operations.FlattenOperation( + transform_proto.unique_name, + None, + factory.counter_factory, + factory.state_sampler), + transform_proto.unique_name, + consumers) diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker.py b/sdks/python/apache_beam/runners/worker/sdk_worker.py index ae8683047122f..6a236802b9a6f 100644 --- a/sdks/python/apache_beam/runners/worker/sdk_worker.py +++ b/sdks/python/apache_beam/runners/worker/sdk_worker.py @@ -21,170 +21,21 @@ from __future__ import division from __future__ import print_function -import base64 -import collections -import json import logging import Queue as queue import threading import traceback -from google.protobuf import wrappers_pb2 - -from apache_beam.coders import coder_impl -from apache_beam.coders import WindowedValueCoder -from apache_beam.internal import pickler -from apache_beam.io import iobase from apache_beam.portability.api import beam_fn_api_pb2 -from apache_beam.runners.dataflow.native_io import iobase as native_iobase -from apache_beam.runners import pipeline_context -from apache_beam.runners.worker import operation_specs -from apache_beam.runners.worker import operations -from apache_beam.utils import counters -from apache_beam.utils import proto_utils - -# This module is experimental. No backwards-compatibility guarantees. - - -try: - from apache_beam.runners.worker import statesampler -except ImportError: - from apache_beam.runners.worker import statesampler_fake as statesampler -from apache_beam.runners.worker.data_plane import GrpcClientDataChannelFactory - - -DATA_INPUT_URN = 'urn:org.apache.beam:source:runner:0.1' -DATA_OUTPUT_URN = 'urn:org.apache.beam:sink:runner:0.1' -IDENTITY_DOFN_URN = 'urn:org.apache.beam:dofn:identity:0.1' -PYTHON_ITERABLE_VIEWFN_URN = 'urn:org.apache.beam:viewfn:iterable:python:0.1' -PYTHON_CODER_URN = 'urn:org.apache.beam:coder:python:0.1' -# TODO(vikasrk): Fix this once runner sends appropriate python urns. -PYTHON_DOFN_URN = 'urn:org.apache.beam:dofn:java:0.1' -PYTHON_SOURCE_URN = 'urn:org.apache.beam:source:java:0.1' - - -def side_input_tag(transform_id, tag): - return str("%d[%s][%s]" % (len(transform_id), transform_id, tag)) - - -class RunnerIOOperation(operations.Operation): - """Common baseclass for runner harness IO operations.""" - - def __init__(self, operation_name, step_name, consumers, counter_factory, - state_sampler, windowed_coder, target, data_channel): - super(RunnerIOOperation, self).__init__( - operation_name, None, counter_factory, state_sampler) - self.windowed_coder = windowed_coder - self.step_name = step_name - # target represents the consumer for the bytes in the data plane for a - # DataInputOperation or a producer of these bytes for a DataOutputOperation. - self.target = target - self.data_channel = data_channel - for _, consumer_ops in consumers.items(): - for consumer in consumer_ops: - self.add_receiver(consumer, 0) - - -class DataOutputOperation(RunnerIOOperation): - """A sink-like operation that gathers outputs to be sent back to the runner. - """ - - def set_output_stream(self, output_stream): - self.output_stream = output_stream - - def process(self, windowed_value): - self.windowed_coder.get_impl().encode_to_stream( - windowed_value, self.output_stream, True) - - def finish(self): - self.output_stream.close() - super(DataOutputOperation, self).finish() - - -class DataInputOperation(RunnerIOOperation): - """A source-like operation that gathers input from the runner. - """ - - def __init__(self, operation_name, step_name, consumers, counter_factory, - state_sampler, windowed_coder, input_target, data_channel): - super(DataInputOperation, self).__init__( - operation_name, step_name, consumers, counter_factory, state_sampler, - windowed_coder, target=input_target, data_channel=data_channel) - # We must do this manually as we don't have a spec or spec.output_coders. - self.receivers = [ - operations.ConsumerSet(self.counter_factory, self.step_name, 0, - consumers.itervalues().next(), - self.windowed_coder)] - - def process(self, windowed_value): - self.output(windowed_value) - - def process_encoded(self, encoded_windowed_values): - input_stream = coder_impl.create_InputStream(encoded_windowed_values) - while input_stream.size() > 0: - decoded_value = self.windowed_coder.get_impl().decode_from_stream( - input_stream, True) - self.output(decoded_value) - - -# TODO(robertwb): Revise side input API to not be in terms of native sources. -# This will enable lookups, but there's an open question as to how to handle -# custom sources without forcing intermediate materialization. This seems very -# related to the desire to inject key and window preserving [Splittable]DoFns -# into the view computation. -class SideInputSource(native_iobase.NativeSource, - native_iobase.NativeSourceReader): - """A 'source' for reading side inputs via state API calls. - """ - - def __init__(self, state_handler, state_key, coder): - self._state_handler = state_handler - self._state_key = state_key - self._coder = coder - - def reader(self): - return self - - @property - def returns_windowed_values(self): - return True - - def __enter__(self): - return self - - def __exit__(self, *exn_info): - pass - - def __iter__(self): - # TODO(robertwb): Support pagination. - input_stream = coder_impl.create_InputStream( - self._state_handler.Get(self._state_key).data) - while input_stream.size() > 0: - yield self._coder.get_impl().decode_from_stream(input_stream, True) - - -def memoize(func): - cache = {} - missing = object() - - def wrapper(*args): - result = cache.get(args, missing) - if result is missing: - result = cache[args] = func(*args) - return result - return wrapper - - -def only_element(iterable): - element, = iterable - return element +from apache_beam.runners.worker import bundle_processor +from apache_beam.runners.worker import data_plane class SdkHarness(object): def __init__(self, control_channel): self._control_channel = control_channel - self._data_channel_factory = GrpcClientDataChannelFactory() + self._data_channel_factory = data_plane.GrpcClientDataChannelFactory() def run(self): contol_stub = beam_fn_api_pb2.BeamFnControlStub(self._control_channel) @@ -251,245 +102,10 @@ def register(self, request, unused_instruction_id=None): self.fns[process_bundle_descriptor.id] = process_bundle_descriptor return beam_fn_api_pb2.RegisterResponse() - def create_execution_tree(self, descriptor): - # TODO(robertwb): Figure out the correct prefix to use for output counters - # from StateSampler. - counter_factory = counters.CounterFactory() - state_sampler = statesampler.StateSampler( - 'fnapi-step%s-' % descriptor.id, counter_factory) - - transform_factory = BeamTransformFactory( - descriptor, self.data_channel_factory, counter_factory, state_sampler, - self.state_handler) - - pcoll_consumers = collections.defaultdict(list) - for transform_id, transform_proto in descriptor.transforms.items(): - for pcoll_id in transform_proto.inputs.values(): - pcoll_consumers[pcoll_id].append(transform_id) - - @memoize - def get_operation(transform_id): - transform_consumers = { - tag: [get_operation(op) for op in pcoll_consumers[pcoll_id]] - for tag, pcoll_id - in descriptor.transforms[transform_id].outputs.items() - } - return transform_factory.create_operation( - transform_id, transform_consumers) - - # Operations must be started (hence returned) in order. - @memoize - def topological_height(transform_id): - return 1 + max( - [0] + - [topological_height(consumer) - for pcoll in descriptor.transforms[transform_id].outputs.values() - for consumer in pcoll_consumers[pcoll]]) - - return [get_operation(transform_id) - for transform_id in sorted( - descriptor.transforms, key=topological_height, reverse=True)] - def process_bundle(self, request, instruction_id): - ops = self.create_execution_tree( - self.fns[request.process_bundle_descriptor_reference]) - - expected_inputs = [] - for op in ops: - if isinstance(op, DataOutputOperation): - # TODO(robertwb): Is there a better way to pass the instruction id to - # the operation? - op.set_output_stream(op.data_channel.output_stream( - instruction_id, op.target)) - elif isinstance(op, DataInputOperation): - # We must wait until we receive "end of stream" for each of these ops. - expected_inputs.append(op) - - # Start all operations. - for op in reversed(ops): - logging.info('start %s', op) - op.start() - - # Inject inputs from data plane. - for input_op in expected_inputs: - for data in input_op.data_channel.input_elements( - instruction_id, [input_op.target]): - # ignores input name - input_op.process_encoded(data.data) - - # Finish all operations. - for op in ops: - logging.info('finish %s', op) - op.finish() + bundle_processor.BundleProcessor( + self.fns[request.process_bundle_descriptor_reference], + self.state_handler, + self.data_channel_factory).process_bundle(instruction_id) return beam_fn_api_pb2.ProcessBundleResponse() - - -class BeamTransformFactory(object): - """Factory for turning transform_protos into executable operations.""" - def __init__(self, descriptor, data_channel_factory, counter_factory, - state_sampler, state_handler): - self.descriptor = descriptor - self.data_channel_factory = data_channel_factory - self.counter_factory = counter_factory - self.state_sampler = state_sampler - self.state_handler = state_handler - self.context = pipeline_context.PipelineContext(descriptor) - - _known_urns = {} - - @classmethod - def register_urn(cls, urn, parameter_type): - def wrapper(func): - cls._known_urns[urn] = func, parameter_type - return func - return wrapper - - def create_operation(self, transform_id, consumers): - transform_proto = self.descriptor.transforms[transform_id] - creator, parameter_type = self._known_urns[transform_proto.spec.urn] - parameter = proto_utils.unpack_Any( - transform_proto.spec.parameter, parameter_type) - return creator(self, transform_id, transform_proto, parameter, consumers) - - def get_coder(self, coder_id): - coder_proto = self.descriptor.coders[coder_id] - if coder_proto.spec.spec.urn: - return self.context.coders.get_by_id(coder_id) - else: - # No URN, assume cloud object encoding json bytes. - return operation_specs.get_coder_from_spec( - json.loads( - proto_utils.unpack_Any(coder_proto.spec.spec.parameter, - wrappers_pb2.BytesValue).value)) - - def get_output_coders(self, transform_proto): - return { - tag: self.get_coder(self.descriptor.pcollections[pcoll_id].coder_id) - for tag, pcoll_id in transform_proto.outputs.items() - } - - def get_only_output_coder(self, transform_proto): - return only_element(self.get_output_coders(transform_proto).values()) - - def get_input_coders(self, transform_proto): - return { - tag: self.get_coder(self.descriptor.pcollections[pcoll_id].coder_id) - for tag, pcoll_id in transform_proto.inputs.items() - } - - def get_only_input_coder(self, transform_proto): - return only_element(self.get_input_coders(transform_proto).values()) - - # TODO(robertwb): Update all operations to take these in the constructor. - @staticmethod - def augment_oldstyle_op(op, step_name, consumers, tag_list=None): - op.step_name = step_name - for tag, op_consumers in consumers.items(): - for consumer in op_consumers: - op.add_receiver(consumer, tag_list.index(tag) if tag_list else 0) - return op - - -@BeamTransformFactory.register_urn( - DATA_INPUT_URN, beam_fn_api_pb2.RemoteGrpcPort) -def create(factory, transform_id, transform_proto, grpc_port, consumers): - target = beam_fn_api_pb2.Target( - primitive_transform_reference=transform_id, - name=only_element(transform_proto.outputs.keys())) - return DataInputOperation( - transform_proto.unique_name, - transform_proto.unique_name, - consumers, - factory.counter_factory, - factory.state_sampler, - factory.get_only_output_coder(transform_proto), - input_target=target, - data_channel=factory.data_channel_factory.create_data_channel(grpc_port)) - - -@BeamTransformFactory.register_urn( - DATA_OUTPUT_URN, beam_fn_api_pb2.RemoteGrpcPort) -def create(factory, transform_id, transform_proto, grpc_port, consumers): - target = beam_fn_api_pb2.Target( - primitive_transform_reference=transform_id, - name=only_element(transform_proto.inputs.keys())) - return DataOutputOperation( - transform_proto.unique_name, - transform_proto.unique_name, - consumers, - factory.counter_factory, - factory.state_sampler, - # TODO(robertwb): Perhaps this could be distinct from the input coder? - factory.get_only_input_coder(transform_proto), - target=target, - data_channel=factory.data_channel_factory.create_data_channel(grpc_port)) - - -@BeamTransformFactory.register_urn(PYTHON_SOURCE_URN, wrappers_pb2.BytesValue) -def create(factory, transform_id, transform_proto, parameter, consumers): - # The Dataflow runner harness strips the base64 encoding. - source = pickler.loads(base64.b64encode(parameter.value)) - spec = operation_specs.WorkerRead( - iobase.SourceBundle(1.0, source, None, None), - [WindowedValueCoder(source.default_output_coder())]) - return factory.augment_oldstyle_op( - operations.ReadOperation( - transform_proto.unique_name, - spec, - factory.counter_factory, - factory.state_sampler), - transform_proto.unique_name, - consumers) - - -@BeamTransformFactory.register_urn(PYTHON_DOFN_URN, wrappers_pb2.BytesValue) -def create(factory, transform_id, transform_proto, parameter, consumers): - dofn_data = pickler.loads(parameter.value) - if len(dofn_data) == 2: - # Has side input data. - serialized_fn, side_input_data = dofn_data - else: - # No side input data. - serialized_fn, side_input_data = parameter.value, [] - - def create_side_input(tag, coder): - # TODO(robertwb): Extract windows (and keys) out of element data. - # TODO(robertwb): Extract state key from ParDoPayload. - return operation_specs.WorkerSideInputSource( - tag=tag, - source=SideInputSource( - factory.state_handler, - beam_fn_api_pb2.StateKey.MultimapSideInput( - key=side_input_tag(transform_id, tag)), - coder=coder)) - output_tags = list(transform_proto.outputs.keys()) - output_coders = factory.get_output_coders(transform_proto) - spec = operation_specs.WorkerDoFn( - serialized_fn=serialized_fn, - output_tags=output_tags, - input=None, - side_inputs=[ - create_side_input(tag, coder) for tag, coder in side_input_data], - output_coders=[output_coders[tag] for tag in output_tags]) - return factory.augment_oldstyle_op( - operations.DoOperation( - transform_proto.unique_name, - spec, - factory.counter_factory, - factory.state_sampler), - transform_proto.unique_name, - consumers, - output_tags) - - -@BeamTransformFactory.register_urn(IDENTITY_DOFN_URN, None) -def create(factory, transform_id, transform_proto, unused_parameter, consumers): - return factory.augment_oldstyle_op( - operations.FlattenOperation( - transform_proto.unique_name, - None, - factory.counter_factory, - factory.state_sampler), - transform_proto.unique_name, - consumers) From 21fd30283eb4f6b829b06830a3ef04df0a377b06 Mon Sep 17 00:00:00 2001 From: Pawel Kaczmarczyk Date: Mon, 19 Jun 2017 11:10:25 +0200 Subject: [PATCH 191/200] Reformatting Kinesis IO to comply with official code style --- .../sdk/io/kinesis/CheckpointGenerator.java | 6 +- .../beam/sdk/io/kinesis/CustomOptional.java | 111 ++-- .../kinesis/DynamicCheckpointGenerator.java | 52 +- .../io/kinesis/GetKinesisRecordsResult.java | 49 +- .../sdk/io/kinesis/KinesisClientProvider.java | 4 +- .../apache/beam/sdk/io/kinesis/KinesisIO.java | 279 ++++----- .../beam/sdk/io/kinesis/KinesisReader.java | 206 +++---- .../io/kinesis/KinesisReaderCheckpoint.java | 97 ++-- .../beam/sdk/io/kinesis/KinesisRecord.java | 177 +++--- .../sdk/io/kinesis/KinesisRecordCoder.java | 68 +-- .../beam/sdk/io/kinesis/KinesisSource.java | 147 ++--- .../beam/sdk/io/kinesis/RecordFilter.java | 18 +- .../beam/sdk/io/kinesis/RoundRobin.java | 37 +- .../beam/sdk/io/kinesis/ShardCheckpoint.java | 241 ++++---- .../sdk/io/kinesis/ShardRecordsIterator.java | 106 ++-- .../io/kinesis/SimplifiedKinesisClient.java | 215 +++---- .../beam/sdk/io/kinesis/StartingPoint.java | 84 +-- .../io/kinesis/StaticCheckpointGenerator.java | 27 +- .../io/kinesis/TransientKinesisException.java | 7 +- .../sdk/io/kinesis/AmazonKinesisMock.java | 539 +++++++++--------- .../sdk/io/kinesis/CustomOptionalTest.java | 27 +- .../DynamicCheckpointGeneratorTest.java | 33 +- .../sdk/io/kinesis/KinesisMockReadTest.java | 97 ++-- .../kinesis/KinesisReaderCheckpointTest.java | 52 +- .../beam/sdk/io/kinesis/KinesisReaderIT.java | 127 +++-- .../sdk/io/kinesis/KinesisReaderTest.java | 166 +++--- .../io/kinesis/KinesisRecordCoderTest.java | 34 +- .../sdk/io/kinesis/KinesisTestOptions.java | 43 +- .../beam/sdk/io/kinesis/KinesisUploader.java | 70 +-- .../beam/sdk/io/kinesis/RecordFilterTest.java | 52 +- .../beam/sdk/io/kinesis/RoundRobinTest.java | 42 +- .../sdk/io/kinesis/ShardCheckpointTest.java | 203 +++---- .../io/kinesis/ShardRecordsIteratorTest.java | 216 +++---- .../kinesis/SimplifiedKinesisClientTest.java | 351 ++++++------ 34 files changed, 2031 insertions(+), 1952 deletions(-) diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/CheckpointGenerator.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/CheckpointGenerator.java index 919d85aacb4c7..2629c57c7583f 100644 --- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/CheckpointGenerator.java +++ b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/CheckpointGenerator.java @@ -17,7 +17,6 @@ */ package org.apache.beam.sdk.io.kinesis; - import java.io.Serializable; /** @@ -25,6 +24,7 @@ * How exactly the checkpoint is generated is up to implementing class. */ interface CheckpointGenerator extends Serializable { - KinesisReaderCheckpoint generate(SimplifiedKinesisClient client) - throws TransientKinesisException; + + KinesisReaderCheckpoint generate(SimplifiedKinesisClient client) + throws TransientKinesisException; } diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/CustomOptional.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/CustomOptional.java index 4bed0e39bcfc4..5a282148b0115 100644 --- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/CustomOptional.java +++ b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/CustomOptional.java @@ -24,76 +24,79 @@ * Similar to Guava {@code Optional}, but throws {@link NoSuchElementException} for missing element. */ abstract class CustomOptional { - @SuppressWarnings("unchecked") - public static CustomOptional absent() { - return (Absent) Absent.INSTANCE; - } - public static CustomOptional of(T v) { - return new Present<>(v); - } + @SuppressWarnings("unchecked") + public static CustomOptional absent() { + return (Absent) Absent.INSTANCE; + } - public abstract boolean isPresent(); + public static CustomOptional of(T v) { + return new Present<>(v); + } - public abstract T get(); + public abstract boolean isPresent(); - private static class Present extends CustomOptional { - private final T value; + public abstract T get(); - private Present(T value) { - this.value = value; - } + private static class Present extends CustomOptional { - @Override - public boolean isPresent() { - return true; - } + private final T value; - @Override - public T get() { - return value; - } + private Present(T value) { + this.value = value; + } - @Override - public boolean equals(Object o) { - if (!(o instanceof Present)) { - return false; - } + @Override + public boolean isPresent() { + return true; + } - Present present = (Present) o; - return Objects.equals(value, present.value); - } + @Override + public T get() { + return value; + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof Present)) { + return false; + } - @Override - public int hashCode() { - return Objects.hash(value); - } + Present present = (Present) o; + return Objects.equals(value, present.value); } - private static class Absent extends CustomOptional { - private static final Absent INSTANCE = new Absent<>(); + @Override + public int hashCode() { + return Objects.hash(value); + } + } - private Absent() { - } + private static class Absent extends CustomOptional { - @Override - public boolean isPresent() { - return false; - } + private static final Absent INSTANCE = new Absent<>(); - @Override - public T get() { - throw new NoSuchElementException(); - } + private Absent() { + } + + @Override + public boolean isPresent() { + return false; + } - @Override - public boolean equals(Object o) { - return o instanceof Absent; - } + @Override + public T get() { + throw new NoSuchElementException(); + } + + @Override + public boolean equals(Object o) { + return o instanceof Absent; + } - @Override - public int hashCode() { - return 0; - } + @Override + public int hashCode() { + return 0; } + } } diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/DynamicCheckpointGenerator.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/DynamicCheckpointGenerator.java index 2ec293cfeac21..9933019e4d4cf 100644 --- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/DynamicCheckpointGenerator.java +++ b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/DynamicCheckpointGenerator.java @@ -28,29 +28,31 @@ * List of shards is obtained dynamically on call to {@link #generate(SimplifiedKinesisClient)}. */ class DynamicCheckpointGenerator implements CheckpointGenerator { - private final String streamName; - private final StartingPoint startingPoint; - - public DynamicCheckpointGenerator(String streamName, StartingPoint startingPoint) { - this.streamName = checkNotNull(streamName, "streamName"); - this.startingPoint = checkNotNull(startingPoint, "startingPoint"); - } - - @Override - public KinesisReaderCheckpoint generate(SimplifiedKinesisClient kinesis) - throws TransientKinesisException { - return new KinesisReaderCheckpoint( - transform(kinesis.listShards(streamName), new Function() { - @Override - public ShardCheckpoint apply(Shard shard) { - return new ShardCheckpoint(streamName, shard.getShardId(), startingPoint); - } - }) - ); - } - - @Override - public String toString() { - return String.format("Checkpoint generator for %s: %s", streamName, startingPoint); - } + + private final String streamName; + private final StartingPoint startingPoint; + + public DynamicCheckpointGenerator(String streamName, StartingPoint startingPoint) { + this.streamName = checkNotNull(streamName, "streamName"); + this.startingPoint = checkNotNull(startingPoint, "startingPoint"); + } + + @Override + public KinesisReaderCheckpoint generate(SimplifiedKinesisClient kinesis) + throws TransientKinesisException { + return new KinesisReaderCheckpoint( + transform(kinesis.listShards(streamName), new Function() { + + @Override + public ShardCheckpoint apply(Shard shard) { + return new ShardCheckpoint(streamName, shard.getShardId(), startingPoint); + } + }) + ); + } + + @Override + public String toString() { + return String.format("Checkpoint generator for %s: %s", streamName, startingPoint); + } } diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/GetKinesisRecordsResult.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/GetKinesisRecordsResult.java index 5a34d7d4401ac..f605f5506510b 100644 --- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/GetKinesisRecordsResult.java +++ b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/GetKinesisRecordsResult.java @@ -21,6 +21,7 @@ import com.amazonaws.services.kinesis.clientlibrary.types.UserRecord; import com.google.common.base.Function; + import java.util.List; import javax.annotation.Nullable; @@ -28,27 +29,29 @@ * Represents the output of 'get' operation on Kinesis stream. */ class GetKinesisRecordsResult { - private final List records; - private final String nextShardIterator; - - public GetKinesisRecordsResult(List records, String nextShardIterator, - final String streamName, final String shardId) { - this.records = transform(records, new Function() { - @Nullable - @Override - public KinesisRecord apply(@Nullable UserRecord input) { - assert input != null; // to make FindBugs happy - return new KinesisRecord(input, streamName, shardId); - } - }); - this.nextShardIterator = nextShardIterator; - } - - public List getRecords() { - return records; - } - - public String getNextShardIterator() { - return nextShardIterator; - } + + private final List records; + private final String nextShardIterator; + + public GetKinesisRecordsResult(List records, String nextShardIterator, + final String streamName, final String shardId) { + this.records = transform(records, new Function() { + + @Nullable + @Override + public KinesisRecord apply(@Nullable UserRecord input) { + assert input != null; // to make FindBugs happy + return new KinesisRecord(input, streamName, shardId); + } + }); + this.nextShardIterator = nextShardIterator; + } + + public List getRecords() { + return records; + } + + public String getNextShardIterator() { + return nextShardIterator; + } } diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisClientProvider.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisClientProvider.java index c7fd7f618e997..b5b721e23c544 100644 --- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisClientProvider.java +++ b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisClientProvider.java @@ -18,6 +18,7 @@ package org.apache.beam.sdk.io.kinesis; import com.amazonaws.services.kinesis.AmazonKinesis; + import java.io.Serializable; /** @@ -27,5 +28,6 @@ * {@link Serializable} to ensure it can be sent to worker machines. */ interface KinesisClientProvider extends Serializable { - AmazonKinesis get(); + + AmazonKinesis get(); } diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisIO.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisIO.java index b85eb6347dbce..bc8ada168b27b 100644 --- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisIO.java +++ b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisIO.java @@ -17,7 +17,6 @@ */ package org.apache.beam.sdk.io.kinesis; - import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkNotNull; @@ -29,7 +28,9 @@ import com.amazonaws.services.kinesis.AmazonKinesisClient; import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream; import com.google.auto.value.AutoValue; + import javax.annotation.Nullable; + import org.apache.beam.sdk.annotations.Experimental; import org.apache.beam.sdk.io.BoundedReadFromUnboundedSource; import org.apache.beam.sdk.transforms.PTransform; @@ -102,142 +103,148 @@ */ @Experimental(Experimental.Kind.SOURCE_SINK) public final class KinesisIO { - /** Returns a new {@link Read} transform for reading from Kinesis. */ - public static Read read() { - return new AutoValue_KinesisIO_Read.Builder().setMaxNumRecords(-1).build(); + + /** Returns a new {@link Read} transform for reading from Kinesis. */ + public static Read read() { + return new AutoValue_KinesisIO_Read.Builder().setMaxNumRecords(-1).build(); + } + + /** Implementation of {@link #read}. */ + @AutoValue + public abstract static class Read extends PTransform> { + + @Nullable + abstract String getStreamName(); + + @Nullable + abstract StartingPoint getInitialPosition(); + + @Nullable + abstract KinesisClientProvider getClientProvider(); + + abstract int getMaxNumRecords(); + + @Nullable + abstract Duration getMaxReadTime(); + + abstract Builder toBuilder(); + + @AutoValue.Builder + abstract static class Builder { + + abstract Builder setStreamName(String streamName); + + abstract Builder setInitialPosition(StartingPoint startingPoint); + + abstract Builder setClientProvider(KinesisClientProvider clientProvider); + + abstract Builder setMaxNumRecords(int maxNumRecords); + + abstract Builder setMaxReadTime(Duration maxReadTime); + + abstract Read build(); } - /** Implementation of {@link #read}. */ - @AutoValue - public abstract static class Read extends PTransform> { - @Nullable - abstract String getStreamName(); - - @Nullable - abstract StartingPoint getInitialPosition(); - - @Nullable - abstract KinesisClientProvider getClientProvider(); - - abstract int getMaxNumRecords(); - - @Nullable - abstract Duration getMaxReadTime(); - - abstract Builder toBuilder(); - - @AutoValue.Builder - abstract static class Builder { - abstract Builder setStreamName(String streamName); - abstract Builder setInitialPosition(StartingPoint startingPoint); - abstract Builder setClientProvider(KinesisClientProvider clientProvider); - abstract Builder setMaxNumRecords(int maxNumRecords); - abstract Builder setMaxReadTime(Duration maxReadTime); - - abstract Read build(); - } - - /** - * Specify reading from streamName at some initial position. - */ - public Read from(String streamName, InitialPositionInStream initialPosition) { - return toBuilder() - .setStreamName(streamName) - .setInitialPosition( - new StartingPoint(checkNotNull(initialPosition, "initialPosition"))) - .build(); - } - - /** - * Specify reading from streamName beginning at given {@link Instant}. - * This {@link Instant} must be in the past, i.e. before {@link Instant#now()}. - */ - public Read from(String streamName, Instant initialTimestamp) { - return toBuilder() - .setStreamName(streamName) - .setInitialPosition( - new StartingPoint(checkNotNull(initialTimestamp, "initialTimestamp"))) - .build(); - } - - /** - * Allows to specify custom {@link KinesisClientProvider}. - * {@link KinesisClientProvider} provides {@link AmazonKinesis} instances which are later - * used for communication with Kinesis. - * You should use this method if {@link Read#withClientProvider(String, String, Regions)} - * does not suit your needs. - */ - public Read withClientProvider(KinesisClientProvider kinesisClientProvider) { - return toBuilder().setClientProvider(kinesisClientProvider).build(); - } - - /** - * Specify credential details and region to be used to read from Kinesis. - * If you need more sophisticated credential protocol, then you should look at - * {@link Read#withClientProvider(KinesisClientProvider)}. - */ - public Read withClientProvider(String awsAccessKey, String awsSecretKey, Regions region) { - return withClientProvider(new BasicKinesisProvider(awsAccessKey, awsSecretKey, region)); - } - - /** Specifies to read at most a given number of records. */ - public Read withMaxNumRecords(int maxNumRecords) { - checkArgument( - maxNumRecords > 0, "maxNumRecords must be positive, but was: %s", maxNumRecords); - return toBuilder().setMaxNumRecords(maxNumRecords).build(); - } - - /** Specifies to read at most a given number of records. */ - public Read withMaxReadTime(Duration maxReadTime) { - checkNotNull(maxReadTime, "maxReadTime"); - return toBuilder().setMaxReadTime(maxReadTime).build(); - } - - @Override - public PCollection expand(PBegin input) { - org.apache.beam.sdk.io.Read.Unbounded read = - org.apache.beam.sdk.io.Read.from( - new KinesisSource(getClientProvider(), getStreamName(), getInitialPosition())); - if (getMaxNumRecords() > 0) { - BoundedReadFromUnboundedSource bounded = - read.withMaxNumRecords(getMaxNumRecords()); - return getMaxReadTime() == null - ? input.apply(bounded) - : input.apply(bounded.withMaxReadTime(getMaxReadTime())); - } else { - return getMaxReadTime() == null - ? input.apply(read) - : input.apply(read.withMaxReadTime(getMaxReadTime())); - } - } - - private static final class BasicKinesisProvider implements KinesisClientProvider { - - private final String accessKey; - private final String secretKey; - private final Regions region; - - private BasicKinesisProvider(String accessKey, String secretKey, Regions region) { - this.accessKey = checkNotNull(accessKey, "accessKey"); - this.secretKey = checkNotNull(secretKey, "secretKey"); - this.region = checkNotNull(region, "region"); - } - - - private AWSCredentialsProvider getCredentialsProvider() { - return new StaticCredentialsProvider(new BasicAWSCredentials( - accessKey, - secretKey - )); - - } - - @Override - public AmazonKinesis get() { - AmazonKinesisClient client = new AmazonKinesisClient(getCredentialsProvider()); - client.withRegion(region); - return client; - } - } + /** + * Specify reading from streamName at some initial position. + */ + public Read from(String streamName, InitialPositionInStream initialPosition) { + return toBuilder() + .setStreamName(streamName) + .setInitialPosition( + new StartingPoint(checkNotNull(initialPosition, "initialPosition"))) + .build(); + } + + /** + * Specify reading from streamName beginning at given {@link Instant}. + * This {@link Instant} must be in the past, i.e. before {@link Instant#now()}. + */ + public Read from(String streamName, Instant initialTimestamp) { + return toBuilder() + .setStreamName(streamName) + .setInitialPosition( + new StartingPoint(checkNotNull(initialTimestamp, "initialTimestamp"))) + .build(); + } + + /** + * Allows to specify custom {@link KinesisClientProvider}. + * {@link KinesisClientProvider} provides {@link AmazonKinesis} instances which are later + * used for communication with Kinesis. + * You should use this method if {@link Read#withClientProvider(String, String, Regions)} + * does not suit your needs. + */ + public Read withClientProvider(KinesisClientProvider kinesisClientProvider) { + return toBuilder().setClientProvider(kinesisClientProvider).build(); + } + + /** + * Specify credential details and region to be used to read from Kinesis. + * If you need more sophisticated credential protocol, then you should look at + * {@link Read#withClientProvider(KinesisClientProvider)}. + */ + public Read withClientProvider(String awsAccessKey, String awsSecretKey, Regions region) { + return withClientProvider(new BasicKinesisProvider(awsAccessKey, awsSecretKey, region)); + } + + /** Specifies to read at most a given number of records. */ + public Read withMaxNumRecords(int maxNumRecords) { + checkArgument( + maxNumRecords > 0, "maxNumRecords must be positive, but was: %s", maxNumRecords); + return toBuilder().setMaxNumRecords(maxNumRecords).build(); + } + + /** Specifies to read at most a given number of records. */ + public Read withMaxReadTime(Duration maxReadTime) { + checkNotNull(maxReadTime, "maxReadTime"); + return toBuilder().setMaxReadTime(maxReadTime).build(); + } + + @Override + public PCollection expand(PBegin input) { + org.apache.beam.sdk.io.Read.Unbounded read = + org.apache.beam.sdk.io.Read.from( + new KinesisSource(getClientProvider(), getStreamName(), getInitialPosition())); + if (getMaxNumRecords() > 0) { + BoundedReadFromUnboundedSource bounded = + read.withMaxNumRecords(getMaxNumRecords()); + return getMaxReadTime() == null + ? input.apply(bounded) + : input.apply(bounded.withMaxReadTime(getMaxReadTime())); + } else { + return getMaxReadTime() == null + ? input.apply(read) + : input.apply(read.withMaxReadTime(getMaxReadTime())); + } + } + + private static final class BasicKinesisProvider implements KinesisClientProvider { + + private final String accessKey; + private final String secretKey; + private final Regions region; + + private BasicKinesisProvider(String accessKey, String secretKey, Regions region) { + this.accessKey = checkNotNull(accessKey, "accessKey"); + this.secretKey = checkNotNull(secretKey, "secretKey"); + this.region = checkNotNull(region, "region"); + } + + private AWSCredentialsProvider getCredentialsProvider() { + return new StaticCredentialsProvider(new BasicAWSCredentials( + accessKey, + secretKey + )); + + } + + @Override + public AmazonKinesis get() { + AmazonKinesisClient client = new AmazonKinesisClient(getCredentialsProvider()); + client.withRegion(region); + return client; + } } + } } diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisReader.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisReader.java index 21380941246db..e5c32d20d919a 100644 --- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisReader.java +++ b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisReader.java @@ -17,129 +17,129 @@ */ package org.apache.beam.sdk.io.kinesis; - import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.collect.Lists.newArrayList; import java.io.IOException; import java.util.List; import java.util.NoSuchElementException; + import org.apache.beam.sdk.io.UnboundedSource; import org.joda.time.Instant; import org.slf4j.Logger; import org.slf4j.LoggerFactory; - /** * Reads data from multiple kinesis shards in a single thread. * It uses simple round robin algorithm when fetching data from shards. */ class KinesisReader extends UnboundedSource.UnboundedReader { - private static final Logger LOG = LoggerFactory.getLogger(KinesisReader.class); - - private final SimplifiedKinesisClient kinesis; - private final UnboundedSource source; - private final CheckpointGenerator initialCheckpointGenerator; - private RoundRobin shardIterators; - private CustomOptional currentRecord = CustomOptional.absent(); - - public KinesisReader(SimplifiedKinesisClient kinesis, - CheckpointGenerator initialCheckpointGenerator, - UnboundedSource source) { - this.kinesis = checkNotNull(kinesis, "kinesis"); - this.initialCheckpointGenerator = - checkNotNull(initialCheckpointGenerator, "initialCheckpointGenerator"); - this.source = source; - } - - /** - * Generates initial checkpoint and instantiates iterators for shards. - */ - @Override - public boolean start() throws IOException { - LOG.info("Starting reader using {}", initialCheckpointGenerator); - - try { - KinesisReaderCheckpoint initialCheckpoint = - initialCheckpointGenerator.generate(kinesis); - List iterators = newArrayList(); - for (ShardCheckpoint checkpoint : initialCheckpoint) { - iterators.add(checkpoint.getShardRecordsIterator(kinesis)); - } - shardIterators = new RoundRobin<>(iterators); - } catch (TransientKinesisException e) { - throw new IOException(e); - } - return advance(); + private static final Logger LOG = LoggerFactory.getLogger(KinesisReader.class); + + private final SimplifiedKinesisClient kinesis; + private final UnboundedSource source; + private final CheckpointGenerator initialCheckpointGenerator; + private RoundRobin shardIterators; + private CustomOptional currentRecord = CustomOptional.absent(); + + public KinesisReader(SimplifiedKinesisClient kinesis, + CheckpointGenerator initialCheckpointGenerator, + UnboundedSource source) { + this.kinesis = checkNotNull(kinesis, "kinesis"); + this.initialCheckpointGenerator = + checkNotNull(initialCheckpointGenerator, "initialCheckpointGenerator"); + this.source = source; + } + + /** + * Generates initial checkpoint and instantiates iterators for shards. + */ + @Override + public boolean start() throws IOException { + LOG.info("Starting reader using {}", initialCheckpointGenerator); + + try { + KinesisReaderCheckpoint initialCheckpoint = + initialCheckpointGenerator.generate(kinesis); + List iterators = newArrayList(); + for (ShardCheckpoint checkpoint : initialCheckpoint) { + iterators.add(checkpoint.getShardRecordsIterator(kinesis)); + } + shardIterators = new RoundRobin<>(iterators); + } catch (TransientKinesisException e) { + throw new IOException(e); } - /** - * Moves to the next record in one of the shards. - * If current shard iterator can be move forward (i.e. there's a record present) then we do it. - * If not, we iterate over shards in a round-robin manner. - */ - @Override - public boolean advance() throws IOException { - try { - for (int i = 0; i < shardIterators.size(); ++i) { - currentRecord = shardIterators.getCurrent().next(); - if (currentRecord.isPresent()) { - return true; - } else { - shardIterators.moveForward(); - } - } - } catch (TransientKinesisException e) { - LOG.warn("Transient exception occurred", e); + return advance(); + } + + /** + * Moves to the next record in one of the shards. + * If current shard iterator can be move forward (i.e. there's a record present) then we do it. + * If not, we iterate over shards in a round-robin manner. + */ + @Override + public boolean advance() throws IOException { + try { + for (int i = 0; i < shardIterators.size(); ++i) { + currentRecord = shardIterators.getCurrent().next(); + if (currentRecord.isPresent()) { + return true; + } else { + shardIterators.moveForward(); } - return false; - } - - @Override - public byte[] getCurrentRecordId() throws NoSuchElementException { - return currentRecord.get().getUniqueId(); - } - - @Override - public KinesisRecord getCurrent() throws NoSuchElementException { - return currentRecord.get(); - } - - /** - * When {@link KinesisReader} was advanced to the current record. - * We cannot use approximate arrival timestamp given for each record by Kinesis as it - * is not guaranteed to be accurate - this could lead to mark some records as "late" - * even if they were not. - */ - @Override - public Instant getCurrentTimestamp() throws NoSuchElementException { - return currentRecord.get().getReadTime(); - } - - @Override - public void close() throws IOException { - } - - /** - * Current time. - * We cannot give better approximation of the watermark with current semantics of - * {@link KinesisReader#getCurrentTimestamp()}, because we don't know when the next - * {@link KinesisReader#advance()} will be called. - */ - @Override - public Instant getWatermark() { - return Instant.now(); - } - - @Override - public UnboundedSource.CheckpointMark getCheckpointMark() { - return KinesisReaderCheckpoint.asCurrentStateOf(shardIterators); - } - - @Override - public UnboundedSource getCurrentSource() { - return source; + } + } catch (TransientKinesisException e) { + LOG.warn("Transient exception occurred", e); } + return false; + } + + @Override + public byte[] getCurrentRecordId() throws NoSuchElementException { + return currentRecord.get().getUniqueId(); + } + + @Override + public KinesisRecord getCurrent() throws NoSuchElementException { + return currentRecord.get(); + } + + /** + * When {@link KinesisReader} was advanced to the current record. + * We cannot use approximate arrival timestamp given for each record by Kinesis as it + * is not guaranteed to be accurate - this could lead to mark some records as "late" + * even if they were not. + */ + @Override + public Instant getCurrentTimestamp() throws NoSuchElementException { + return currentRecord.get().getReadTime(); + } + + @Override + public void close() throws IOException { + } + + /** + * Current time. + * We cannot give better approximation of the watermark with current semantics of + * {@link KinesisReader#getCurrentTimestamp()}, because we don't know when the next + * {@link KinesisReader#advance()} will be called. + */ + @Override + public Instant getWatermark() { + return Instant.now(); + } + + @Override + public UnboundedSource.CheckpointMark getCheckpointMark() { + return KinesisReaderCheckpoint.asCurrentStateOf(shardIterators); + } + + @Override + public UnboundedSource getCurrentSource() { + return source; + } } diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisReaderCheckpoint.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisReaderCheckpoint.java index f0fa45d9c26cd..d995e7546449c 100644 --- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisReaderCheckpoint.java +++ b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisReaderCheckpoint.java @@ -23,11 +23,13 @@ import com.google.common.base.Function; import com.google.common.collect.ImmutableList; + import java.io.IOException; import java.io.Serializable; import java.util.Iterator; import java.util.List; import javax.annotation.Nullable; + import org.apache.beam.sdk.io.UnboundedSource; /** @@ -37,60 +39,61 @@ * This class is immutable. */ class KinesisReaderCheckpoint implements Iterable, UnboundedSource - .CheckpointMark, Serializable { - private final List shardCheckpoints; + .CheckpointMark, Serializable { - public KinesisReaderCheckpoint(Iterable shardCheckpoints) { - this.shardCheckpoints = ImmutableList.copyOf(shardCheckpoints); - } + private final List shardCheckpoints; - public static KinesisReaderCheckpoint asCurrentStateOf(Iterable - iterators) { - return new KinesisReaderCheckpoint(transform(iterators, - new Function() { - - @Nullable - @Override - public ShardCheckpoint apply(@Nullable - ShardRecordsIterator shardRecordsIterator) { - assert shardRecordsIterator != null; - return shardRecordsIterator.getCheckpoint(); - } - })); - } + public KinesisReaderCheckpoint(Iterable shardCheckpoints) { + this.shardCheckpoints = ImmutableList.copyOf(shardCheckpoints); + } - /** - * Splits given multi-shard checkpoint into partitions of approximately equal size. - * - * @param desiredNumSplits - upper limit for number of partitions to generate. - * @return list of checkpoints covering consecutive partitions of current checkpoint. - */ - public List splitInto(int desiredNumSplits) { - int partitionSize = divideAndRoundUp(shardCheckpoints.size(), desiredNumSplits); - - List checkpoints = newArrayList(); - for (List shardPartition : partition(shardCheckpoints, partitionSize)) { - checkpoints.add(new KinesisReaderCheckpoint(shardPartition)); - } - return checkpoints; - } + public static KinesisReaderCheckpoint asCurrentStateOf(Iterable + iterators) { + return new KinesisReaderCheckpoint(transform(iterators, + new Function() { - private int divideAndRoundUp(int nominator, int denominator) { - return (nominator + denominator - 1) / denominator; - } + @Nullable + @Override + public ShardCheckpoint apply(@Nullable + ShardRecordsIterator shardRecordsIterator) { + assert shardRecordsIterator != null; + return shardRecordsIterator.getCheckpoint(); + } + })); + } - @Override - public void finalizeCheckpoint() throws IOException { + /** + * Splits given multi-shard checkpoint into partitions of approximately equal size. + * + * @param desiredNumSplits - upper limit for number of partitions to generate. + * @return list of checkpoints covering consecutive partitions of current checkpoint. + */ + public List splitInto(int desiredNumSplits) { + int partitionSize = divideAndRoundUp(shardCheckpoints.size(), desiredNumSplits); + List checkpoints = newArrayList(); + for (List shardPartition : partition(shardCheckpoints, partitionSize)) { + checkpoints.add(new KinesisReaderCheckpoint(shardPartition)); } + return checkpoints; + } - @Override - public String toString() { - return shardCheckpoints.toString(); - } + private int divideAndRoundUp(int nominator, int denominator) { + return (nominator + denominator - 1) / denominator; + } - @Override - public Iterator iterator() { - return shardCheckpoints.iterator(); - } + @Override + public void finalizeCheckpoint() throws IOException { + + } + + @Override + public String toString() { + return shardCheckpoints.toString(); + } + + @Override + public Iterator iterator() { + return shardCheckpoints.iterator(); + } } diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisRecord.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisRecord.java index 02b5370f7a5c6..057b7bb5ff7ec 100644 --- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisRecord.java +++ b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisRecord.java @@ -22,7 +22,9 @@ import com.amazonaws.services.kinesis.clientlibrary.types.ExtendedSequenceNumber; import com.amazonaws.services.kinesis.clientlibrary.types.UserRecord; import com.google.common.base.Charsets; + import java.nio.ByteBuffer; + import org.apache.commons.lang.builder.EqualsBuilder; import org.joda.time.Instant; @@ -30,91 +32,92 @@ * {@link UserRecord} enhanced with utility methods. */ public class KinesisRecord { - private Instant readTime; - private String streamName; - private String shardId; - private long subSequenceNumber; - private String sequenceNumber; - private Instant approximateArrivalTimestamp; - private ByteBuffer data; - private String partitionKey; - - public KinesisRecord(UserRecord record, String streamName, String shardId) { - this(record.getData(), record.getSequenceNumber(), record.getSubSequenceNumber(), - record.getPartitionKey(), - new Instant(record.getApproximateArrivalTimestamp()), - Instant.now(), - streamName, shardId); - } - - public KinesisRecord(ByteBuffer data, String sequenceNumber, long subSequenceNumber, - String partitionKey, Instant approximateArrivalTimestamp, - Instant readTime, - String streamName, String shardId) { - this.data = data; - this.sequenceNumber = sequenceNumber; - this.subSequenceNumber = subSequenceNumber; - this.partitionKey = partitionKey; - this.approximateArrivalTimestamp = approximateArrivalTimestamp; - this.readTime = readTime; - this.streamName = streamName; - this.shardId = shardId; - } - - public ExtendedSequenceNumber getExtendedSequenceNumber() { - return new ExtendedSequenceNumber(getSequenceNumber(), getSubSequenceNumber()); - } - - /*** - * @return unique id of the record based on its position in the stream - */ - public byte[] getUniqueId() { - return getExtendedSequenceNumber().toString().getBytes(Charsets.UTF_8); - } - - public Instant getReadTime() { - return readTime; - } - - public String getStreamName() { - return streamName; - } - - public String getShardId() { - return shardId; - } - - public byte[] getDataAsBytes() { - return getData().array(); - } - - @Override - public boolean equals(Object obj) { - return EqualsBuilder.reflectionEquals(this, obj); - } - - @Override - public int hashCode() { - return reflectionHashCode(this); - } - - public long getSubSequenceNumber() { - return subSequenceNumber; - } - - public String getSequenceNumber() { - return sequenceNumber; - } - - public Instant getApproximateArrivalTimestamp() { - return approximateArrivalTimestamp; - } - - public ByteBuffer getData() { - return data; - } - - public String getPartitionKey() { - return partitionKey; - } + + private Instant readTime; + private String streamName; + private String shardId; + private long subSequenceNumber; + private String sequenceNumber; + private Instant approximateArrivalTimestamp; + private ByteBuffer data; + private String partitionKey; + + public KinesisRecord(UserRecord record, String streamName, String shardId) { + this(record.getData(), record.getSequenceNumber(), record.getSubSequenceNumber(), + record.getPartitionKey(), + new Instant(record.getApproximateArrivalTimestamp()), + Instant.now(), + streamName, shardId); + } + + public KinesisRecord(ByteBuffer data, String sequenceNumber, long subSequenceNumber, + String partitionKey, Instant approximateArrivalTimestamp, + Instant readTime, + String streamName, String shardId) { + this.data = data; + this.sequenceNumber = sequenceNumber; + this.subSequenceNumber = subSequenceNumber; + this.partitionKey = partitionKey; + this.approximateArrivalTimestamp = approximateArrivalTimestamp; + this.readTime = readTime; + this.streamName = streamName; + this.shardId = shardId; + } + + public ExtendedSequenceNumber getExtendedSequenceNumber() { + return new ExtendedSequenceNumber(getSequenceNumber(), getSubSequenceNumber()); + } + + /*** + * @return unique id of the record based on its position in the stream + */ + public byte[] getUniqueId() { + return getExtendedSequenceNumber().toString().getBytes(Charsets.UTF_8); + } + + public Instant getReadTime() { + return readTime; + } + + public String getStreamName() { + return streamName; + } + + public String getShardId() { + return shardId; + } + + public byte[] getDataAsBytes() { + return getData().array(); + } + + @Override + public boolean equals(Object obj) { + return EqualsBuilder.reflectionEquals(this, obj); + } + + @Override + public int hashCode() { + return reflectionHashCode(this); + } + + public long getSubSequenceNumber() { + return subSequenceNumber; + } + + public String getSequenceNumber() { + return sequenceNumber; + } + + public Instant getApproximateArrivalTimestamp() { + return approximateArrivalTimestamp; + } + + public ByteBuffer getData() { + return data; + } + + public String getPartitionKey() { + return partitionKey; + } } diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisRecordCoder.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisRecordCoder.java index f233e27d064f4..dcf564d3ec732 100644 --- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisRecordCoder.java +++ b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisRecordCoder.java @@ -21,6 +21,7 @@ import java.io.InputStream; import java.io.OutputStream; import java.nio.ByteBuffer; + import org.apache.beam.sdk.coders.AtomicCoder; import org.apache.beam.sdk.coders.ByteArrayCoder; import org.apache.beam.sdk.coders.Coder; @@ -33,40 +34,41 @@ * A {@link Coder} for {@link KinesisRecord}. */ class KinesisRecordCoder extends AtomicCoder { - private static final StringUtf8Coder STRING_CODER = StringUtf8Coder.of(); - private static final ByteArrayCoder BYTE_ARRAY_CODER = ByteArrayCoder.of(); - private static final InstantCoder INSTANT_CODER = InstantCoder.of(); - private static final VarLongCoder VAR_LONG_CODER = VarLongCoder.of(); - public static KinesisRecordCoder of() { - return new KinesisRecordCoder(); - } + private static final StringUtf8Coder STRING_CODER = StringUtf8Coder.of(); + private static final ByteArrayCoder BYTE_ARRAY_CODER = ByteArrayCoder.of(); + private static final InstantCoder INSTANT_CODER = InstantCoder.of(); + private static final VarLongCoder VAR_LONG_CODER = VarLongCoder.of(); + + public static KinesisRecordCoder of() { + return new KinesisRecordCoder(); + } - @Override - public void encode(KinesisRecord value, OutputStream outStream) throws - IOException { - BYTE_ARRAY_CODER.encode(value.getData().array(), outStream); - STRING_CODER.encode(value.getSequenceNumber(), outStream); - STRING_CODER.encode(value.getPartitionKey(), outStream); - INSTANT_CODER.encode(value.getApproximateArrivalTimestamp(), outStream); - VAR_LONG_CODER.encode(value.getSubSequenceNumber(), outStream); - INSTANT_CODER.encode(value.getReadTime(), outStream); - STRING_CODER.encode(value.getStreamName(), outStream); - STRING_CODER.encode(value.getShardId(), outStream); - } + @Override + public void encode(KinesisRecord value, OutputStream outStream) throws + IOException { + BYTE_ARRAY_CODER.encode(value.getData().array(), outStream); + STRING_CODER.encode(value.getSequenceNumber(), outStream); + STRING_CODER.encode(value.getPartitionKey(), outStream); + INSTANT_CODER.encode(value.getApproximateArrivalTimestamp(), outStream); + VAR_LONG_CODER.encode(value.getSubSequenceNumber(), outStream); + INSTANT_CODER.encode(value.getReadTime(), outStream); + STRING_CODER.encode(value.getStreamName(), outStream); + STRING_CODER.encode(value.getShardId(), outStream); + } - @Override - public KinesisRecord decode(InputStream inStream) throws IOException { - ByteBuffer data = ByteBuffer.wrap(BYTE_ARRAY_CODER.decode(inStream)); - String sequenceNumber = STRING_CODER.decode(inStream); - String partitionKey = STRING_CODER.decode(inStream); - Instant approximateArrivalTimestamp = INSTANT_CODER.decode(inStream); - long subSequenceNumber = VAR_LONG_CODER.decode(inStream); - Instant readTimestamp = INSTANT_CODER.decode(inStream); - String streamName = STRING_CODER.decode(inStream); - String shardId = STRING_CODER.decode(inStream); - return new KinesisRecord(data, sequenceNumber, subSequenceNumber, partitionKey, - approximateArrivalTimestamp, readTimestamp, streamName, shardId - ); - } + @Override + public KinesisRecord decode(InputStream inStream) throws IOException { + ByteBuffer data = ByteBuffer.wrap(BYTE_ARRAY_CODER.decode(inStream)); + String sequenceNumber = STRING_CODER.decode(inStream); + String partitionKey = STRING_CODER.decode(inStream); + Instant approximateArrivalTimestamp = INSTANT_CODER.decode(inStream); + long subSequenceNumber = VAR_LONG_CODER.decode(inStream); + Instant readTimestamp = INSTANT_CODER.decode(inStream); + String streamName = STRING_CODER.decode(inStream); + String shardId = STRING_CODER.decode(inStream); + return new KinesisRecord(data, sequenceNumber, subSequenceNumber, partitionKey, + approximateArrivalTimestamp, readTimestamp, streamName, shardId + ); + } } diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisSource.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisSource.java index 7e67d070d94b5..362792b941a47 100644 --- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisSource.java +++ b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisSource.java @@ -21,6 +21,7 @@ import static com.google.common.collect.Lists.newArrayList; import java.util.List; + import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.SerializableCoder; import org.apache.beam.sdk.io.UnboundedSource; @@ -28,85 +29,85 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; - /** * Represents source for single stream in Kinesis. */ class KinesisSource extends UnboundedSource { - private static final Logger LOG = LoggerFactory.getLogger(KinesisSource.class); - - private final KinesisClientProvider kinesis; - private CheckpointGenerator initialCheckpointGenerator; - public KinesisSource(KinesisClientProvider kinesis, String streamName, - StartingPoint startingPoint) { - this(kinesis, new DynamicCheckpointGenerator(streamName, startingPoint)); + private static final Logger LOG = LoggerFactory.getLogger(KinesisSource.class); + + private final KinesisClientProvider kinesis; + private CheckpointGenerator initialCheckpointGenerator; + + public KinesisSource(KinesisClientProvider kinesis, String streamName, + StartingPoint startingPoint) { + this(kinesis, new DynamicCheckpointGenerator(streamName, startingPoint)); + } + + private KinesisSource(KinesisClientProvider kinesisClientProvider, + CheckpointGenerator initialCheckpoint) { + this.kinesis = kinesisClientProvider; + this.initialCheckpointGenerator = initialCheckpoint; + validate(); + } + + /** + * Generate splits for reading from the stream. + * Basically, it'll try to evenly split set of shards in the stream into + * {@code desiredNumSplits} partitions. Each partition is then a split. + */ + @Override + public List split(int desiredNumSplits, + PipelineOptions options) throws Exception { + KinesisReaderCheckpoint checkpoint = + initialCheckpointGenerator.generate(SimplifiedKinesisClient.from(kinesis)); + + List sources = newArrayList(); + + for (KinesisReaderCheckpoint partition : checkpoint.splitInto(desiredNumSplits)) { + sources.add(new KinesisSource( + kinesis, + new StaticCheckpointGenerator(partition))); } - - private KinesisSource(KinesisClientProvider kinesisClientProvider, - CheckpointGenerator initialCheckpoint) { - this.kinesis = kinesisClientProvider; - this.initialCheckpointGenerator = initialCheckpoint; - validate(); + return sources; + } + + /** + * Creates reader based on given {@link KinesisReaderCheckpoint}. + * If {@link KinesisReaderCheckpoint} is not given, then we use + * {@code initialCheckpointGenerator} to generate new checkpoint. + */ + @Override + public UnboundedReader createReader(PipelineOptions options, + KinesisReaderCheckpoint checkpointMark) { + + CheckpointGenerator checkpointGenerator = initialCheckpointGenerator; + + if (checkpointMark != null) { + checkpointGenerator = new StaticCheckpointGenerator(checkpointMark); } - /** - * Generate splits for reading from the stream. - * Basically, it'll try to evenly split set of shards in the stream into - * {@code desiredNumSplits} partitions. Each partition is then a split. - */ - @Override - public List split(int desiredNumSplits, - PipelineOptions options) throws Exception { - KinesisReaderCheckpoint checkpoint = - initialCheckpointGenerator.generate(SimplifiedKinesisClient.from(kinesis)); - - List sources = newArrayList(); - - for (KinesisReaderCheckpoint partition : checkpoint.splitInto(desiredNumSplits)) { - sources.add(new KinesisSource( - kinesis, - new StaticCheckpointGenerator(partition))); - } - return sources; - } - - /** - * Creates reader based on given {@link KinesisReaderCheckpoint}. - * If {@link KinesisReaderCheckpoint} is not given, then we use - * {@code initialCheckpointGenerator} to generate new checkpoint. - */ - @Override - public UnboundedReader createReader(PipelineOptions options, - KinesisReaderCheckpoint checkpointMark) { - - CheckpointGenerator checkpointGenerator = initialCheckpointGenerator; - - if (checkpointMark != null) { - checkpointGenerator = new StaticCheckpointGenerator(checkpointMark); - } - - LOG.info("Creating new reader using {}", checkpointGenerator); - - return new KinesisReader( - SimplifiedKinesisClient.from(kinesis), - checkpointGenerator, - this); - } - - @Override - public Coder getCheckpointMarkCoder() { - return SerializableCoder.of(KinesisReaderCheckpoint.class); - } - - @Override - public void validate() { - checkNotNull(kinesis); - checkNotNull(initialCheckpointGenerator); - } - - @Override - public Coder getDefaultOutputCoder() { - return KinesisRecordCoder.of(); - } + LOG.info("Creating new reader using {}", checkpointGenerator); + + return new KinesisReader( + SimplifiedKinesisClient.from(kinesis), + checkpointGenerator, + this); + } + + @Override + public Coder getCheckpointMarkCoder() { + return SerializableCoder.of(KinesisReaderCheckpoint.class); + } + + @Override + public void validate() { + checkNotNull(kinesis); + checkNotNull(initialCheckpointGenerator); + } + + @Override + public Coder getDefaultOutputCoder() { + return KinesisRecordCoder.of(); + } } diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/RecordFilter.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/RecordFilter.java index 40e65fc909b73..eca725c20b2ed 100644 --- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/RecordFilter.java +++ b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/RecordFilter.java @@ -21,7 +21,6 @@ import java.util.List; - /** * Filters out records, which were already processed and checkpointed. * @@ -29,13 +28,14 @@ * accuracy, not with "subSequenceNumber" accuracy. */ class RecordFilter { - public List apply(List records, ShardCheckpoint checkpoint) { - List filteredRecords = newArrayList(); - for (KinesisRecord record : records) { - if (checkpoint.isBeforeOrAt(record)) { - filteredRecords.add(record); - } - } - return filteredRecords; + + public List apply(List records, ShardCheckpoint checkpoint) { + List filteredRecords = newArrayList(); + for (KinesisRecord record : records) { + if (checkpoint.isBeforeOrAt(record)) { + filteredRecords.add(record); + } } + return filteredRecords; + } } diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/RoundRobin.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/RoundRobin.java index e4ff541fdd771..806d982901866 100644 --- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/RoundRobin.java +++ b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/RoundRobin.java @@ -27,27 +27,28 @@ * Very simple implementation of round robin algorithm. */ class RoundRobin implements Iterable { - private final Deque deque; - public RoundRobin(Iterable collection) { - this.deque = newArrayDeque(collection); - checkArgument(!deque.isEmpty(), "Tried to initialize RoundRobin with empty collection"); - } + private final Deque deque; - public T getCurrent() { - return deque.getFirst(); - } + public RoundRobin(Iterable collection) { + this.deque = newArrayDeque(collection); + checkArgument(!deque.isEmpty(), "Tried to initialize RoundRobin with empty collection"); + } - public void moveForward() { - deque.addLast(deque.removeFirst()); - } + public T getCurrent() { + return deque.getFirst(); + } - public int size() { - return deque.size(); - } + public void moveForward() { + deque.addLast(deque.removeFirst()); + } - @Override - public Iterator iterator() { - return deque.iterator(); - } + public int size() { + return deque.size(); + } + + @Override + public Iterator iterator() { + return deque.iterator(); + } } diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/ShardCheckpoint.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/ShardCheckpoint.java index 6aa3504bd2106..95f97b8858735 100644 --- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/ShardCheckpoint.java +++ b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/ShardCheckpoint.java @@ -17,7 +17,6 @@ */ package org.apache.beam.sdk.io.kinesis; - import static com.amazonaws.services.kinesis.model.ShardIteratorType.AFTER_SEQUENCE_NUMBER; import static com.amazonaws.services.kinesis.model.ShardIteratorType.AT_SEQUENCE_NUMBER; import static com.amazonaws.services.kinesis.model.ShardIteratorType.AT_TIMESTAMP; @@ -27,9 +26,10 @@ import com.amazonaws.services.kinesis.clientlibrary.types.ExtendedSequenceNumber; import com.amazonaws.services.kinesis.model.Record; import com.amazonaws.services.kinesis.model.ShardIteratorType; + import java.io.Serializable; -import org.joda.time.Instant; +import org.joda.time.Instant; /** * Checkpoint mark for single shard in the stream. @@ -45,131 +45,132 @@ * This class is immutable. */ class ShardCheckpoint implements Serializable { - private final String streamName; - private final String shardId; - private final String sequenceNumber; - private final ShardIteratorType shardIteratorType; - private final Long subSequenceNumber; - private final Instant timestamp; - - public ShardCheckpoint(String streamName, String shardId, StartingPoint - startingPoint) { - this(streamName, shardId, - ShardIteratorType.fromValue(startingPoint.getPositionName()), - startingPoint.getTimestamp()); - } - - public ShardCheckpoint(String streamName, String shardId, ShardIteratorType - shardIteratorType, Instant timestamp) { - this(streamName, shardId, shardIteratorType, null, null, timestamp); - } - - public ShardCheckpoint(String streamName, String shardId, ShardIteratorType - shardIteratorType, String sequenceNumber, Long subSequenceNumber) { - this(streamName, shardId, shardIteratorType, sequenceNumber, subSequenceNumber, null); - } - - private ShardCheckpoint(String streamName, String shardId, ShardIteratorType shardIteratorType, - String sequenceNumber, Long subSequenceNumber, Instant timestamp) { - this.shardIteratorType = checkNotNull(shardIteratorType, "shardIteratorType"); - this.streamName = checkNotNull(streamName, "streamName"); - this.shardId = checkNotNull(shardId, "shardId"); - if (shardIteratorType == AT_SEQUENCE_NUMBER || shardIteratorType == AFTER_SEQUENCE_NUMBER) { - checkNotNull(sequenceNumber, - "You must provide sequence number for AT_SEQUENCE_NUMBER" - + " or AFTER_SEQUENCE_NUMBER"); - } else { - checkArgument(sequenceNumber == null, - "Sequence number must be null for LATEST, TRIM_HORIZON or AT_TIMESTAMP"); - } - if (shardIteratorType == AT_TIMESTAMP) { - checkNotNull(timestamp, - "You must provide timestamp for AT_SEQUENCE_NUMBER" - + " or AFTER_SEQUENCE_NUMBER"); - } else { - checkArgument(timestamp == null, - "Timestamp must be null for an iterator type other than AT_TIMESTAMP"); - } - - this.subSequenceNumber = subSequenceNumber; - this.sequenceNumber = sequenceNumber; - this.timestamp = timestamp; - } - - /** - * Used to compare {@link ShardCheckpoint} object to {@link KinesisRecord}. Depending - * on the the underlying shardIteratorType, it will either compare the timestamp or the - * {@link ExtendedSequenceNumber}. - * - * @param other - * @return if current checkpoint mark points before or at given {@link ExtendedSequenceNumber} - */ - public boolean isBeforeOrAt(KinesisRecord other) { - if (shardIteratorType == AT_TIMESTAMP) { - return timestamp.compareTo(other.getApproximateArrivalTimestamp()) <= 0; - } - int result = extendedSequenceNumber().compareTo(other.getExtendedSequenceNumber()); - if (result == 0) { - return shardIteratorType == AT_SEQUENCE_NUMBER; - } - return result < 0; - } - - private ExtendedSequenceNumber extendedSequenceNumber() { - String fullSequenceNumber = sequenceNumber; - if (fullSequenceNumber == null) { - fullSequenceNumber = shardIteratorType.toString(); - } - return new ExtendedSequenceNumber(fullSequenceNumber, subSequenceNumber); - } - @Override - public String toString() { - return String.format("Checkpoint %s for stream %s, shard %s: %s", shardIteratorType, - streamName, shardId, - sequenceNumber); + private final String streamName; + private final String shardId; + private final String sequenceNumber; + private final ShardIteratorType shardIteratorType; + private final Long subSequenceNumber; + private final Instant timestamp; + + public ShardCheckpoint(String streamName, String shardId, StartingPoint + startingPoint) { + this(streamName, shardId, + ShardIteratorType.fromValue(startingPoint.getPositionName()), + startingPoint.getTimestamp()); + } + + public ShardCheckpoint(String streamName, String shardId, ShardIteratorType + shardIteratorType, Instant timestamp) { + this(streamName, shardId, shardIteratorType, null, null, timestamp); + } + + public ShardCheckpoint(String streamName, String shardId, ShardIteratorType + shardIteratorType, String sequenceNumber, Long subSequenceNumber) { + this(streamName, shardId, shardIteratorType, sequenceNumber, subSequenceNumber, null); + } + + private ShardCheckpoint(String streamName, String shardId, ShardIteratorType shardIteratorType, + String sequenceNumber, Long subSequenceNumber, Instant timestamp) { + this.shardIteratorType = checkNotNull(shardIteratorType, "shardIteratorType"); + this.streamName = checkNotNull(streamName, "streamName"); + this.shardId = checkNotNull(shardId, "shardId"); + if (shardIteratorType == AT_SEQUENCE_NUMBER || shardIteratorType == AFTER_SEQUENCE_NUMBER) { + checkNotNull(sequenceNumber, + "You must provide sequence number for AT_SEQUENCE_NUMBER" + + " or AFTER_SEQUENCE_NUMBER"); + } else { + checkArgument(sequenceNumber == null, + "Sequence number must be null for LATEST, TRIM_HORIZON or AT_TIMESTAMP"); } - - public ShardRecordsIterator getShardRecordsIterator(SimplifiedKinesisClient kinesis) - throws TransientKinesisException { - return new ShardRecordsIterator(this, kinesis); + if (shardIteratorType == AT_TIMESTAMP) { + checkNotNull(timestamp, + "You must provide timestamp for AT_SEQUENCE_NUMBER" + + " or AFTER_SEQUENCE_NUMBER"); + } else { + checkArgument(timestamp == null, + "Timestamp must be null for an iterator type other than AT_TIMESTAMP"); } - public String getShardIterator(SimplifiedKinesisClient kinesisClient) - throws TransientKinesisException { - if (checkpointIsInTheMiddleOfAUserRecord()) { - return kinesisClient.getShardIterator(streamName, - shardId, AT_SEQUENCE_NUMBER, - sequenceNumber, null); - } - return kinesisClient.getShardIterator(streamName, - shardId, shardIteratorType, - sequenceNumber, timestamp); + this.subSequenceNumber = subSequenceNumber; + this.sequenceNumber = sequenceNumber; + this.timestamp = timestamp; + } + + /** + * Used to compare {@link ShardCheckpoint} object to {@link KinesisRecord}. Depending + * on the the underlying shardIteratorType, it will either compare the timestamp or the + * {@link ExtendedSequenceNumber}. + * + * @param other + * @return if current checkpoint mark points before or at given {@link ExtendedSequenceNumber} + */ + public boolean isBeforeOrAt(KinesisRecord other) { + if (shardIteratorType == AT_TIMESTAMP) { + return timestamp.compareTo(other.getApproximateArrivalTimestamp()) <= 0; } - - private boolean checkpointIsInTheMiddleOfAUserRecord() { - return shardIteratorType == AFTER_SEQUENCE_NUMBER && subSequenceNumber != null; + int result = extendedSequenceNumber().compareTo(other.getExtendedSequenceNumber()); + if (result == 0) { + return shardIteratorType == AT_SEQUENCE_NUMBER; } + return result < 0; + } - /** - * Used to advance checkpoint mark to position after given {@link Record}. - * - * @param record - * @return new checkpoint object pointing directly after given {@link Record} - */ - public ShardCheckpoint moveAfter(KinesisRecord record) { - return new ShardCheckpoint( - streamName, shardId, - AFTER_SEQUENCE_NUMBER, - record.getSequenceNumber(), - record.getSubSequenceNumber()); + private ExtendedSequenceNumber extendedSequenceNumber() { + String fullSequenceNumber = sequenceNumber; + if (fullSequenceNumber == null) { + fullSequenceNumber = shardIteratorType.toString(); } - - public String getStreamName() { - return streamName; - } - - public String getShardId() { - return shardId; + return new ExtendedSequenceNumber(fullSequenceNumber, subSequenceNumber); + } + + @Override + public String toString() { + return String.format("Checkpoint %s for stream %s, shard %s: %s", shardIteratorType, + streamName, shardId, + sequenceNumber); + } + + public ShardRecordsIterator getShardRecordsIterator(SimplifiedKinesisClient kinesis) + throws TransientKinesisException { + return new ShardRecordsIterator(this, kinesis); + } + + public String getShardIterator(SimplifiedKinesisClient kinesisClient) + throws TransientKinesisException { + if (checkpointIsInTheMiddleOfAUserRecord()) { + return kinesisClient.getShardIterator(streamName, + shardId, AT_SEQUENCE_NUMBER, + sequenceNumber, null); } + return kinesisClient.getShardIterator(streamName, + shardId, shardIteratorType, + sequenceNumber, timestamp); + } + + private boolean checkpointIsInTheMiddleOfAUserRecord() { + return shardIteratorType == AFTER_SEQUENCE_NUMBER && subSequenceNumber != null; + } + + /** + * Used to advance checkpoint mark to position after given {@link Record}. + * + * @param record + * @return new checkpoint object pointing directly after given {@link Record} + */ + public ShardCheckpoint moveAfter(KinesisRecord record) { + return new ShardCheckpoint( + streamName, shardId, + AFTER_SEQUENCE_NUMBER, + record.getSequenceNumber(), + record.getSubSequenceNumber()); + } + + public String getStreamName() { + return streamName; + } + + public String getShardId() { + return shardId; + } } diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/ShardRecordsIterator.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/ShardRecordsIterator.java index 872f60453dc8a..a69c6c1e1f791 100644 --- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/ShardRecordsIterator.java +++ b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/ShardRecordsIterator.java @@ -21,7 +21,9 @@ import static com.google.common.collect.Queues.newArrayDeque; import com.amazonaws.services.kinesis.model.ExpiredIteratorException; + import java.util.Deque; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -31,68 +33,68 @@ * Then the caller of {@link ShardRecordsIterator#next()} can read from queue one by one. */ class ShardRecordsIterator { - private static final Logger LOG = LoggerFactory.getLogger(ShardRecordsIterator.class); - private final SimplifiedKinesisClient kinesis; - private final RecordFilter filter; - private ShardCheckpoint checkpoint; - private String shardIterator; - private Deque data = newArrayDeque(); + private static final Logger LOG = LoggerFactory.getLogger(ShardRecordsIterator.class); - public ShardRecordsIterator(final ShardCheckpoint initialCheckpoint, - SimplifiedKinesisClient simplifiedKinesisClient) throws - TransientKinesisException { - this(initialCheckpoint, simplifiedKinesisClient, new RecordFilter()); - } + private final SimplifiedKinesisClient kinesis; + private final RecordFilter filter; + private ShardCheckpoint checkpoint; + private String shardIterator; + private Deque data = newArrayDeque(); - public ShardRecordsIterator(final ShardCheckpoint initialCheckpoint, - SimplifiedKinesisClient simplifiedKinesisClient, - RecordFilter filter) throws - TransientKinesisException { + public ShardRecordsIterator(final ShardCheckpoint initialCheckpoint, + SimplifiedKinesisClient simplifiedKinesisClient) throws + TransientKinesisException { + this(initialCheckpoint, simplifiedKinesisClient, new RecordFilter()); + } - this.checkpoint = checkNotNull(initialCheckpoint, "initialCheckpoint"); - this.filter = checkNotNull(filter, "filter"); - this.kinesis = checkNotNull(simplifiedKinesisClient, "simplifiedKinesisClient"); - shardIterator = checkpoint.getShardIterator(kinesis); - } + public ShardRecordsIterator(final ShardCheckpoint initialCheckpoint, + SimplifiedKinesisClient simplifiedKinesisClient, + RecordFilter filter) throws + TransientKinesisException { - /** - * Returns record if there's any present. - * Returns absent() if there are no new records at this time in the shard. - */ - public CustomOptional next() throws TransientKinesisException { - readMoreIfNecessary(); + this.checkpoint = checkNotNull(initialCheckpoint, "initialCheckpoint"); + this.filter = checkNotNull(filter, "filter"); + this.kinesis = checkNotNull(simplifiedKinesisClient, "simplifiedKinesisClient"); + shardIterator = checkpoint.getShardIterator(kinesis); + } - if (data.isEmpty()) { - return CustomOptional.absent(); - } else { - KinesisRecord record = data.removeFirst(); - checkpoint = checkpoint.moveAfter(record); - return CustomOptional.of(record); - } - } + /** + * Returns record if there's any present. + * Returns absent() if there are no new records at this time in the shard. + */ + public CustomOptional next() throws TransientKinesisException { + readMoreIfNecessary(); - private void readMoreIfNecessary() throws TransientKinesisException { - if (data.isEmpty()) { - GetKinesisRecordsResult response; - try { - response = kinesis.getRecords(shardIterator, checkpoint.getStreamName(), - checkpoint.getShardId()); - } catch (ExpiredIteratorException e) { - LOG.info("Refreshing expired iterator", e); - shardIterator = checkpoint.getShardIterator(kinesis); - response = kinesis.getRecords(shardIterator, checkpoint.getStreamName(), - checkpoint.getShardId()); - } - LOG.debug("Fetched {} new records", response.getRecords().size()); - shardIterator = response.getNextShardIterator(); - data.addAll(filter.apply(response.getRecords(), checkpoint)); - } + if (data.isEmpty()) { + return CustomOptional.absent(); + } else { + KinesisRecord record = data.removeFirst(); + checkpoint = checkpoint.moveAfter(record); + return CustomOptional.of(record); } + } - public ShardCheckpoint getCheckpoint() { - return checkpoint; + private void readMoreIfNecessary() throws TransientKinesisException { + if (data.isEmpty()) { + GetKinesisRecordsResult response; + try { + response = kinesis.getRecords(shardIterator, checkpoint.getStreamName(), + checkpoint.getShardId()); + } catch (ExpiredIteratorException e) { + LOG.info("Refreshing expired iterator", e); + shardIterator = checkpoint.getShardIterator(kinesis); + response = kinesis.getRecords(shardIterator, checkpoint.getStreamName(), + checkpoint.getShardId()); + } + LOG.debug("Fetched {} new records", response.getRecords().size()); + shardIterator = response.getNextShardIterator(); + data.addAll(filter.apply(response.getRecords(), checkpoint)); } + } + public ShardCheckpoint getCheckpoint() { + return checkpoint; + } } diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/SimplifiedKinesisClient.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/SimplifiedKinesisClient.java index 3e3984a8c91c6..80c950f9dc82d 100644 --- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/SimplifiedKinesisClient.java +++ b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/SimplifiedKinesisClient.java @@ -17,7 +17,6 @@ */ package org.apache.beam.sdk.io.kinesis; - import com.amazonaws.AmazonServiceException; import com.amazonaws.services.kinesis.AmazonKinesis; import com.amazonaws.services.kinesis.clientlibrary.types.UserRecord; @@ -31,9 +30,11 @@ import com.amazonaws.services.kinesis.model.ShardIteratorType; import com.amazonaws.services.kinesis.model.StreamDescription; import com.google.common.collect.Lists; + import java.util.Date; import java.util.List; import java.util.concurrent.Callable; + import org.joda.time.Instant; /** @@ -41,117 +42,121 @@ * proper error handling. */ class SimplifiedKinesisClient { - private final AmazonKinesis kinesis; - public SimplifiedKinesisClient(AmazonKinesis kinesis) { - this.kinesis = kinesis; - } + private final AmazonKinesis kinesis; - public static SimplifiedKinesisClient from(KinesisClientProvider provider) { - return new SimplifiedKinesisClient(provider.get()); - } + public SimplifiedKinesisClient(AmazonKinesis kinesis) { + this.kinesis = kinesis; + } - public String getShardIterator(final String streamName, final String shardId, - final ShardIteratorType shardIteratorType, - final String startingSequenceNumber, final Instant timestamp) - throws TransientKinesisException { - final Date date = timestamp != null ? timestamp.toDate() : null; - return wrapExceptions(new Callable() { - @Override - public String call() throws Exception { - return kinesis.getShardIterator(new GetShardIteratorRequest() - .withStreamName(streamName) - .withShardId(shardId) - .withShardIteratorType(shardIteratorType) - .withStartingSequenceNumber(startingSequenceNumber) - .withTimestamp(date) - ).getShardIterator(); - } - }); - } + public static SimplifiedKinesisClient from(KinesisClientProvider provider) { + return new SimplifiedKinesisClient(provider.get()); + } - public List listShards(final String streamName) throws TransientKinesisException { - return wrapExceptions(new Callable>() { - @Override - public List call() throws Exception { - List shards = Lists.newArrayList(); - String lastShardId = null; - - StreamDescription description; - do { - description = kinesis.describeStream(streamName, lastShardId) - .getStreamDescription(); - - shards.addAll(description.getShards()); - lastShardId = shards.get(shards.size() - 1).getShardId(); - } while (description.getHasMoreShards()); - - return shards; - } - }); - } + public String getShardIterator(final String streamName, final String shardId, + final ShardIteratorType shardIteratorType, + final String startingSequenceNumber, final Instant timestamp) + throws TransientKinesisException { + final Date date = timestamp != null ? timestamp.toDate() : null; + return wrapExceptions(new Callable() { - /** - * Gets records from Kinesis and deaggregates them if needed. - * - * @return list of deaggregated records - * @throws TransientKinesisException - in case of recoverable situation - */ - public GetKinesisRecordsResult getRecords(String shardIterator, String streamName, - String shardId) throws TransientKinesisException { - return getRecords(shardIterator, streamName, shardId, null); - } + @Override + public String call() throws Exception { + return kinesis.getShardIterator(new GetShardIteratorRequest() + .withStreamName(streamName) + .withShardId(shardId) + .withShardIteratorType(shardIteratorType) + .withStartingSequenceNumber(startingSequenceNumber) + .withTimestamp(date) + ).getShardIterator(); + } + }); + } - /** - * Gets records from Kinesis and deaggregates them if needed. - * - * @return list of deaggregated records - * @throws TransientKinesisException - in case of recoverable situation - */ - public GetKinesisRecordsResult getRecords(final String shardIterator, final String streamName, - final String shardId, final Integer limit) - throws - TransientKinesisException { - return wrapExceptions(new Callable() { - @Override - public GetKinesisRecordsResult call() throws Exception { - GetRecordsResult response = kinesis.getRecords(new GetRecordsRequest() - .withShardIterator(shardIterator) - .withLimit(limit)); - return new GetKinesisRecordsResult( - UserRecord.deaggregate(response.getRecords()), - response.getNextShardIterator(), - streamName, shardId); - } - }); - } + public List listShards(final String streamName) throws TransientKinesisException { + return wrapExceptions(new Callable>() { + + @Override + public List call() throws Exception { + List shards = Lists.newArrayList(); + String lastShardId = null; + + StreamDescription description; + do { + description = kinesis.describeStream(streamName, lastShardId) + .getStreamDescription(); + + shards.addAll(description.getShards()); + lastShardId = shards.get(shards.size() - 1).getShardId(); + } while (description.getHasMoreShards()); + + return shards; + } + }); + } + + /** + * Gets records from Kinesis and deaggregates them if needed. + * + * @return list of deaggregated records + * @throws TransientKinesisException - in case of recoverable situation + */ + public GetKinesisRecordsResult getRecords(String shardIterator, String streamName, + String shardId) throws TransientKinesisException { + return getRecords(shardIterator, streamName, shardId, null); + } + + /** + * Gets records from Kinesis and deaggregates them if needed. + * + * @return list of deaggregated records + * @throws TransientKinesisException - in case of recoverable situation + */ + public GetKinesisRecordsResult getRecords(final String shardIterator, final String streamName, + final String shardId, final Integer limit) + throws + TransientKinesisException { + return wrapExceptions(new Callable() { + + @Override + public GetKinesisRecordsResult call() throws Exception { + GetRecordsResult response = kinesis.getRecords(new GetRecordsRequest() + .withShardIterator(shardIterator) + .withLimit(limit)); + return new GetKinesisRecordsResult( + UserRecord.deaggregate(response.getRecords()), + response.getNextShardIterator(), + streamName, shardId); + } + }); + } - /** - * Wraps Amazon specific exceptions into more friendly format. - * - * @throws TransientKinesisException - in case of recoverable situation, i.e. - * the request rate is too high, Kinesis remote service - * failed, network issue, etc. - * @throws ExpiredIteratorException - if iterator needs to be refreshed - * @throws RuntimeException - in all other cases - */ - private T wrapExceptions(Callable callable) throws TransientKinesisException { - try { - return callable.call(); - } catch (ExpiredIteratorException e) { - throw e; - } catch (LimitExceededException | ProvisionedThroughputExceededException e) { - throw new TransientKinesisException( - "Too many requests to Kinesis. Wait some time and retry.", e); - } catch (AmazonServiceException e) { - if (e.getErrorType() == AmazonServiceException.ErrorType.Service) { - throw new TransientKinesisException( - "Kinesis backend failed. Wait some time and retry.", e); - } - throw new RuntimeException("Kinesis client side failure", e); - } catch (Exception e) { - throw new RuntimeException("Unknown kinesis failure, when trying to reach kinesis", e); - } + /** + * Wraps Amazon specific exceptions into more friendly format. + * + * @throws TransientKinesisException - in case of recoverable situation, i.e. + * the request rate is too high, Kinesis remote service + * failed, network issue, etc. + * @throws ExpiredIteratorException - if iterator needs to be refreshed + * @throws RuntimeException - in all other cases + */ + private T wrapExceptions(Callable callable) throws TransientKinesisException { + try { + return callable.call(); + } catch (ExpiredIteratorException e) { + throw e; + } catch (LimitExceededException | ProvisionedThroughputExceededException e) { + throw new TransientKinesisException( + "Too many requests to Kinesis. Wait some time and retry.", e); + } catch (AmazonServiceException e) { + if (e.getErrorType() == AmazonServiceException.ErrorType.Service) { + throw new TransientKinesisException( + "Kinesis backend failed. Wait some time and retry.", e); + } + throw new RuntimeException("Kinesis client side failure", e); + } catch (Exception e) { + throw new RuntimeException("Unknown kinesis failure, when trying to reach kinesis", e); } + } } diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/StartingPoint.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/StartingPoint.java index d8842c4a6cd39..f9298fa54debf 100644 --- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/StartingPoint.java +++ b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/StartingPoint.java @@ -17,13 +17,14 @@ */ package org.apache.beam.sdk.io.kinesis; - import static com.google.common.base.Preconditions.checkNotNull; import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream; import com.amazonaws.services.kinesis.model.ShardIteratorType; + import java.io.Serializable; import java.util.Objects; + import org.joda.time.Instant; /** @@ -32,54 +33,55 @@ * in which case the reader will start reading at the specified point in time. */ class StartingPoint implements Serializable { - private final InitialPositionInStream position; - private final Instant timestamp; - public StartingPoint(InitialPositionInStream position) { - this.position = checkNotNull(position, "position"); - this.timestamp = null; - } + private final InitialPositionInStream position; + private final Instant timestamp; - public StartingPoint(Instant timestamp) { - this.timestamp = checkNotNull(timestamp, "timestamp"); - this.position = null; - } + public StartingPoint(InitialPositionInStream position) { + this.position = checkNotNull(position, "position"); + this.timestamp = null; + } - public InitialPositionInStream getPosition() { - return position; - } + public StartingPoint(Instant timestamp) { + this.timestamp = checkNotNull(timestamp, "timestamp"); + this.position = null; + } - public String getPositionName() { - return position != null ? position.name() : ShardIteratorType.AT_TIMESTAMP.name(); - } + public InitialPositionInStream getPosition() { + return position; + } - public Instant getTimestamp() { - return timestamp != null ? timestamp : null; - } + public String getPositionName() { + return position != null ? position.name() : ShardIteratorType.AT_TIMESTAMP.name(); + } - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - StartingPoint that = (StartingPoint) o; - return position == that.position && Objects.equals(timestamp, that.timestamp); - } + public Instant getTimestamp() { + return timestamp != null ? timestamp : null; + } - @Override - public int hashCode() { - return Objects.hash(position, timestamp); + @Override + public boolean equals(Object o) { + if (this == o) { + return true; } + if (o == null || getClass() != o.getClass()) { + return false; + } + StartingPoint that = (StartingPoint) o; + return position == that.position && Objects.equals(timestamp, that.timestamp); + } + + @Override + public int hashCode() { + return Objects.hash(position, timestamp); + } - @Override - public String toString() { - if (timestamp == null) { - return position.toString(); - } else { - return "Starting at timestamp " + timestamp; - } + @Override + public String toString() { + if (timestamp == null) { + return position.toString(); + } else { + return "Starting at timestamp " + timestamp; } + } } diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/StaticCheckpointGenerator.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/StaticCheckpointGenerator.java index 22dc9734f441e..1ec865d66ad6e 100644 --- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/StaticCheckpointGenerator.java +++ b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/StaticCheckpointGenerator.java @@ -23,20 +23,21 @@ * Always returns the same instance of checkpoint. */ class StaticCheckpointGenerator implements CheckpointGenerator { - private final KinesisReaderCheckpoint checkpoint; - public StaticCheckpointGenerator(KinesisReaderCheckpoint checkpoint) { - checkNotNull(checkpoint, "checkpoint"); - this.checkpoint = checkpoint; - } + private final KinesisReaderCheckpoint checkpoint; - @Override - public KinesisReaderCheckpoint generate(SimplifiedKinesisClient client) { - return checkpoint; - } + public StaticCheckpointGenerator(KinesisReaderCheckpoint checkpoint) { + checkNotNull(checkpoint, "checkpoint"); + this.checkpoint = checkpoint; + } - @Override - public String toString() { - return checkpoint.toString(); - } + @Override + public KinesisReaderCheckpoint generate(SimplifiedKinesisClient client) { + return checkpoint; + } + + @Override + public String toString() { + return checkpoint.toString(); + } } diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/TransientKinesisException.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/TransientKinesisException.java index 57ad8a89103e9..68ca0d7c5ac8e 100644 --- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/TransientKinesisException.java +++ b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/TransientKinesisException.java @@ -23,7 +23,8 @@ * A transient exception thrown by Kinesis. */ class TransientKinesisException extends Exception { - public TransientKinesisException(String s, AmazonServiceException e) { - super(s, e); - } + + public TransientKinesisException(String s, AmazonServiceException e) { + super(s, e); + } } diff --git a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/AmazonKinesisMock.java b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/AmazonKinesisMock.java index 046c9d9126d2f..994d6e3c3f751 100644 --- a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/AmazonKinesisMock.java +++ b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/AmazonKinesisMock.java @@ -66,10 +66,12 @@ import com.amazonaws.services.kinesis.model.SplitShardResult; import com.amazonaws.services.kinesis.model.StreamDescription; import com.google.common.base.Function; + import java.io.Serializable; import java.nio.ByteBuffer; import java.util.List; import javax.annotation.Nullable; + import org.apache.commons.lang.builder.EqualsBuilder; import org.joda.time.Instant; @@ -78,298 +80,301 @@ */ class AmazonKinesisMock implements AmazonKinesis { - static class TestData implements Serializable { - private final String data; - private final Instant arrivalTimestamp; - private final String sequenceNumber; - - public TestData(KinesisRecord record) { - this(new String(record.getData().array()), - record.getApproximateArrivalTimestamp(), - record.getSequenceNumber()); - } - - public TestData(String data, Instant arrivalTimestamp, String sequenceNumber) { - this.data = data; - this.arrivalTimestamp = arrivalTimestamp; - this.sequenceNumber = sequenceNumber; - } - - public Record convertToRecord() { - return new Record(). - withApproximateArrivalTimestamp(arrivalTimestamp.toDate()). - withData(ByteBuffer.wrap(data.getBytes())). - withSequenceNumber(sequenceNumber). - withPartitionKey(""); - } - - @Override - public boolean equals(Object obj) { - return EqualsBuilder.reflectionEquals(this, obj); - } - - @Override - public int hashCode() { - return reflectionHashCode(this); - } - } - - static class Provider implements KinesisClientProvider { - - private final List> shardedData; - private final int numberOfRecordsPerGet; - - public Provider(List> shardedData, int numberOfRecordsPerGet) { - this.shardedData = shardedData; - this.numberOfRecordsPerGet = numberOfRecordsPerGet; - } - - @Override - public AmazonKinesis get() { - return new AmazonKinesisMock(transform(shardedData, - new Function, List>() { - @Override - public List apply(@Nullable List testDatas) { - return transform(testDatas, new Function() { - @Override - public Record apply(@Nullable TestData testData) { - return testData.convertToRecord(); - } - }); - } - }), numberOfRecordsPerGet); - } - } - - private final List> shardedData; - private final int numberOfRecordsPerGet; - - public AmazonKinesisMock(List> shardedData, int numberOfRecordsPerGet) { - this.shardedData = shardedData; - this.numberOfRecordsPerGet = numberOfRecordsPerGet; - } - - @Override - public GetRecordsResult getRecords(GetRecordsRequest getRecordsRequest) { - String[] shardIteratorParts = getRecordsRequest.getShardIterator().split(":"); - int shardId = parseInt(shardIteratorParts[0]); - int startingRecord = parseInt(shardIteratorParts[1]); - List shardData = shardedData.get(shardId); - - int toIndex = min(startingRecord + numberOfRecordsPerGet, shardData.size()); - int fromIndex = min(startingRecord, toIndex); - return new GetRecordsResult(). - withRecords(shardData.subList(fromIndex, toIndex)). - withNextShardIterator(String.format("%s:%s", shardId, toIndex)); - } - - @Override - public GetShardIteratorResult getShardIterator( - GetShardIteratorRequest getShardIteratorRequest) { - ShardIteratorType shardIteratorType = ShardIteratorType.fromValue( - getShardIteratorRequest.getShardIteratorType()); - - String shardIterator; - if (shardIteratorType == ShardIteratorType.TRIM_HORIZON) { - shardIterator = String.format("%s:%s", getShardIteratorRequest.getShardId(), 0); - } else { - throw new RuntimeException("Not implemented"); - } - - return new GetShardIteratorResult().withShardIterator(shardIterator); - } - - @Override - public DescribeStreamResult describeStream(String streamName, String exclusiveStartShardId) { - int nextShardId = 0; - if (exclusiveStartShardId != null) { - nextShardId = parseInt(exclusiveStartShardId) + 1; - } - boolean hasMoreShards = nextShardId + 1 < shardedData.size(); - - List shards = newArrayList(); - if (nextShardId < shardedData.size()) { - shards.add(new Shard().withShardId(Integer.toString(nextShardId))); - } - - return new DescribeStreamResult().withStreamDescription( - new StreamDescription().withHasMoreShards(hasMoreShards).withShards(shards) - ); - } - - @Override - public void setEndpoint(String endpoint) { - - } - - @Override - public void setRegion(Region region) { - - } - - @Override - public AddTagsToStreamResult addTagsToStream(AddTagsToStreamRequest addTagsToStreamRequest) { - throw new RuntimeException("Not implemented"); - } - - @Override - public CreateStreamResult createStream(CreateStreamRequest createStreamRequest) { - throw new RuntimeException("Not implemented"); - } - - @Override - public CreateStreamResult createStream(String streamName, Integer shardCount) { - throw new RuntimeException("Not implemented"); - } - - @Override - public DecreaseStreamRetentionPeriodResult decreaseStreamRetentionPeriod( - DecreaseStreamRetentionPeriodRequest decreaseStreamRetentionPeriodRequest) { - throw new RuntimeException("Not implemented"); - } - - @Override - public DeleteStreamResult deleteStream(DeleteStreamRequest deleteStreamRequest) { - throw new RuntimeException("Not implemented"); - } - - @Override - public DeleteStreamResult deleteStream(String streamName) { - throw new RuntimeException("Not implemented"); - } - - @Override - public DescribeStreamResult describeStream(DescribeStreamRequest describeStreamRequest) { - throw new RuntimeException("Not implemented"); - } - - @Override - public DescribeStreamResult describeStream(String streamName) { - - throw new RuntimeException("Not implemented"); - } - - @Override - public DescribeStreamResult describeStream(String streamName, - Integer limit, String exclusiveStartShardId) { - throw new RuntimeException("Not implemented"); - } - - @Override - public DisableEnhancedMonitoringResult disableEnhancedMonitoring( - DisableEnhancedMonitoringRequest disableEnhancedMonitoringRequest) { - throw new RuntimeException("Not implemented"); - } - - @Override - public EnableEnhancedMonitoringResult enableEnhancedMonitoring( - EnableEnhancedMonitoringRequest enableEnhancedMonitoringRequest) { - throw new RuntimeException("Not implemented"); - } - - @Override - public GetShardIteratorResult getShardIterator(String streamName, - String shardId, - String shardIteratorType) { - throw new RuntimeException("Not implemented"); - } - - @Override - public GetShardIteratorResult getShardIterator(String streamName, - String shardId, - String shardIteratorType, - String startingSequenceNumber) { - throw new RuntimeException("Not implemented"); - } - - @Override - public IncreaseStreamRetentionPeriodResult increaseStreamRetentionPeriod( - IncreaseStreamRetentionPeriodRequest increaseStreamRetentionPeriodRequest) { - throw new RuntimeException("Not implemented"); - } + static class TestData implements Serializable { - @Override - public ListStreamsResult listStreams(ListStreamsRequest listStreamsRequest) { - throw new RuntimeException("Not implemented"); - } + private final String data; + private final Instant arrivalTimestamp; + private final String sequenceNumber; - @Override - public ListStreamsResult listStreams() { - throw new RuntimeException("Not implemented"); + public TestData(KinesisRecord record) { + this(new String(record.getData().array()), + record.getApproximateArrivalTimestamp(), + record.getSequenceNumber()); } - @Override - public ListStreamsResult listStreams(String exclusiveStartStreamName) { - throw new RuntimeException("Not implemented"); + public TestData(String data, Instant arrivalTimestamp, String sequenceNumber) { + this.data = data; + this.arrivalTimestamp = arrivalTimestamp; + this.sequenceNumber = sequenceNumber; } - @Override - public ListStreamsResult listStreams(Integer limit, String exclusiveStartStreamName) { - throw new RuntimeException("Not implemented"); + public Record convertToRecord() { + return new Record(). + withApproximateArrivalTimestamp(arrivalTimestamp.toDate()). + withData(ByteBuffer.wrap(data.getBytes())). + withSequenceNumber(sequenceNumber). + withPartitionKey(""); } @Override - public ListTagsForStreamResult listTagsForStream( - ListTagsForStreamRequest listTagsForStreamRequest) { - throw new RuntimeException("Not implemented"); + public boolean equals(Object obj) { + return EqualsBuilder.reflectionEquals(this, obj); } @Override - public MergeShardsResult mergeShards(MergeShardsRequest mergeShardsRequest) { - throw new RuntimeException("Not implemented"); + public int hashCode() { + return reflectionHashCode(this); } + } - @Override - public MergeShardsResult mergeShards(String streamName, - String shardToMerge, String adjacentShardToMerge) { - throw new RuntimeException("Not implemented"); - } + static class Provider implements KinesisClientProvider { - @Override - public PutRecordResult putRecord(PutRecordRequest putRecordRequest) { - throw new RuntimeException("Not implemented"); - } + private final List> shardedData; + private final int numberOfRecordsPerGet; - @Override - public PutRecordResult putRecord(String streamName, ByteBuffer data, String partitionKey) { - throw new RuntimeException("Not implemented"); + public Provider(List> shardedData, int numberOfRecordsPerGet) { + this.shardedData = shardedData; + this.numberOfRecordsPerGet = numberOfRecordsPerGet; } @Override - public PutRecordResult putRecord(String streamName, ByteBuffer data, - String partitionKey, String sequenceNumberForOrdering) { - throw new RuntimeException("Not implemented"); - } + public AmazonKinesis get() { + return new AmazonKinesisMock(transform(shardedData, + new Function, List>() { - @Override - public PutRecordsResult putRecords(PutRecordsRequest putRecordsRequest) { - throw new RuntimeException("Not implemented"); - } + @Override + public List apply(@Nullable List testDatas) { + return transform(testDatas, new Function() { - @Override - public RemoveTagsFromStreamResult removeTagsFromStream( - RemoveTagsFromStreamRequest removeTagsFromStreamRequest) { - throw new RuntimeException("Not implemented"); + @Override + public Record apply(@Nullable TestData testData) { + return testData.convertToRecord(); + } + }); + } + }), numberOfRecordsPerGet); } + } - @Override - public SplitShardResult splitShard(SplitShardRequest splitShardRequest) { - throw new RuntimeException("Not implemented"); - } + private final List> shardedData; + private final int numberOfRecordsPerGet; - @Override - public SplitShardResult splitShard(String streamName, - String shardToSplit, String newStartingHashKey) { - throw new RuntimeException("Not implemented"); - } + public AmazonKinesisMock(List> shardedData, int numberOfRecordsPerGet) { + this.shardedData = shardedData; + this.numberOfRecordsPerGet = numberOfRecordsPerGet; + } - @Override - public void shutdown() { + @Override + public GetRecordsResult getRecords(GetRecordsRequest getRecordsRequest) { + String[] shardIteratorParts = getRecordsRequest.getShardIterator().split(":"); + int shardId = parseInt(shardIteratorParts[0]); + int startingRecord = parseInt(shardIteratorParts[1]); + List shardData = shardedData.get(shardId); - } + int toIndex = min(startingRecord + numberOfRecordsPerGet, shardData.size()); + int fromIndex = min(startingRecord, toIndex); + return new GetRecordsResult(). + withRecords(shardData.subList(fromIndex, toIndex)). + withNextShardIterator(String.format("%s:%s", shardId, toIndex)); + } - @Override - public ResponseMetadata getCachedResponseMetadata(AmazonWebServiceRequest request) { - throw new RuntimeException("Not implemented"); - } + @Override + public GetShardIteratorResult getShardIterator( + GetShardIteratorRequest getShardIteratorRequest) { + ShardIteratorType shardIteratorType = ShardIteratorType.fromValue( + getShardIteratorRequest.getShardIteratorType()); + + String shardIterator; + if (shardIteratorType == ShardIteratorType.TRIM_HORIZON) { + shardIterator = String.format("%s:%s", getShardIteratorRequest.getShardId(), 0); + } else { + throw new RuntimeException("Not implemented"); + } + + return new GetShardIteratorResult().withShardIterator(shardIterator); + } + + @Override + public DescribeStreamResult describeStream(String streamName, String exclusiveStartShardId) { + int nextShardId = 0; + if (exclusiveStartShardId != null) { + nextShardId = parseInt(exclusiveStartShardId) + 1; + } + boolean hasMoreShards = nextShardId + 1 < shardedData.size(); + + List shards = newArrayList(); + if (nextShardId < shardedData.size()) { + shards.add(new Shard().withShardId(Integer.toString(nextShardId))); + } + + return new DescribeStreamResult().withStreamDescription( + new StreamDescription().withHasMoreShards(hasMoreShards).withShards(shards) + ); + } + + @Override + public void setEndpoint(String endpoint) { + + } + + @Override + public void setRegion(Region region) { + + } + + @Override + public AddTagsToStreamResult addTagsToStream(AddTagsToStreamRequest addTagsToStreamRequest) { + throw new RuntimeException("Not implemented"); + } + + @Override + public CreateStreamResult createStream(CreateStreamRequest createStreamRequest) { + throw new RuntimeException("Not implemented"); + } + + @Override + public CreateStreamResult createStream(String streamName, Integer shardCount) { + throw new RuntimeException("Not implemented"); + } + + @Override + public DecreaseStreamRetentionPeriodResult decreaseStreamRetentionPeriod( + DecreaseStreamRetentionPeriodRequest decreaseStreamRetentionPeriodRequest) { + throw new RuntimeException("Not implemented"); + } + + @Override + public DeleteStreamResult deleteStream(DeleteStreamRequest deleteStreamRequest) { + throw new RuntimeException("Not implemented"); + } + + @Override + public DeleteStreamResult deleteStream(String streamName) { + throw new RuntimeException("Not implemented"); + } + + @Override + public DescribeStreamResult describeStream(DescribeStreamRequest describeStreamRequest) { + throw new RuntimeException("Not implemented"); + } + + @Override + public DescribeStreamResult describeStream(String streamName) { + + throw new RuntimeException("Not implemented"); + } + + @Override + public DescribeStreamResult describeStream(String streamName, + Integer limit, String exclusiveStartShardId) { + throw new RuntimeException("Not implemented"); + } + + @Override + public DisableEnhancedMonitoringResult disableEnhancedMonitoring( + DisableEnhancedMonitoringRequest disableEnhancedMonitoringRequest) { + throw new RuntimeException("Not implemented"); + } + + @Override + public EnableEnhancedMonitoringResult enableEnhancedMonitoring( + EnableEnhancedMonitoringRequest enableEnhancedMonitoringRequest) { + throw new RuntimeException("Not implemented"); + } + + @Override + public GetShardIteratorResult getShardIterator(String streamName, + String shardId, + String shardIteratorType) { + throw new RuntimeException("Not implemented"); + } + + @Override + public GetShardIteratorResult getShardIterator(String streamName, + String shardId, + String shardIteratorType, + String startingSequenceNumber) { + throw new RuntimeException("Not implemented"); + } + + @Override + public IncreaseStreamRetentionPeriodResult increaseStreamRetentionPeriod( + IncreaseStreamRetentionPeriodRequest increaseStreamRetentionPeriodRequest) { + throw new RuntimeException("Not implemented"); + } + + @Override + public ListStreamsResult listStreams(ListStreamsRequest listStreamsRequest) { + throw new RuntimeException("Not implemented"); + } + + @Override + public ListStreamsResult listStreams() { + throw new RuntimeException("Not implemented"); + } + + @Override + public ListStreamsResult listStreams(String exclusiveStartStreamName) { + throw new RuntimeException("Not implemented"); + } + + @Override + public ListStreamsResult listStreams(Integer limit, String exclusiveStartStreamName) { + throw new RuntimeException("Not implemented"); + } + + @Override + public ListTagsForStreamResult listTagsForStream( + ListTagsForStreamRequest listTagsForStreamRequest) { + throw new RuntimeException("Not implemented"); + } + + @Override + public MergeShardsResult mergeShards(MergeShardsRequest mergeShardsRequest) { + throw new RuntimeException("Not implemented"); + } + + @Override + public MergeShardsResult mergeShards(String streamName, + String shardToMerge, String adjacentShardToMerge) { + throw new RuntimeException("Not implemented"); + } + + @Override + public PutRecordResult putRecord(PutRecordRequest putRecordRequest) { + throw new RuntimeException("Not implemented"); + } + + @Override + public PutRecordResult putRecord(String streamName, ByteBuffer data, String partitionKey) { + throw new RuntimeException("Not implemented"); + } + + @Override + public PutRecordResult putRecord(String streamName, ByteBuffer data, + String partitionKey, String sequenceNumberForOrdering) { + throw new RuntimeException("Not implemented"); + } + + @Override + public PutRecordsResult putRecords(PutRecordsRequest putRecordsRequest) { + throw new RuntimeException("Not implemented"); + } + + @Override + public RemoveTagsFromStreamResult removeTagsFromStream( + RemoveTagsFromStreamRequest removeTagsFromStreamRequest) { + throw new RuntimeException("Not implemented"); + } + + @Override + public SplitShardResult splitShard(SplitShardRequest splitShardRequest) { + throw new RuntimeException("Not implemented"); + } + + @Override + public SplitShardResult splitShard(String streamName, + String shardToSplit, String newStartingHashKey) { + throw new RuntimeException("Not implemented"); + } + + @Override + public void shutdown() { + + } + + @Override + public ResponseMetadata getCachedResponseMetadata(AmazonWebServiceRequest request) { + throw new RuntimeException("Not implemented"); + } } diff --git a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/CustomOptionalTest.java b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/CustomOptionalTest.java index 00acffeae6191..0b16bb77ba0d0 100644 --- a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/CustomOptionalTest.java +++ b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/CustomOptionalTest.java @@ -18,24 +18,27 @@ package org.apache.beam.sdk.io.kinesis; import com.google.common.testing.EqualsTester; + import java.util.NoSuchElementException; + import org.junit.Test; /** * Tests {@link CustomOptional}. */ public class CustomOptionalTest { - @Test(expected = NoSuchElementException.class) - public void absentThrowsNoSuchElementExceptionOnGet() { - CustomOptional.absent().get(); - } - @Test - public void testEqualsAndHashCode() { - new EqualsTester() - .addEqualityGroup(CustomOptional.absent(), CustomOptional.absent()) - .addEqualityGroup(CustomOptional.of(3), CustomOptional.of(3)) - .addEqualityGroup(CustomOptional.of(11)) - .addEqualityGroup(CustomOptional.of("3")).testEquals(); - } + @Test(expected = NoSuchElementException.class) + public void absentThrowsNoSuchElementExceptionOnGet() { + CustomOptional.absent().get(); + } + + @Test + public void testEqualsAndHashCode() { + new EqualsTester() + .addEqualityGroup(CustomOptional.absent(), CustomOptional.absent()) + .addEqualityGroup(CustomOptional.of(3), CustomOptional.of(3)) + .addEqualityGroup(CustomOptional.of(11)) + .addEqualityGroup(CustomOptional.of("3")).testEquals(); + } } diff --git a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/DynamicCheckpointGeneratorTest.java b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/DynamicCheckpointGeneratorTest.java index c92ac9a4b2939..1bb97176d83ce 100644 --- a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/DynamicCheckpointGeneratorTest.java +++ b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/DynamicCheckpointGeneratorTest.java @@ -28,30 +28,29 @@ import org.mockito.Mock; import org.mockito.runners.MockitoJUnitRunner; - /*** */ @RunWith(MockitoJUnitRunner.class) public class DynamicCheckpointGeneratorTest { - @Mock - private SimplifiedKinesisClient kinesisClient; - @Mock - private Shard shard1, shard2, shard3; + @Mock + private SimplifiedKinesisClient kinesisClient; + @Mock + private Shard shard1, shard2, shard3; - @Test - public void shouldMapAllShardsToCheckpoints() throws Exception { - given(shard1.getShardId()).willReturn("shard-01"); - given(shard2.getShardId()).willReturn("shard-02"); - given(shard3.getShardId()).willReturn("shard-03"); - given(kinesisClient.listShards("stream")).willReturn(asList(shard1, shard2, shard3)); + @Test + public void shouldMapAllShardsToCheckpoints() throws Exception { + given(shard1.getShardId()).willReturn("shard-01"); + given(shard2.getShardId()).willReturn("shard-02"); + given(shard3.getShardId()).willReturn("shard-03"); + given(kinesisClient.listShards("stream")).willReturn(asList(shard1, shard2, shard3)); - StartingPoint startingPoint = new StartingPoint(InitialPositionInStream.LATEST); - DynamicCheckpointGenerator underTest = new DynamicCheckpointGenerator("stream", - startingPoint); + StartingPoint startingPoint = new StartingPoint(InitialPositionInStream.LATEST); + DynamicCheckpointGenerator underTest = new DynamicCheckpointGenerator("stream", + startingPoint); - KinesisReaderCheckpoint checkpoint = underTest.generate(kinesisClient); + KinesisReaderCheckpoint checkpoint = underTest.generate(kinesisClient); - assertThat(checkpoint).hasSize(3); - } + assertThat(checkpoint).hasSize(3); + } } diff --git a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisMockReadTest.java b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisMockReadTest.java index 567e25f8c982c..44ad67d4181d3 100644 --- a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisMockReadTest.java +++ b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisMockReadTest.java @@ -21,7 +21,9 @@ import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream; import com.google.common.collect.Iterables; + import java.util.List; + import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.DoFn; @@ -36,59 +38,60 @@ */ public class KinesisMockReadTest { - @Rule - public final transient TestPipeline p = TestPipeline.create(); - - @Test - public void readsDataFromMockKinesis() { - int noOfShards = 3; - int noOfEventsPerShard = 100; - List> testData = - provideTestData(noOfShards, noOfEventsPerShard); - - PCollection result = p - .apply( - KinesisIO.read() - .from("stream", InitialPositionInStream.TRIM_HORIZON) - .withClientProvider(new AmazonKinesisMock.Provider(testData, 10)) - .withMaxNumRecords(noOfShards * noOfEventsPerShard)) - .apply(ParDo.of(new KinesisRecordToTestData())); - PAssert.that(result).containsInAnyOrder(Iterables.concat(testData)); - p.run(); - } + @Rule + public final transient TestPipeline p = TestPipeline.create(); - private static class KinesisRecordToTestData extends - DoFn { - @ProcessElement - public void processElement(ProcessContext c) throws Exception { - c.output(new AmazonKinesisMock.TestData(c.element())); - } - } + @Test + public void readsDataFromMockKinesis() { + int noOfShards = 3; + int noOfEventsPerShard = 100; + List> testData = + provideTestData(noOfShards, noOfEventsPerShard); - private List> provideTestData( - int noOfShards, - int noOfEventsPerShard) { + PCollection result = p + .apply( + KinesisIO.read() + .from("stream", InitialPositionInStream.TRIM_HORIZON) + .withClientProvider(new AmazonKinesisMock.Provider(testData, 10)) + .withMaxNumRecords(noOfShards * noOfEventsPerShard)) + .apply(ParDo.of(new KinesisRecordToTestData())); + PAssert.that(result).containsInAnyOrder(Iterables.concat(testData)); + p.run(); + } - int seqNumber = 0; + private static class KinesisRecordToTestData extends + DoFn { - List> shardedData = newArrayList(); - for (int i = 0; i < noOfShards; ++i) { - List shardData = newArrayList(); - shardedData.add(shardData); + @ProcessElement + public void processElement(ProcessContext c) throws Exception { + c.output(new AmazonKinesisMock.TestData(c.element())); + } + } + + private List> provideTestData( + int noOfShards, + int noOfEventsPerShard) { - DateTime arrival = DateTime.now(); - for (int j = 0; j < noOfEventsPerShard; ++j) { - arrival = arrival.plusSeconds(1); + int seqNumber = 0; - seqNumber++; - shardData.add(new AmazonKinesisMock.TestData( - Integer.toString(seqNumber), - arrival.toInstant(), - Integer.toString(seqNumber)) - ); - } - } + List> shardedData = newArrayList(); + for (int i = 0; i < noOfShards; ++i) { + List shardData = newArrayList(); + shardedData.add(shardData); - return shardedData; + DateTime arrival = DateTime.now(); + for (int j = 0; j < noOfEventsPerShard; ++j) { + arrival = arrival.plusSeconds(1); + + seqNumber++; + shardData.add(new AmazonKinesisMock.TestData( + Integer.toString(seqNumber), + arrival.toInstant(), + Integer.toString(seqNumber)) + ); + } } + + return shardedData; + } } diff --git a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisReaderCheckpointTest.java b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisReaderCheckpointTest.java index 8c8da641804f1..1038a47bccb0e 100644 --- a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisReaderCheckpointTest.java +++ b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisReaderCheckpointTest.java @@ -17,13 +17,14 @@ */ package org.apache.beam.sdk.io.kinesis; - import static java.util.Arrays.asList; import static org.assertj.core.api.Assertions.assertThat; import com.google.common.collect.Iterables; + import java.util.Iterator; import java.util.List; + import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -35,33 +36,34 @@ */ @RunWith(MockitoJUnitRunner.class) public class KinesisReaderCheckpointTest { - @Mock - private ShardCheckpoint a, b, c; - private KinesisReaderCheckpoint checkpoint; + @Mock + private ShardCheckpoint a, b, c; + + private KinesisReaderCheckpoint checkpoint; - @Before - public void setUp() { - checkpoint = new KinesisReaderCheckpoint(asList(a, b, c)); - } + @Before + public void setUp() { + checkpoint = new KinesisReaderCheckpoint(asList(a, b, c)); + } - @Test - public void splitsCheckpointAccordingly() { - verifySplitInto(1); - verifySplitInto(2); - verifySplitInto(3); - verifySplitInto(4); - } + @Test + public void splitsCheckpointAccordingly() { + verifySplitInto(1); + verifySplitInto(2); + verifySplitInto(3); + verifySplitInto(4); + } - @Test(expected = UnsupportedOperationException.class) - public void isImmutable() { - Iterator iterator = checkpoint.iterator(); - iterator.remove(); - } + @Test(expected = UnsupportedOperationException.class) + public void isImmutable() { + Iterator iterator = checkpoint.iterator(); + iterator.remove(); + } - private void verifySplitInto(int size) { - List split = checkpoint.splitInto(size); - assertThat(Iterables.concat(split)).containsOnly(a, b, c); - assertThat(split).hasSize(Math.min(size, 3)); - } + private void verifySplitInto(int size) { + List split = checkpoint.splitInto(size); + assertThat(Iterables.concat(split)).containsOnly(a, b, c); + assertThat(split).hasSize(Math.min(size, 3)); + } } diff --git a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisReaderIT.java b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisReaderIT.java index 8eb65465ecd34..5781033227db2 100644 --- a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisReaderIT.java +++ b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisReaderIT.java @@ -23,6 +23,7 @@ import static org.assertj.core.api.Assertions.assertThat; import com.amazonaws.regions.Regions; + import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.List; @@ -31,6 +32,7 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; + import org.apache.beam.sdk.PipelineResult; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.testing.PAssert; @@ -50,72 +52,75 @@ * You need to provide all {@link KinesisTestOptions} in order to run this. */ public class KinesisReaderIT { - private static final long PIPELINE_STARTUP_TIME = TimeUnit.SECONDS.toMillis(10); - private ExecutorService singleThreadExecutor = newSingleThreadExecutor(); - - @Rule - public final transient TestPipeline p = TestPipeline.create(); - - @Ignore - @Test - public void readsDataFromRealKinesisStream() - throws IOException, InterruptedException, ExecutionException { - KinesisTestOptions options = readKinesisOptions(); - List testData = prepareTestData(1000); - - Future future = startTestPipeline(testData, options); - KinesisUploader.uploadAll(testData, options); - future.get(); - } - private List prepareTestData(int count) { - List data = newArrayList(); - for (int i = 0; i < count; ++i) { - data.add(RandomStringUtils.randomAlphabetic(32)); - } - return data; - } + private static final long PIPELINE_STARTUP_TIME = TimeUnit.SECONDS.toMillis(10); + private ExecutorService singleThreadExecutor = newSingleThreadExecutor(); - private Future startTestPipeline(List testData, KinesisTestOptions options) - throws InterruptedException { - - PCollection result = p. - apply(KinesisIO.read() - .from(options.getAwsKinesisStream(), Instant.now()) - .withClientProvider(options.getAwsAccessKey(), options.getAwsSecretKey(), - Regions.fromName(options.getAwsKinesisRegion())) - .withMaxReadTime(Duration.standardMinutes(3)) - ). - apply(ParDo.of(new RecordDataToString())); - PAssert.that(result).containsInAnyOrder(testData); - - Future future = singleThreadExecutor.submit(new Callable() { - @Override - public Void call() throws Exception { - PipelineResult result = p.run(); - PipelineResult.State state = result.getState(); - while (state != PipelineResult.State.DONE && state != PipelineResult.State.FAILED) { - Thread.sleep(1000); - state = result.getState(); - } - assertThat(state).isEqualTo(PipelineResult.State.DONE); - return null; - } - }); - Thread.sleep(PIPELINE_STARTUP_TIME); - return future; - } + @Rule + public final transient TestPipeline p = TestPipeline.create(); + + @Ignore + @Test + public void readsDataFromRealKinesisStream() + throws IOException, InterruptedException, ExecutionException { + KinesisTestOptions options = readKinesisOptions(); + List testData = prepareTestData(1000); - private KinesisTestOptions readKinesisOptions() { - PipelineOptionsFactory.register(KinesisTestOptions.class); - return TestPipeline.testingPipelineOptions().as(KinesisTestOptions.class); + Future future = startTestPipeline(testData, options); + KinesisUploader.uploadAll(testData, options); + future.get(); + } + + private List prepareTestData(int count) { + List data = newArrayList(); + for (int i = 0; i < count; ++i) { + data.add(RandomStringUtils.randomAlphabetic(32)); } + return data; + } + + private Future startTestPipeline(List testData, KinesisTestOptions options) + throws InterruptedException { + + PCollection result = p. + apply(KinesisIO.read() + .from(options.getAwsKinesisStream(), Instant.now()) + .withClientProvider(options.getAwsAccessKey(), options.getAwsSecretKey(), + Regions.fromName(options.getAwsKinesisRegion())) + .withMaxReadTime(Duration.standardMinutes(3)) + ). + apply(ParDo.of(new RecordDataToString())); + PAssert.that(result).containsInAnyOrder(testData); + + Future future = singleThreadExecutor.submit(new Callable() { - private static class RecordDataToString extends DoFn { - @ProcessElement - public void processElement(ProcessContext c) throws Exception { - checkNotNull(c.element(), "Null record given"); - c.output(new String(c.element().getData().array(), StandardCharsets.UTF_8)); + @Override + public Void call() throws Exception { + PipelineResult result = p.run(); + PipelineResult.State state = result.getState(); + while (state != PipelineResult.State.DONE && state != PipelineResult.State.FAILED) { + Thread.sleep(1000); + state = result.getState(); } + assertThat(state).isEqualTo(PipelineResult.State.DONE); + return null; + } + }); + Thread.sleep(PIPELINE_STARTUP_TIME); + return future; + } + + private KinesisTestOptions readKinesisOptions() { + PipelineOptionsFactory.register(KinesisTestOptions.class); + return TestPipeline.testingPipelineOptions().as(KinesisTestOptions.class); + } + + private static class RecordDataToString extends DoFn { + + @ProcessElement + public void processElement(ProcessContext c) throws Exception { + checkNotNull(c.element(), "Null record given"); + c.output(new String(c.element().getData().array(), StandardCharsets.UTF_8)); } + } } diff --git a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisReaderTest.java b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisReaderTest.java index 3111029d74e51..a26501ad12c5d 100644 --- a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisReaderTest.java +++ b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisReaderTest.java @@ -23,6 +23,7 @@ import java.io.IOException; import java.util.NoSuchElementException; + import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -34,87 +35,88 @@ */ @RunWith(MockitoJUnitRunner.class) public class KinesisReaderTest { - @Mock - private SimplifiedKinesisClient kinesis; - @Mock - private CheckpointGenerator generator; - @Mock - private ShardCheckpoint firstCheckpoint, secondCheckpoint; - @Mock - private ShardRecordsIterator firstIterator, secondIterator; - @Mock - private KinesisRecord a, b, c, d; - - private KinesisReader reader; - - @Before - public void setUp() throws IOException, TransientKinesisException { - when(generator.generate(kinesis)).thenReturn(new KinesisReaderCheckpoint( - asList(firstCheckpoint, secondCheckpoint) - )); - when(firstCheckpoint.getShardRecordsIterator(kinesis)).thenReturn(firstIterator); - when(secondCheckpoint.getShardRecordsIterator(kinesis)).thenReturn(secondIterator); - when(firstIterator.next()).thenReturn(CustomOptional.absent()); - when(secondIterator.next()).thenReturn(CustomOptional.absent()); - - reader = new KinesisReader(kinesis, generator, null); - } - - @Test - public void startReturnsFalseIfNoDataAtTheBeginning() throws IOException { - assertThat(reader.start()).isFalse(); - } - - @Test(expected = NoSuchElementException.class) - public void throwsNoSuchElementExceptionIfNoData() throws IOException { - reader.start(); - reader.getCurrent(); - } - - @Test - public void startReturnsTrueIfSomeDataAvailable() throws IOException, - TransientKinesisException { - when(firstIterator.next()). - thenReturn(CustomOptional.of(a)). - thenReturn(CustomOptional.absent()); - - assertThat(reader.start()).isTrue(); - } - - @Test - public void advanceReturnsFalseIfThereIsTransientExceptionInKinesis() - throws IOException, TransientKinesisException { - reader.start(); - - when(firstIterator.next()).thenThrow(TransientKinesisException.class); - - assertThat(reader.advance()).isFalse(); - } - - @Test - public void readsThroughAllDataAvailable() throws IOException, TransientKinesisException { - when(firstIterator.next()). - thenReturn(CustomOptional.absent()). - thenReturn(CustomOptional.of(a)). - thenReturn(CustomOptional.absent()). - thenReturn(CustomOptional.of(b)). - thenReturn(CustomOptional.absent()); - - when(secondIterator.next()). - thenReturn(CustomOptional.of(c)). - thenReturn(CustomOptional.absent()). - thenReturn(CustomOptional.of(d)). - thenReturn(CustomOptional.absent()); - - assertThat(reader.start()).isTrue(); - assertThat(reader.getCurrent()).isEqualTo(c); - assertThat(reader.advance()).isTrue(); - assertThat(reader.getCurrent()).isEqualTo(a); - assertThat(reader.advance()).isTrue(); - assertThat(reader.getCurrent()).isEqualTo(d); - assertThat(reader.advance()).isTrue(); - assertThat(reader.getCurrent()).isEqualTo(b); - assertThat(reader.advance()).isFalse(); - } + + @Mock + private SimplifiedKinesisClient kinesis; + @Mock + private CheckpointGenerator generator; + @Mock + private ShardCheckpoint firstCheckpoint, secondCheckpoint; + @Mock + private ShardRecordsIterator firstIterator, secondIterator; + @Mock + private KinesisRecord a, b, c, d; + + private KinesisReader reader; + + @Before + public void setUp() throws IOException, TransientKinesisException { + when(generator.generate(kinesis)).thenReturn(new KinesisReaderCheckpoint( + asList(firstCheckpoint, secondCheckpoint) + )); + when(firstCheckpoint.getShardRecordsIterator(kinesis)).thenReturn(firstIterator); + when(secondCheckpoint.getShardRecordsIterator(kinesis)).thenReturn(secondIterator); + when(firstIterator.next()).thenReturn(CustomOptional.absent()); + when(secondIterator.next()).thenReturn(CustomOptional.absent()); + + reader = new KinesisReader(kinesis, generator, null); + } + + @Test + public void startReturnsFalseIfNoDataAtTheBeginning() throws IOException { + assertThat(reader.start()).isFalse(); + } + + @Test(expected = NoSuchElementException.class) + public void throwsNoSuchElementExceptionIfNoData() throws IOException { + reader.start(); + reader.getCurrent(); + } + + @Test + public void startReturnsTrueIfSomeDataAvailable() throws IOException, + TransientKinesisException { + when(firstIterator.next()). + thenReturn(CustomOptional.of(a)). + thenReturn(CustomOptional.absent()); + + assertThat(reader.start()).isTrue(); + } + + @Test + public void advanceReturnsFalseIfThereIsTransientExceptionInKinesis() + throws IOException, TransientKinesisException { + reader.start(); + + when(firstIterator.next()).thenThrow(TransientKinesisException.class); + + assertThat(reader.advance()).isFalse(); + } + + @Test + public void readsThroughAllDataAvailable() throws IOException, TransientKinesisException { + when(firstIterator.next()). + thenReturn(CustomOptional.absent()). + thenReturn(CustomOptional.of(a)). + thenReturn(CustomOptional.absent()). + thenReturn(CustomOptional.of(b)). + thenReturn(CustomOptional.absent()); + + when(secondIterator.next()). + thenReturn(CustomOptional.of(c)). + thenReturn(CustomOptional.absent()). + thenReturn(CustomOptional.of(d)). + thenReturn(CustomOptional.absent()); + + assertThat(reader.start()).isTrue(); + assertThat(reader.getCurrent()).isEqualTo(c); + assertThat(reader.advance()).isTrue(); + assertThat(reader.getCurrent()).isEqualTo(a); + assertThat(reader.advance()).isTrue(); + assertThat(reader.getCurrent()).isEqualTo(d); + assertThat(reader.advance()).isTrue(); + assertThat(reader.getCurrent()).isEqualTo(b); + assertThat(reader.advance()).isFalse(); + } } diff --git a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisRecordCoderTest.java b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisRecordCoderTest.java index 8771c86c82a5a..c9f01bb11d3f3 100644 --- a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisRecordCoderTest.java +++ b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisRecordCoderTest.java @@ -18,6 +18,7 @@ package org.apache.beam.sdk.io.kinesis; import java.nio.ByteBuffer; + import org.apache.beam.sdk.testing.CoderProperties; import org.joda.time.Instant; import org.junit.Test; @@ -26,20 +27,21 @@ * Tests {@link KinesisRecordCoder}. */ public class KinesisRecordCoderTest { - @Test - public void encodingAndDecodingWorks() throws Exception { - KinesisRecord record = new KinesisRecord( - ByteBuffer.wrap("data".getBytes()), - "sequence", - 128L, - "partition", - Instant.now(), - Instant.now(), - "stream", - "shard" - ); - CoderProperties.coderDecodeEncodeEqual( - new KinesisRecordCoder(), record - ); - } + + @Test + public void encodingAndDecodingWorks() throws Exception { + KinesisRecord record = new KinesisRecord( + ByteBuffer.wrap("data".getBytes()), + "sequence", + 128L, + "partition", + Instant.now(), + Instant.now(), + "stream", + "shard" + ); + CoderProperties.coderDecodeEncodeEqual( + new KinesisRecordCoder(), record + ); + } } diff --git a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisTestOptions.java b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisTestOptions.java index 324de466776d9..76bcb273d5fa1 100644 --- a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisTestOptions.java +++ b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisTestOptions.java @@ -25,23 +25,28 @@ * Options for Kinesis integration tests. */ public interface KinesisTestOptions extends TestPipelineOptions { - @Description("AWS region where Kinesis stream resided") - @Default.String("aws-kinesis-region") - String getAwsKinesisRegion(); - void setAwsKinesisRegion(String value); - - @Description("Kinesis stream name") - @Default.String("aws-kinesis-stream") - String getAwsKinesisStream(); - void setAwsKinesisStream(String value); - - @Description("AWS secret key") - @Default.String("aws-secret-key") - String getAwsSecretKey(); - void setAwsSecretKey(String value); - - @Description("AWS access key") - @Default.String("aws-access-key") - String getAwsAccessKey(); - void setAwsAccessKey(String value); + + @Description("AWS region where Kinesis stream resided") + @Default.String("aws-kinesis-region") + String getAwsKinesisRegion(); + + void setAwsKinesisRegion(String value); + + @Description("Kinesis stream name") + @Default.String("aws-kinesis-stream") + String getAwsKinesisStream(); + + void setAwsKinesisStream(String value); + + @Description("AWS secret key") + @Default.String("aws-secret-key") + String getAwsSecretKey(); + + void setAwsSecretKey(String value); + + @Description("AWS access key") + @Default.String("aws-access-key") + String getAwsAccessKey(); + + void setAwsAccessKey(String value); } diff --git a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisUploader.java b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisUploader.java index 7518ff71d82c2..7a7cb02202a51 100644 --- a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisUploader.java +++ b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisUploader.java @@ -29,6 +29,7 @@ import com.amazonaws.services.kinesis.model.PutRecordsResultEntry; import com.google.common.base.Charsets; import com.google.common.collect.Lists; + import java.nio.ByteBuffer; import java.util.List; @@ -37,47 +38,46 @@ */ public class KinesisUploader { - public static final int MAX_NUMBER_OF_RECORDS_IN_BATCH = 499; - - public static void uploadAll(List data, KinesisTestOptions options) { - AmazonKinesisClient client = new AmazonKinesisClient( - new StaticCredentialsProvider( - new BasicAWSCredentials( - options.getAwsAccessKey(), options.getAwsSecretKey())) - ).withRegion(Regions.fromName(options.getAwsKinesisRegion())); + public static final int MAX_NUMBER_OF_RECORDS_IN_BATCH = 499; - List> partitions = Lists.partition(data, MAX_NUMBER_OF_RECORDS_IN_BATCH); + public static void uploadAll(List data, KinesisTestOptions options) { + AmazonKinesisClient client = new AmazonKinesisClient( + new StaticCredentialsProvider( + new BasicAWSCredentials( + options.getAwsAccessKey(), options.getAwsSecretKey())) + ).withRegion(Regions.fromName(options.getAwsKinesisRegion())); + List> partitions = Lists.partition(data, MAX_NUMBER_OF_RECORDS_IN_BATCH); - for (List partition : partitions) { - List allRecords = newArrayList(); - for (String row : partition) { - allRecords.add(new PutRecordsRequestEntry(). - withData(ByteBuffer.wrap(row.getBytes(Charsets.UTF_8))). - withPartitionKey(Integer.toString(row.hashCode())) + for (List partition : partitions) { + List allRecords = newArrayList(); + for (String row : partition) { + allRecords.add(new PutRecordsRequestEntry(). + withData(ByteBuffer.wrap(row.getBytes(Charsets.UTF_8))). + withPartitionKey(Integer.toString(row.hashCode())) - ); - } + ); + } - PutRecordsResult result; - do { - result = client.putRecords( - new PutRecordsRequest(). - withStreamName(options.getAwsKinesisStream()). - withRecords(allRecords)); - List failedRecords = newArrayList(); - int i = 0; - for (PutRecordsResultEntry row : result.getRecords()) { - if (row.getErrorCode() != null) { - failedRecords.add(allRecords.get(i)); - } - ++i; - } - allRecords = failedRecords; - } - - while (result.getFailedRecordCount() > 0); + PutRecordsResult result; + do { + result = client.putRecords( + new PutRecordsRequest(). + withStreamName(options.getAwsKinesisStream()). + withRecords(allRecords)); + List failedRecords = newArrayList(); + int i = 0; + for (PutRecordsResultEntry row : result.getRecords()) { + if (row.getErrorCode() != null) { + failedRecords.add(allRecords.get(i)); + } + ++i; } + allRecords = failedRecords; + } + + while (result.getFailedRecordCount() > 0); } + } } diff --git a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/RecordFilterTest.java b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/RecordFilterTest.java index f979c0108cdc2..cb325620abfc9 100644 --- a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/RecordFilterTest.java +++ b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/RecordFilterTest.java @@ -20,47 +20,49 @@ import static org.mockito.BDDMockito.given; import com.google.common.collect.Lists; + import java.util.Collections; import java.util.List; + import org.assertj.core.api.Assertions; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.runners.MockitoJUnitRunner; - /*** */ @RunWith(MockitoJUnitRunner.class) public class RecordFilterTest { - @Mock - private ShardCheckpoint checkpoint; - @Mock - private KinesisRecord record1, record2, record3, record4, record5; - @Test - public void shouldFilterOutRecordsBeforeOrAtCheckpoint() { - given(checkpoint.isBeforeOrAt(record1)).willReturn(false); - given(checkpoint.isBeforeOrAt(record2)).willReturn(true); - given(checkpoint.isBeforeOrAt(record3)).willReturn(true); - given(checkpoint.isBeforeOrAt(record4)).willReturn(false); - given(checkpoint.isBeforeOrAt(record5)).willReturn(true); - List records = Lists.newArrayList(record1, record2, - record3, record4, record5); - RecordFilter underTest = new RecordFilter(); + @Mock + private ShardCheckpoint checkpoint; + @Mock + private KinesisRecord record1, record2, record3, record4, record5; + + @Test + public void shouldFilterOutRecordsBeforeOrAtCheckpoint() { + given(checkpoint.isBeforeOrAt(record1)).willReturn(false); + given(checkpoint.isBeforeOrAt(record2)).willReturn(true); + given(checkpoint.isBeforeOrAt(record3)).willReturn(true); + given(checkpoint.isBeforeOrAt(record4)).willReturn(false); + given(checkpoint.isBeforeOrAt(record5)).willReturn(true); + List records = Lists.newArrayList(record1, record2, + record3, record4, record5); + RecordFilter underTest = new RecordFilter(); - List retainedRecords = underTest.apply(records, checkpoint); + List retainedRecords = underTest.apply(records, checkpoint); - Assertions.assertThat(retainedRecords).containsOnly(record2, record3, record5); - } + Assertions.assertThat(retainedRecords).containsOnly(record2, record3, record5); + } - @Test - public void shouldNotFailOnEmptyList() { - List records = Collections.emptyList(); - RecordFilter underTest = new RecordFilter(); + @Test + public void shouldNotFailOnEmptyList() { + List records = Collections.emptyList(); + RecordFilter underTest = new RecordFilter(); - List retainedRecords = underTest.apply(records, checkpoint); + List retainedRecords = underTest.apply(records, checkpoint); - Assertions.assertThat(retainedRecords).isEmpty(); - } + Assertions.assertThat(retainedRecords).isEmpty(); + } } diff --git a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/RoundRobinTest.java b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/RoundRobinTest.java index f032eeab377d4..e4abce47d6117 100644 --- a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/RoundRobinTest.java +++ b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/RoundRobinTest.java @@ -22,36 +22,38 @@ import java.util.Collections; import java.util.List; + import org.junit.Test; /** * Tests {@link RoundRobin}. */ public class RoundRobinTest { - @Test(expected = IllegalArgumentException.class) - public void doesNotAllowCreationWithEmptyCollection() { - new RoundRobin<>(Collections.emptyList()); - } - @Test - public void goesThroughElementsInCycle() { - List input = newArrayList("a", "b", "c"); + @Test(expected = IllegalArgumentException.class) + public void doesNotAllowCreationWithEmptyCollection() { + new RoundRobin<>(Collections.emptyList()); + } - RoundRobin roundRobin = new RoundRobin<>(newArrayList(input)); + @Test + public void goesThroughElementsInCycle() { + List input = newArrayList("a", "b", "c"); - input.addAll(input); // duplicate the input - for (String element : input) { - assertThat(roundRobin.getCurrent()).isEqualTo(element); - assertThat(roundRobin.getCurrent()).isEqualTo(element); - roundRobin.moveForward(); - } + RoundRobin roundRobin = new RoundRobin<>(newArrayList(input)); + + input.addAll(input); // duplicate the input + for (String element : input) { + assertThat(roundRobin.getCurrent()).isEqualTo(element); + assertThat(roundRobin.getCurrent()).isEqualTo(element); + roundRobin.moveForward(); } + } - @Test - public void usualIteratorGoesThroughElementsOnce() { - List input = newArrayList("a", "b", "c"); + @Test + public void usualIteratorGoesThroughElementsOnce() { + List input = newArrayList("a", "b", "c"); - RoundRobin roundRobin = new RoundRobin<>(input); - assertThat(roundRobin).hasSize(3).containsOnly(input.toArray(new String[0])); - } + RoundRobin roundRobin = new RoundRobin<>(input); + assertThat(roundRobin).hasSize(3).containsOnly(input.toArray(new String[0])); + } } diff --git a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/ShardCheckpointTest.java b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/ShardCheckpointTest.java index 39ab36f9255a5..d4784c48f11ec 100644 --- a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/ShardCheckpointTest.java +++ b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/ShardCheckpointTest.java @@ -32,7 +32,9 @@ import com.amazonaws.services.kinesis.clientlibrary.types.ExtendedSequenceNumber; import com.amazonaws.services.kinesis.model.ShardIteratorType; + import java.io.IOException; + import org.joda.time.DateTime; import org.joda.time.Instant; import org.junit.Before; @@ -46,104 +48,105 @@ */ @RunWith(MockitoJUnitRunner.class) public class ShardCheckpointTest { - private static final String AT_SEQUENCE_SHARD_IT = "AT_SEQUENCE_SHARD_IT"; - private static final String AFTER_SEQUENCE_SHARD_IT = "AFTER_SEQUENCE_SHARD_IT"; - private static final String STREAM_NAME = "STREAM"; - private static final String SHARD_ID = "SHARD_ID"; - @Mock - private SimplifiedKinesisClient client; - - @Before - public void setUp() throws IOException, TransientKinesisException { - when(client.getShardIterator( - eq(STREAM_NAME), eq(SHARD_ID), eq(AT_SEQUENCE_NUMBER), - anyString(), isNull(Instant.class))). - thenReturn(AT_SEQUENCE_SHARD_IT); - when(client.getShardIterator( - eq(STREAM_NAME), eq(SHARD_ID), eq(AFTER_SEQUENCE_NUMBER), - anyString(), isNull(Instant.class))). - thenReturn(AFTER_SEQUENCE_SHARD_IT); - } - - @Test - public void testProvidingShardIterator() throws IOException, TransientKinesisException { - assertThat(checkpoint(AT_SEQUENCE_NUMBER, "100", null).getShardIterator(client)) - .isEqualTo(AT_SEQUENCE_SHARD_IT); - assertThat(checkpoint(AFTER_SEQUENCE_NUMBER, "100", null).getShardIterator(client)) - .isEqualTo(AFTER_SEQUENCE_SHARD_IT); - assertThat(checkpoint(AT_SEQUENCE_NUMBER, "100", 10L).getShardIterator(client)).isEqualTo - (AT_SEQUENCE_SHARD_IT); - assertThat(checkpoint(AFTER_SEQUENCE_NUMBER, "100", 10L).getShardIterator(client)) - .isEqualTo(AT_SEQUENCE_SHARD_IT); - } - - @Test - public void testComparisonWithExtendedSequenceNumber() { - assertThat(new ShardCheckpoint("", "", new StartingPoint(LATEST)).isBeforeOrAt( - recordWith(new ExtendedSequenceNumber("100", 0L)) - )).isTrue(); - - assertThat(new ShardCheckpoint("", "", new StartingPoint(TRIM_HORIZON)).isBeforeOrAt( - recordWith(new ExtendedSequenceNumber("100", 0L)) - )).isTrue(); - - assertThat(checkpoint(AFTER_SEQUENCE_NUMBER, "10", 1L).isBeforeOrAt( - recordWith(new ExtendedSequenceNumber("100", 0L)) - )).isTrue(); - - assertThat(checkpoint(AT_SEQUENCE_NUMBER, "100", 0L).isBeforeOrAt( - recordWith(new ExtendedSequenceNumber("100", 0L)) - )).isTrue(); - - assertThat(checkpoint(AFTER_SEQUENCE_NUMBER, "100", 0L).isBeforeOrAt( - recordWith(new ExtendedSequenceNumber("100", 0L)) - )).isFalse(); - - assertThat(checkpoint(AT_SEQUENCE_NUMBER, "100", 1L).isBeforeOrAt( - recordWith(new ExtendedSequenceNumber("100", 0L)) - )).isFalse(); - - assertThat(checkpoint(AFTER_SEQUENCE_NUMBER, "100", 0L).isBeforeOrAt( - recordWith(new ExtendedSequenceNumber("99", 1L)) - )).isFalse(); - } - - @Test - public void testComparisonWithTimestamp() { - DateTime referenceTimestamp = DateTime.now(); - - assertThat(checkpoint(AT_TIMESTAMP, referenceTimestamp.toInstant()) - .isBeforeOrAt(recordWith(referenceTimestamp.minusMillis(10).toInstant())) - ).isFalse(); - - assertThat(checkpoint(AT_TIMESTAMP, referenceTimestamp.toInstant()) - .isBeforeOrAt(recordWith(referenceTimestamp.toInstant())) - ).isTrue(); - - assertThat(checkpoint(AT_TIMESTAMP, referenceTimestamp.toInstant()) - .isBeforeOrAt(recordWith(referenceTimestamp.plusMillis(10).toInstant())) - ).isTrue(); - } - - private KinesisRecord recordWith(ExtendedSequenceNumber extendedSequenceNumber) { - KinesisRecord record = mock(KinesisRecord.class); - given(record.getExtendedSequenceNumber()).willReturn(extendedSequenceNumber); - return record; - } - - private ShardCheckpoint checkpoint(ShardIteratorType iteratorType, String sequenceNumber, - Long subSequenceNumber) { - return new ShardCheckpoint(STREAM_NAME, SHARD_ID, iteratorType, sequenceNumber, - subSequenceNumber); - } - - private KinesisRecord recordWith(Instant approximateArrivalTimestamp) { - KinesisRecord record = mock(KinesisRecord.class); - given(record.getApproximateArrivalTimestamp()).willReturn(approximateArrivalTimestamp); - return record; - } - - private ShardCheckpoint checkpoint(ShardIteratorType iteratorType, Instant timestamp) { - return new ShardCheckpoint(STREAM_NAME, SHARD_ID, iteratorType, timestamp); - } + + private static final String AT_SEQUENCE_SHARD_IT = "AT_SEQUENCE_SHARD_IT"; + private static final String AFTER_SEQUENCE_SHARD_IT = "AFTER_SEQUENCE_SHARD_IT"; + private static final String STREAM_NAME = "STREAM"; + private static final String SHARD_ID = "SHARD_ID"; + @Mock + private SimplifiedKinesisClient client; + + @Before + public void setUp() throws IOException, TransientKinesisException { + when(client.getShardIterator( + eq(STREAM_NAME), eq(SHARD_ID), eq(AT_SEQUENCE_NUMBER), + anyString(), isNull(Instant.class))). + thenReturn(AT_SEQUENCE_SHARD_IT); + when(client.getShardIterator( + eq(STREAM_NAME), eq(SHARD_ID), eq(AFTER_SEQUENCE_NUMBER), + anyString(), isNull(Instant.class))). + thenReturn(AFTER_SEQUENCE_SHARD_IT); + } + + @Test + public void testProvidingShardIterator() throws IOException, TransientKinesisException { + assertThat(checkpoint(AT_SEQUENCE_NUMBER, "100", null).getShardIterator(client)) + .isEqualTo(AT_SEQUENCE_SHARD_IT); + assertThat(checkpoint(AFTER_SEQUENCE_NUMBER, "100", null).getShardIterator(client)) + .isEqualTo(AFTER_SEQUENCE_SHARD_IT); + assertThat(checkpoint(AT_SEQUENCE_NUMBER, "100", 10L).getShardIterator(client)).isEqualTo + (AT_SEQUENCE_SHARD_IT); + assertThat(checkpoint(AFTER_SEQUENCE_NUMBER, "100", 10L).getShardIterator(client)) + .isEqualTo(AT_SEQUENCE_SHARD_IT); + } + + @Test + public void testComparisonWithExtendedSequenceNumber() { + assertThat(new ShardCheckpoint("", "", new StartingPoint(LATEST)).isBeforeOrAt( + recordWith(new ExtendedSequenceNumber("100", 0L)) + )).isTrue(); + + assertThat(new ShardCheckpoint("", "", new StartingPoint(TRIM_HORIZON)).isBeforeOrAt( + recordWith(new ExtendedSequenceNumber("100", 0L)) + )).isTrue(); + + assertThat(checkpoint(AFTER_SEQUENCE_NUMBER, "10", 1L).isBeforeOrAt( + recordWith(new ExtendedSequenceNumber("100", 0L)) + )).isTrue(); + + assertThat(checkpoint(AT_SEQUENCE_NUMBER, "100", 0L).isBeforeOrAt( + recordWith(new ExtendedSequenceNumber("100", 0L)) + )).isTrue(); + + assertThat(checkpoint(AFTER_SEQUENCE_NUMBER, "100", 0L).isBeforeOrAt( + recordWith(new ExtendedSequenceNumber("100", 0L)) + )).isFalse(); + + assertThat(checkpoint(AT_SEQUENCE_NUMBER, "100", 1L).isBeforeOrAt( + recordWith(new ExtendedSequenceNumber("100", 0L)) + )).isFalse(); + + assertThat(checkpoint(AFTER_SEQUENCE_NUMBER, "100", 0L).isBeforeOrAt( + recordWith(new ExtendedSequenceNumber("99", 1L)) + )).isFalse(); + } + + @Test + public void testComparisonWithTimestamp() { + DateTime referenceTimestamp = DateTime.now(); + + assertThat(checkpoint(AT_TIMESTAMP, referenceTimestamp.toInstant()) + .isBeforeOrAt(recordWith(referenceTimestamp.minusMillis(10).toInstant())) + ).isFalse(); + + assertThat(checkpoint(AT_TIMESTAMP, referenceTimestamp.toInstant()) + .isBeforeOrAt(recordWith(referenceTimestamp.toInstant())) + ).isTrue(); + + assertThat(checkpoint(AT_TIMESTAMP, referenceTimestamp.toInstant()) + .isBeforeOrAt(recordWith(referenceTimestamp.plusMillis(10).toInstant())) + ).isTrue(); + } + + private KinesisRecord recordWith(ExtendedSequenceNumber extendedSequenceNumber) { + KinesisRecord record = mock(KinesisRecord.class); + given(record.getExtendedSequenceNumber()).willReturn(extendedSequenceNumber); + return record; + } + + private ShardCheckpoint checkpoint(ShardIteratorType iteratorType, String sequenceNumber, + Long subSequenceNumber) { + return new ShardCheckpoint(STREAM_NAME, SHARD_ID, iteratorType, sequenceNumber, + subSequenceNumber); + } + + private KinesisRecord recordWith(Instant approximateArrivalTimestamp) { + KinesisRecord record = mock(KinesisRecord.class); + given(record.getApproximateArrivalTimestamp()).willReturn(approximateArrivalTimestamp); + return record; + } + + private ShardCheckpoint checkpoint(ShardIteratorType iteratorType, Instant timestamp) { + return new ShardCheckpoint(STREAM_NAME, SHARD_ID, iteratorType, timestamp); + } } diff --git a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/ShardRecordsIteratorTest.java b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/ShardRecordsIteratorTest.java index 49e806dc12e85..4b2190fe9e463 100644 --- a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/ShardRecordsIteratorTest.java +++ b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/ShardRecordsIteratorTest.java @@ -25,8 +25,10 @@ import static org.mockito.Mockito.when; import com.amazonaws.services.kinesis.model.ExpiredIteratorException; + import java.io.IOException; import java.util.Collections; + import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -40,112 +42,114 @@ */ @RunWith(MockitoJUnitRunner.class) public class ShardRecordsIteratorTest { - private static final String INITIAL_ITERATOR = "INITIAL_ITERATOR"; - private static final String SECOND_ITERATOR = "SECOND_ITERATOR"; - private static final String SECOND_REFRESHED_ITERATOR = "SECOND_REFRESHED_ITERATOR"; - private static final String THIRD_ITERATOR = "THIRD_ITERATOR"; - private static final String STREAM_NAME = "STREAM_NAME"; - private static final String SHARD_ID = "SHARD_ID"; - - @Mock - private SimplifiedKinesisClient kinesisClient; - @Mock - private ShardCheckpoint firstCheckpoint, aCheckpoint, bCheckpoint, cCheckpoint, dCheckpoint; - @Mock - private GetKinesisRecordsResult firstResult, secondResult, thirdResult; - @Mock - private KinesisRecord a, b, c, d; - @Mock - private RecordFilter recordFilter; - - private ShardRecordsIterator iterator; - - @Before - public void setUp() throws IOException, TransientKinesisException { - when(firstCheckpoint.getShardIterator(kinesisClient)).thenReturn(INITIAL_ITERATOR); - when(firstCheckpoint.getStreamName()).thenReturn(STREAM_NAME); - when(firstCheckpoint.getShardId()).thenReturn(SHARD_ID); - - when(firstCheckpoint.moveAfter(a)).thenReturn(aCheckpoint); - when(aCheckpoint.moveAfter(b)).thenReturn(bCheckpoint); - when(aCheckpoint.getStreamName()).thenReturn(STREAM_NAME); - when(aCheckpoint.getShardId()).thenReturn(SHARD_ID); - when(bCheckpoint.moveAfter(c)).thenReturn(cCheckpoint); - when(bCheckpoint.getStreamName()).thenReturn(STREAM_NAME); - when(bCheckpoint.getShardId()).thenReturn(SHARD_ID); - when(cCheckpoint.moveAfter(d)).thenReturn(dCheckpoint); - when(cCheckpoint.getStreamName()).thenReturn(STREAM_NAME); - when(cCheckpoint.getShardId()).thenReturn(SHARD_ID); - when(dCheckpoint.getStreamName()).thenReturn(STREAM_NAME); - when(dCheckpoint.getShardId()).thenReturn(SHARD_ID); - - when(kinesisClient.getRecords(INITIAL_ITERATOR, STREAM_NAME, SHARD_ID)) - .thenReturn(firstResult); - when(kinesisClient.getRecords(SECOND_ITERATOR, STREAM_NAME, SHARD_ID)) - .thenReturn(secondResult); - when(kinesisClient.getRecords(THIRD_ITERATOR, STREAM_NAME, SHARD_ID)) - .thenReturn(thirdResult); - - when(firstResult.getNextShardIterator()).thenReturn(SECOND_ITERATOR); - when(secondResult.getNextShardIterator()).thenReturn(THIRD_ITERATOR); - when(thirdResult.getNextShardIterator()).thenReturn(THIRD_ITERATOR); - - when(firstResult.getRecords()).thenReturn(Collections.emptyList()); - when(secondResult.getRecords()).thenReturn(Collections.emptyList()); - when(thirdResult.getRecords()).thenReturn(Collections.emptyList()); - - when(recordFilter.apply(anyListOf(KinesisRecord.class), any(ShardCheckpoint - .class))).thenAnswer(new IdentityAnswer()); - - iterator = new ShardRecordsIterator(firstCheckpoint, kinesisClient, recordFilter); - } - - @Test - public void returnsAbsentIfNoRecordsPresent() throws IOException, TransientKinesisException { - assertThat(iterator.next()).isEqualTo(CustomOptional.absent()); - assertThat(iterator.next()).isEqualTo(CustomOptional.absent()); - assertThat(iterator.next()).isEqualTo(CustomOptional.absent()); - } - - @Test - public void goesThroughAvailableRecords() throws IOException, TransientKinesisException { - when(firstResult.getRecords()).thenReturn(asList(a, b, c)); - when(secondResult.getRecords()).thenReturn(singletonList(d)); - - assertThat(iterator.getCheckpoint()).isEqualTo(firstCheckpoint); - assertThat(iterator.next()).isEqualTo(CustomOptional.of(a)); - assertThat(iterator.getCheckpoint()).isEqualTo(aCheckpoint); - assertThat(iterator.next()).isEqualTo(CustomOptional.of(b)); - assertThat(iterator.getCheckpoint()).isEqualTo(bCheckpoint); - assertThat(iterator.next()).isEqualTo(CustomOptional.of(c)); - assertThat(iterator.getCheckpoint()).isEqualTo(cCheckpoint); - assertThat(iterator.next()).isEqualTo(CustomOptional.of(d)); - assertThat(iterator.getCheckpoint()).isEqualTo(dCheckpoint); - assertThat(iterator.next()).isEqualTo(CustomOptional.absent()); - assertThat(iterator.getCheckpoint()).isEqualTo(dCheckpoint); - } - - @Test - public void refreshesExpiredIterator() throws IOException, TransientKinesisException { - when(firstResult.getRecords()).thenReturn(singletonList(a)); - when(secondResult.getRecords()).thenReturn(singletonList(b)); - - when(kinesisClient.getRecords(SECOND_ITERATOR, STREAM_NAME, SHARD_ID)) - .thenThrow(ExpiredIteratorException.class); - when(aCheckpoint.getShardIterator(kinesisClient)) - .thenReturn(SECOND_REFRESHED_ITERATOR); - when(kinesisClient.getRecords(SECOND_REFRESHED_ITERATOR, STREAM_NAME, SHARD_ID)) - .thenReturn(secondResult); - - assertThat(iterator.next()).isEqualTo(CustomOptional.of(a)); - assertThat(iterator.next()).isEqualTo(CustomOptional.of(b)); - assertThat(iterator.next()).isEqualTo(CustomOptional.absent()); - } - private static class IdentityAnswer implements Answer { - @Override - public Object answer(InvocationOnMock invocation) throws Throwable { - return invocation.getArguments()[0]; - } + private static final String INITIAL_ITERATOR = "INITIAL_ITERATOR"; + private static final String SECOND_ITERATOR = "SECOND_ITERATOR"; + private static final String SECOND_REFRESHED_ITERATOR = "SECOND_REFRESHED_ITERATOR"; + private static final String THIRD_ITERATOR = "THIRD_ITERATOR"; + private static final String STREAM_NAME = "STREAM_NAME"; + private static final String SHARD_ID = "SHARD_ID"; + + @Mock + private SimplifiedKinesisClient kinesisClient; + @Mock + private ShardCheckpoint firstCheckpoint, aCheckpoint, bCheckpoint, cCheckpoint, dCheckpoint; + @Mock + private GetKinesisRecordsResult firstResult, secondResult, thirdResult; + @Mock + private KinesisRecord a, b, c, d; + @Mock + private RecordFilter recordFilter; + + private ShardRecordsIterator iterator; + + @Before + public void setUp() throws IOException, TransientKinesisException { + when(firstCheckpoint.getShardIterator(kinesisClient)).thenReturn(INITIAL_ITERATOR); + when(firstCheckpoint.getStreamName()).thenReturn(STREAM_NAME); + when(firstCheckpoint.getShardId()).thenReturn(SHARD_ID); + + when(firstCheckpoint.moveAfter(a)).thenReturn(aCheckpoint); + when(aCheckpoint.moveAfter(b)).thenReturn(bCheckpoint); + when(aCheckpoint.getStreamName()).thenReturn(STREAM_NAME); + when(aCheckpoint.getShardId()).thenReturn(SHARD_ID); + when(bCheckpoint.moveAfter(c)).thenReturn(cCheckpoint); + when(bCheckpoint.getStreamName()).thenReturn(STREAM_NAME); + when(bCheckpoint.getShardId()).thenReturn(SHARD_ID); + when(cCheckpoint.moveAfter(d)).thenReturn(dCheckpoint); + when(cCheckpoint.getStreamName()).thenReturn(STREAM_NAME); + when(cCheckpoint.getShardId()).thenReturn(SHARD_ID); + when(dCheckpoint.getStreamName()).thenReturn(STREAM_NAME); + when(dCheckpoint.getShardId()).thenReturn(SHARD_ID); + + when(kinesisClient.getRecords(INITIAL_ITERATOR, STREAM_NAME, SHARD_ID)) + .thenReturn(firstResult); + when(kinesisClient.getRecords(SECOND_ITERATOR, STREAM_NAME, SHARD_ID)) + .thenReturn(secondResult); + when(kinesisClient.getRecords(THIRD_ITERATOR, STREAM_NAME, SHARD_ID)) + .thenReturn(thirdResult); + + when(firstResult.getNextShardIterator()).thenReturn(SECOND_ITERATOR); + when(secondResult.getNextShardIterator()).thenReturn(THIRD_ITERATOR); + when(thirdResult.getNextShardIterator()).thenReturn(THIRD_ITERATOR); + + when(firstResult.getRecords()).thenReturn(Collections.emptyList()); + when(secondResult.getRecords()).thenReturn(Collections.emptyList()); + when(thirdResult.getRecords()).thenReturn(Collections.emptyList()); + + when(recordFilter.apply(anyListOf(KinesisRecord.class), any(ShardCheckpoint + .class))).thenAnswer(new IdentityAnswer()); + + iterator = new ShardRecordsIterator(firstCheckpoint, kinesisClient, recordFilter); + } + + @Test + public void returnsAbsentIfNoRecordsPresent() throws IOException, TransientKinesisException { + assertThat(iterator.next()).isEqualTo(CustomOptional.absent()); + assertThat(iterator.next()).isEqualTo(CustomOptional.absent()); + assertThat(iterator.next()).isEqualTo(CustomOptional.absent()); + } + + @Test + public void goesThroughAvailableRecords() throws IOException, TransientKinesisException { + when(firstResult.getRecords()).thenReturn(asList(a, b, c)); + when(secondResult.getRecords()).thenReturn(singletonList(d)); + + assertThat(iterator.getCheckpoint()).isEqualTo(firstCheckpoint); + assertThat(iterator.next()).isEqualTo(CustomOptional.of(a)); + assertThat(iterator.getCheckpoint()).isEqualTo(aCheckpoint); + assertThat(iterator.next()).isEqualTo(CustomOptional.of(b)); + assertThat(iterator.getCheckpoint()).isEqualTo(bCheckpoint); + assertThat(iterator.next()).isEqualTo(CustomOptional.of(c)); + assertThat(iterator.getCheckpoint()).isEqualTo(cCheckpoint); + assertThat(iterator.next()).isEqualTo(CustomOptional.of(d)); + assertThat(iterator.getCheckpoint()).isEqualTo(dCheckpoint); + assertThat(iterator.next()).isEqualTo(CustomOptional.absent()); + assertThat(iterator.getCheckpoint()).isEqualTo(dCheckpoint); + } + + @Test + public void refreshesExpiredIterator() throws IOException, TransientKinesisException { + when(firstResult.getRecords()).thenReturn(singletonList(a)); + when(secondResult.getRecords()).thenReturn(singletonList(b)); + + when(kinesisClient.getRecords(SECOND_ITERATOR, STREAM_NAME, SHARD_ID)) + .thenThrow(ExpiredIteratorException.class); + when(aCheckpoint.getShardIterator(kinesisClient)) + .thenReturn(SECOND_REFRESHED_ITERATOR); + when(kinesisClient.getRecords(SECOND_REFRESHED_ITERATOR, STREAM_NAME, SHARD_ID)) + .thenReturn(secondResult); + + assertThat(iterator.next()).isEqualTo(CustomOptional.of(a)); + assertThat(iterator.next()).isEqualTo(CustomOptional.of(b)); + assertThat(iterator.next()).isEqualTo(CustomOptional.absent()); + } + + private static class IdentityAnswer implements Answer { + + @Override + public Object answer(InvocationOnMock invocation) throws Throwable { + return invocation.getArguments()[0]; } + } } diff --git a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/SimplifiedKinesisClientTest.java b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/SimplifiedKinesisClientTest.java index 96434fd4c838a..2f8757c5c0806 100644 --- a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/SimplifiedKinesisClientTest.java +++ b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/SimplifiedKinesisClientTest.java @@ -34,7 +34,9 @@ import com.amazonaws.services.kinesis.model.Shard; import com.amazonaws.services.kinesis.model.ShardIteratorType; import com.amazonaws.services.kinesis.model.StreamDescription; + import java.util.List; + import org.joda.time.Instant; import org.junit.Test; import org.junit.runner.RunWith; @@ -46,179 +48,180 @@ */ @RunWith(MockitoJUnitRunner.class) public class SimplifiedKinesisClientTest { - private static final String STREAM = "stream"; - private static final String SHARD_1 = "shard-01"; - private static final String SHARD_2 = "shard-02"; - private static final String SHARD_3 = "shard-03"; - private static final String SHARD_ITERATOR = "iterator"; - private static final String SEQUENCE_NUMBER = "abc123"; - - @Mock - private AmazonKinesis kinesis; - @InjectMocks - private SimplifiedKinesisClient underTest; - - @Test - public void shouldReturnIteratorStartingWithSequenceNumber() throws Exception { - given(kinesis.getShardIterator(new GetShardIteratorRequest() - .withStreamName(STREAM) - .withShardId(SHARD_1) - .withShardIteratorType(ShardIteratorType.AT_SEQUENCE_NUMBER) - .withStartingSequenceNumber(SEQUENCE_NUMBER) - )).willReturn(new GetShardIteratorResult() - .withShardIterator(SHARD_ITERATOR)); - - String stream = underTest.getShardIterator(STREAM, SHARD_1, - ShardIteratorType.AT_SEQUENCE_NUMBER, SEQUENCE_NUMBER, null); - - assertThat(stream).isEqualTo(SHARD_ITERATOR); - } - - @Test - public void shouldReturnIteratorStartingWithTimestamp() throws Exception { - Instant timestamp = Instant.now(); - given(kinesis.getShardIterator(new GetShardIteratorRequest() - .withStreamName(STREAM) - .withShardId(SHARD_1) - .withShardIteratorType(ShardIteratorType.AT_SEQUENCE_NUMBER) - .withTimestamp(timestamp.toDate()) - )).willReturn(new GetShardIteratorResult() - .withShardIterator(SHARD_ITERATOR)); - - String stream = underTest.getShardIterator(STREAM, SHARD_1, - ShardIteratorType.AT_SEQUENCE_NUMBER, null, timestamp); - - assertThat(stream).isEqualTo(SHARD_ITERATOR); - } - - @Test - public void shouldHandleExpiredIterationExceptionForGetShardIterator() { - shouldHandleGetShardIteratorError(new ExpiredIteratorException(""), - ExpiredIteratorException.class); - } - - @Test - public void shouldHandleLimitExceededExceptionForGetShardIterator() { - shouldHandleGetShardIteratorError(new LimitExceededException(""), - TransientKinesisException.class); - } - - @Test - public void shouldHandleProvisionedThroughputExceededExceptionForGetShardIterator() { - shouldHandleGetShardIteratorError(new ProvisionedThroughputExceededException(""), - TransientKinesisException.class); - } - - @Test - public void shouldHandleServiceErrorForGetShardIterator() { - shouldHandleGetShardIteratorError(newAmazonServiceException(ErrorType.Service), - TransientKinesisException.class); - } - - @Test - public void shouldHandleClientErrorForGetShardIterator() { - shouldHandleGetShardIteratorError(newAmazonServiceException(ErrorType.Client), - RuntimeException.class); - } - - @Test - public void shouldHandleUnexpectedExceptionForGetShardIterator() { - shouldHandleGetShardIteratorError(new NullPointerException(), - RuntimeException.class); - } - - private void shouldHandleGetShardIteratorError( - Exception thrownException, - Class expectedExceptionClass) { - GetShardIteratorRequest request = new GetShardIteratorRequest() - .withStreamName(STREAM) - .withShardId(SHARD_1) - .withShardIteratorType(ShardIteratorType.LATEST); - - given(kinesis.getShardIterator(request)).willThrow(thrownException); - - try { - underTest.getShardIterator(STREAM, SHARD_1, ShardIteratorType.LATEST, null, null); - failBecauseExceptionWasNotThrown(expectedExceptionClass); - } catch (Exception e) { - assertThat(e).isExactlyInstanceOf(expectedExceptionClass); - } finally { - reset(kinesis); - } - } - - @Test - public void shouldListAllShards() throws Exception { - Shard shard1 = new Shard().withShardId(SHARD_1); - Shard shard2 = new Shard().withShardId(SHARD_2); - Shard shard3 = new Shard().withShardId(SHARD_3); - given(kinesis.describeStream(STREAM, null)).willReturn(new DescribeStreamResult() - .withStreamDescription(new StreamDescription() - .withShards(shard1, shard2) - .withHasMoreShards(true))); - given(kinesis.describeStream(STREAM, SHARD_2)).willReturn(new DescribeStreamResult() - .withStreamDescription(new StreamDescription() - .withShards(shard3) - .withHasMoreShards(false))); - - List shards = underTest.listShards(STREAM); - - assertThat(shards).containsOnly(shard1, shard2, shard3); - } - - @Test - public void shouldHandleExpiredIterationExceptionForShardListing() { - shouldHandleShardListingError(new ExpiredIteratorException(""), - ExpiredIteratorException.class); - } - - @Test - public void shouldHandleLimitExceededExceptionForShardListing() { - shouldHandleShardListingError(new LimitExceededException(""), - TransientKinesisException.class); - } - - @Test - public void shouldHandleProvisionedThroughputExceededExceptionForShardListing() { - shouldHandleShardListingError(new ProvisionedThroughputExceededException(""), - TransientKinesisException.class); - } - @Test - public void shouldHandleServiceErrorForShardListing() { - shouldHandleShardListingError(newAmazonServiceException(ErrorType.Service), - TransientKinesisException.class); - } - - @Test - public void shouldHandleClientErrorForShardListing() { - shouldHandleShardListingError(newAmazonServiceException(ErrorType.Client), - RuntimeException.class); - } - - @Test - public void shouldHandleUnexpectedExceptionForShardListing() { - shouldHandleShardListingError(new NullPointerException(), - RuntimeException.class); - } - - private void shouldHandleShardListingError( - Exception thrownException, - Class expectedExceptionClass) { - given(kinesis.describeStream(STREAM, null)).willThrow(thrownException); - try { - underTest.listShards(STREAM); - failBecauseExceptionWasNotThrown(expectedExceptionClass); - } catch (Exception e) { - assertThat(e).isExactlyInstanceOf(expectedExceptionClass); - } finally { - reset(kinesis); - } - } - - private AmazonServiceException newAmazonServiceException(ErrorType errorType) { - AmazonServiceException exception = new AmazonServiceException(""); - exception.setErrorType(errorType); - return exception; - } + private static final String STREAM = "stream"; + private static final String SHARD_1 = "shard-01"; + private static final String SHARD_2 = "shard-02"; + private static final String SHARD_3 = "shard-03"; + private static final String SHARD_ITERATOR = "iterator"; + private static final String SEQUENCE_NUMBER = "abc123"; + + @Mock + private AmazonKinesis kinesis; + @InjectMocks + private SimplifiedKinesisClient underTest; + + @Test + public void shouldReturnIteratorStartingWithSequenceNumber() throws Exception { + given(kinesis.getShardIterator(new GetShardIteratorRequest() + .withStreamName(STREAM) + .withShardId(SHARD_1) + .withShardIteratorType(ShardIteratorType.AT_SEQUENCE_NUMBER) + .withStartingSequenceNumber(SEQUENCE_NUMBER) + )).willReturn(new GetShardIteratorResult() + .withShardIterator(SHARD_ITERATOR)); + + String stream = underTest.getShardIterator(STREAM, SHARD_1, + ShardIteratorType.AT_SEQUENCE_NUMBER, SEQUENCE_NUMBER, null); + + assertThat(stream).isEqualTo(SHARD_ITERATOR); + } + + @Test + public void shouldReturnIteratorStartingWithTimestamp() throws Exception { + Instant timestamp = Instant.now(); + given(kinesis.getShardIterator(new GetShardIteratorRequest() + .withStreamName(STREAM) + .withShardId(SHARD_1) + .withShardIteratorType(ShardIteratorType.AT_SEQUENCE_NUMBER) + .withTimestamp(timestamp.toDate()) + )).willReturn(new GetShardIteratorResult() + .withShardIterator(SHARD_ITERATOR)); + + String stream = underTest.getShardIterator(STREAM, SHARD_1, + ShardIteratorType.AT_SEQUENCE_NUMBER, null, timestamp); + + assertThat(stream).isEqualTo(SHARD_ITERATOR); + } + + @Test + public void shouldHandleExpiredIterationExceptionForGetShardIterator() { + shouldHandleGetShardIteratorError(new ExpiredIteratorException(""), + ExpiredIteratorException.class); + } + + @Test + public void shouldHandleLimitExceededExceptionForGetShardIterator() { + shouldHandleGetShardIteratorError(new LimitExceededException(""), + TransientKinesisException.class); + } + + @Test + public void shouldHandleProvisionedThroughputExceededExceptionForGetShardIterator() { + shouldHandleGetShardIteratorError(new ProvisionedThroughputExceededException(""), + TransientKinesisException.class); + } + + @Test + public void shouldHandleServiceErrorForGetShardIterator() { + shouldHandleGetShardIteratorError(newAmazonServiceException(ErrorType.Service), + TransientKinesisException.class); + } + + @Test + public void shouldHandleClientErrorForGetShardIterator() { + shouldHandleGetShardIteratorError(newAmazonServiceException(ErrorType.Client), + RuntimeException.class); + } + + @Test + public void shouldHandleUnexpectedExceptionForGetShardIterator() { + shouldHandleGetShardIteratorError(new NullPointerException(), + RuntimeException.class); + } + + private void shouldHandleGetShardIteratorError( + Exception thrownException, + Class expectedExceptionClass) { + GetShardIteratorRequest request = new GetShardIteratorRequest() + .withStreamName(STREAM) + .withShardId(SHARD_1) + .withShardIteratorType(ShardIteratorType.LATEST); + + given(kinesis.getShardIterator(request)).willThrow(thrownException); + + try { + underTest.getShardIterator(STREAM, SHARD_1, ShardIteratorType.LATEST, null, null); + failBecauseExceptionWasNotThrown(expectedExceptionClass); + } catch (Exception e) { + assertThat(e).isExactlyInstanceOf(expectedExceptionClass); + } finally { + reset(kinesis); + } + } + + @Test + public void shouldListAllShards() throws Exception { + Shard shard1 = new Shard().withShardId(SHARD_1); + Shard shard2 = new Shard().withShardId(SHARD_2); + Shard shard3 = new Shard().withShardId(SHARD_3); + given(kinesis.describeStream(STREAM, null)).willReturn(new DescribeStreamResult() + .withStreamDescription(new StreamDescription() + .withShards(shard1, shard2) + .withHasMoreShards(true))); + given(kinesis.describeStream(STREAM, SHARD_2)).willReturn(new DescribeStreamResult() + .withStreamDescription(new StreamDescription() + .withShards(shard3) + .withHasMoreShards(false))); + + List shards = underTest.listShards(STREAM); + + assertThat(shards).containsOnly(shard1, shard2, shard3); + } + + @Test + public void shouldHandleExpiredIterationExceptionForShardListing() { + shouldHandleShardListingError(new ExpiredIteratorException(""), + ExpiredIteratorException.class); + } + + @Test + public void shouldHandleLimitExceededExceptionForShardListing() { + shouldHandleShardListingError(new LimitExceededException(""), + TransientKinesisException.class); + } + + @Test + public void shouldHandleProvisionedThroughputExceededExceptionForShardListing() { + shouldHandleShardListingError(new ProvisionedThroughputExceededException(""), + TransientKinesisException.class); + } + + @Test + public void shouldHandleServiceErrorForShardListing() { + shouldHandleShardListingError(newAmazonServiceException(ErrorType.Service), + TransientKinesisException.class); + } + + @Test + public void shouldHandleClientErrorForShardListing() { + shouldHandleShardListingError(newAmazonServiceException(ErrorType.Client), + RuntimeException.class); + } + + @Test + public void shouldHandleUnexpectedExceptionForShardListing() { + shouldHandleShardListingError(new NullPointerException(), + RuntimeException.class); + } + + private void shouldHandleShardListingError( + Exception thrownException, + Class expectedExceptionClass) { + given(kinesis.describeStream(STREAM, null)).willThrow(thrownException); + try { + underTest.listShards(STREAM); + failBecauseExceptionWasNotThrown(expectedExceptionClass); + } catch (Exception e) { + assertThat(e).isExactlyInstanceOf(expectedExceptionClass); + } finally { + reset(kinesis); + } + } + + private AmazonServiceException newAmazonServiceException(ErrorType errorType) { + AmazonServiceException exception = new AmazonServiceException(""); + exception.setErrorType(errorType); + return exception; + } } From 435cbcfae67c3a2bc8a72437fbdf7350fe5ac10a Mon Sep 17 00:00:00 2001 From: Colin Phipps Date: Mon, 26 Jun 2017 13:34:19 +0000 Subject: [PATCH 192/200] Add client-side throttling. The approach used is as described in https://landing.google.com/sre/book/chapters/handling-overload.html#client-side-throttling-a7sYUg . By backing off individual workers in response to high error rates, we relieve pressure on the Datastore service, increasing the chance that the workload can complete successfully. The exported cumulativeThrottledSeconds could also be used as an autoscaling signal in future. --- .../io/gcp/datastore/AdaptiveThrottler.java | 103 ++++++++++++++++ .../sdk/io/gcp/datastore/DatastoreV1.java | 25 +++- .../gcp/datastore/AdaptiveThrottlerTest.java | 111 ++++++++++++++++++ 3 files changed, 238 insertions(+), 1 deletion(-) create mode 100644 sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/datastore/AdaptiveThrottler.java create mode 100644 sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/datastore/AdaptiveThrottlerTest.java diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/datastore/AdaptiveThrottler.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/datastore/AdaptiveThrottler.java new file mode 100644 index 0000000000000..ce6ebe63e5a03 --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/datastore/AdaptiveThrottler.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.sdk.io.gcp.datastore; + +import com.google.common.annotations.VisibleForTesting; +import java.util.Random; +import org.apache.beam.sdk.transforms.Sum; +import org.apache.beam.sdk.util.MovingFunction; + + +/** + * An implementation of client-side adaptive throttling. See + * https://landing.google.com/sre/book/chapters/handling-overload.html#client-side-throttling-a7sYUg + * for a full discussion of the use case and algorithm applied. + */ +class AdaptiveThrottler { + private final MovingFunction successfulRequests; + private final MovingFunction allRequests; + + /** The target ratio between requests sent and successful requests. This is "K" in the formula in + * https://landing.google.com/sre/book/chapters/handling-overload.html */ + private final double overloadRatio; + + /** The target minimum number of requests per samplePeriodMs, even if no requests succeed. Must be + * greater than 0, else we could throttle to zero. Because every decision is probabilistic, there + * is no guarantee that the request rate in any given interval will not be zero. (This is the +1 + * from the formula in https://landing.google.com/sre/book/chapters/handling-overload.html */ + private static final double MIN_REQUESTS = 1; + private final Random random; + + /** + * @param samplePeriodMs the time window to keep of request history to inform throttling + * decisions. + * @param sampleUpdateMs the length of buckets within this time window. + * @param overloadRatio the target ratio between requests sent and successful requests. You should + * always set this to more than 1, otherwise the client would never try to send more requests than + * succeeded in the past - so it could never recover from temporary setbacks. + */ + public AdaptiveThrottler(long samplePeriodMs, long sampleUpdateMs, + double overloadRatio) { + this(samplePeriodMs, sampleUpdateMs, overloadRatio, new Random()); + } + + @VisibleForTesting + AdaptiveThrottler(long samplePeriodMs, long sampleUpdateMs, + double overloadRatio, Random random) { + allRequests = + new MovingFunction(samplePeriodMs, sampleUpdateMs, + 1 /* numSignificantBuckets */, 1 /* numSignificantSamples */, Sum.ofLongs()); + successfulRequests = + new MovingFunction(samplePeriodMs, sampleUpdateMs, + 1 /* numSignificantBuckets */, 1 /* numSignificantSamples */, Sum.ofLongs()); + this.overloadRatio = overloadRatio; + this.random = random; + } + + @VisibleForTesting + double throttlingProbability(long nowMsSinceEpoch) { + if (!allRequests.isSignificant()) { + return 0; + } + long allRequestsNow = allRequests.get(nowMsSinceEpoch); + long successfulRequestsNow = successfulRequests.get(nowMsSinceEpoch); + return Math.max(0, + (allRequestsNow - overloadRatio * successfulRequestsNow) / (allRequestsNow + MIN_REQUESTS)); + } + + /** + * Call this before sending a request to the remote service; if this returns true, drop the + * request (treating it as a failure or trying it again at a later time). + */ + public boolean throttleRequest(long nowMsSinceEpoch) { + double delayProbability = throttlingProbability(nowMsSinceEpoch); + // Note that we increment the count of all requests here, even if we return true - so even if we + // tell the client not to send a request at all, it still counts as a failed request. + allRequests.add(nowMsSinceEpoch, 1); + + return (random.nextDouble() < delayProbability); + } + + /** + * Call this after {@link throttleRequest} if your request was successful. + */ + public void successfulRequest(long nowMsSinceEpoch) { + successfulRequests.add(nowMsSinceEpoch, 1); + } +} diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/datastore/DatastoreV1.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/datastore/DatastoreV1.java index e67f4b2fcd972..5f65428141af3 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/datastore/DatastoreV1.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/datastore/DatastoreV1.java @@ -71,6 +71,8 @@ import org.apache.beam.sdk.annotations.Experimental.Kind; import org.apache.beam.sdk.coders.SerializableCoder; import org.apache.beam.sdk.extensions.gcp.options.GcpOptions; +import org.apache.beam.sdk.metrics.Counter; +import org.apache.beam.sdk.metrics.Metrics; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.ValueProvider; import org.apache.beam.sdk.options.ValueProvider.StaticValueProvider; @@ -1209,6 +1211,13 @@ static class DatastoreWriterFn extends DoFn { private final List mutations = new ArrayList<>(); private int mutationsSize = 0; // Accumulated size of protos in mutations. private WriteBatcher writeBatcher; + private transient AdaptiveThrottler throttler; + private final Counter throttledSeconds = + Metrics.counter(DatastoreWriterFn.class, "cumulativeThrottlingSeconds"); + private final Counter rpcErrors = + Metrics.counter(DatastoreWriterFn.class, "datastoreRpcErrors"); + private final Counter rpcSuccesses = + Metrics.counter(DatastoreWriterFn.class, "datastoreRpcSuccesses"); private static final int MAX_RETRIES = 5; private static final FluentBackoff BUNDLE_WRITE_BACKOFF = @@ -1237,6 +1246,10 @@ static class DatastoreWriterFn extends DoFn { public void startBundle(StartBundleContext c) { datastore = datastoreFactory.getDatastore(c.getPipelineOptions(), projectId.get(), localhost); writeBatcher.start(); + if (throttler == null) { + // Initialize throttler at first use, because it is not serializable. + throttler = new AdaptiveThrottler(120000, 10000, 1.25); + } } @ProcessElement @@ -1284,11 +1297,20 @@ private void flushBatch() throws DatastoreException, IOException, InterruptedExc commitRequest.setMode(CommitRequest.Mode.NON_TRANSACTIONAL); long startTime = System.currentTimeMillis(), endTime; + if (throttler.throttleRequest(startTime)) { + LOG.info("Delaying request due to previous failures"); + throttledSeconds.inc(WriteBatcherImpl.DATASTORE_BATCH_TARGET_LATENCY_MS / 1000); + sleeper.sleep(WriteBatcherImpl.DATASTORE_BATCH_TARGET_LATENCY_MS); + continue; + } + try { datastore.commit(commitRequest.build()); endTime = System.currentTimeMillis(); writeBatcher.addRequestLatency(endTime, endTime - startTime, mutations.size()); + throttler.successfulRequest(startTime); + rpcSuccesses.inc(); // Break if the commit threw no exception. break; @@ -1300,11 +1322,12 @@ private void flushBatch() throws DatastoreException, IOException, InterruptedExc endTime = System.currentTimeMillis(); writeBatcher.addRequestLatency(endTime, endTime - startTime, mutations.size()); } - // Only log the code and message for potentially-transient errors. The entire exception // will be propagated upon the last retry. LOG.error("Error writing batch of {} mutations to Datastore ({}): {}", mutations.size(), exception.getCode(), exception.getMessage()); + rpcErrors.inc(); + if (!BackOffUtils.next(sleeper, backoff)) { LOG.error("Aborting after {} retries.", MAX_RETRIES); throw exception; diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/datastore/AdaptiveThrottlerTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/datastore/AdaptiveThrottlerTest.java new file mode 100644 index 0000000000000..c12cf55447887 --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/datastore/AdaptiveThrottlerTest.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.datastore; + +import static org.hamcrest.Matchers.closeTo; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThat; + +import java.util.Random; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mockito; + +/** + * Tests for {@link AdaptiveThrottler}. + */ +@RunWith(JUnit4.class) +public class AdaptiveThrottlerTest { + + static final long START_TIME_MS = 0; + static final long SAMPLE_PERIOD_MS = 60000; + static final long SAMPLE_BUCKET_MS = 1000; + static final double OVERLOAD_RATIO = 2; + + /** Returns a throttler configured with the standard parameters above. */ + AdaptiveThrottler getThrottler() { + return new AdaptiveThrottler(SAMPLE_PERIOD_MS, SAMPLE_BUCKET_MS, OVERLOAD_RATIO); + } + + @Test + public void testNoInitialThrottling() throws Exception { + AdaptiveThrottler throttler = getThrottler(); + assertThat(throttler.throttlingProbability(START_TIME_MS), equalTo(0.0)); + assertThat("first request is not throttled", + throttler.throttleRequest(START_TIME_MS), equalTo(false)); + } + + @Test + public void testNoThrottlingIfNoErrors() throws Exception { + AdaptiveThrottler throttler = getThrottler(); + long t = START_TIME_MS; + for (; t < START_TIME_MS + 20; t++) { + assertFalse(throttler.throttleRequest(t)); + throttler.successfulRequest(t); + } + assertThat(throttler.throttlingProbability(t), equalTo(0.0)); + } + + @Test + public void testNoThrottlingAfterErrorsExpire() throws Exception { + AdaptiveThrottler throttler = getThrottler(); + long t = START_TIME_MS; + for (; t < START_TIME_MS + SAMPLE_PERIOD_MS; t++) { + throttler.throttleRequest(t); + // and no successfulRequest. + } + assertThat("check that we set up a non-zero probability of throttling", + throttler.throttlingProbability(t), greaterThan(0.0)); + for (; t < START_TIME_MS + 2 * SAMPLE_PERIOD_MS; t++) { + throttler.throttleRequest(t); + throttler.successfulRequest(t); + } + assertThat(throttler.throttlingProbability(t), equalTo(0.0)); + } + + @Test + public void testThrottlingAfterErrors() throws Exception { + Random mockRandom = Mockito.mock(Random.class); + Mockito.when(mockRandom.nextDouble()).thenReturn( + 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, + 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9); + AdaptiveThrottler throttler = new AdaptiveThrottler( + SAMPLE_PERIOD_MS, SAMPLE_BUCKET_MS, OVERLOAD_RATIO, mockRandom); + for (int i = 0; i < 20; i++) { + boolean throttled = throttler.throttleRequest(START_TIME_MS + i); + // 1/3rd of requests succeeding. + if (i % 3 == 1) { + throttler.successfulRequest(START_TIME_MS + i); + } + + // Once we have some history in place, check what throttling happens. + if (i >= 10) { + // Expect 1/3rd of requests to be throttled. (So 1/3rd throttled, 1/3rd succeeding, 1/3rd + // tried and failing). + assertThat(String.format("for i=%d", i), + throttler.throttlingProbability(START_TIME_MS + i), closeTo(0.33, /*error=*/ 0.1)); + // Requests 10..13 should be throttled, 14..19 not throttled given the mocked random numbers + // that we fed to throttler. + assertThat(String.format("for i=%d", i), throttled, equalTo(i < 14)); + } + } + } +} From aaffe15ee891b5656e448eac4bd3a7ff72eee315 Mon Sep 17 00:00:00 2001 From: Kenneth Knowles Date: Tue, 11 Jul 2017 10:09:12 -0700 Subject: [PATCH 193/200] Remove dead (and wrong) viewFromProto overload --- .../core/construction/ParDoTranslation.java | 21 ------------------- .../construction/ParDoTranslationTest.java | 2 +- 2 files changed, 1 insertion(+), 22 deletions(-) diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java index 90c9aadfdfb86..03f29ff319b88 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java @@ -41,7 +41,6 @@ import java.util.Map; import java.util.Set; import org.apache.beam.runners.core.construction.PTransformTranslation.TransformPayloadTranslator; -import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.IterableCoder; import org.apache.beam.sdk.common.runner.v1.RunnerApi; @@ -509,26 +508,6 @@ private static SideInput toProto(PCollectionView view) { return builder.build(); } - public static PCollectionView viewFromProto( - Pipeline pipeline, - SideInput sideInput, - String localName, - RunnerApi.PTransform parDoTransform, - Components components) - throws IOException { - - String pCollectionId = parDoTransform.getInputsOrThrow(localName); - - // This may be a PCollection defined in another language, but we should be - // able to rehydrate it enough to stick it in a side input. The coder may not - // be grokkable in Java. - PCollection pCollection = - PCollectionTranslation.fromProto( - pipeline, components.getPcollectionsOrThrow(pCollectionId), components); - - return viewFromProto(sideInput, localName, pCollection, parDoTransform, components); - } - /** * Create a {@link PCollectionView} from a side input spec and an already-deserialized {@link * PCollection} that should be wired up. diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/ParDoTranslationTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/ParDoTranslationTest.java index 6fdf9d6ad8b73..a87a16d213617 100644 --- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/ParDoTranslationTest.java +++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/ParDoTranslationTest.java @@ -162,9 +162,9 @@ public void toAndFromTransformProto() throws Exception { SideInput sideInput = parDoPayload.getSideInputsOrThrow(view.getTagInternal().getId()); PCollectionView restoredView = ParDoTranslation.viewFromProto( - rehydratedPipeline, sideInput, view.getTagInternal().getId(), + view.getPCollection(), protoTransform, protoComponents); assertThat(restoredView.getTagInternal(), equalTo(view.getTagInternal())); From cd216f796bebf78101dce7ab6387f3db9b839fc7 Mon Sep 17 00:00:00 2001 From: Eugene Kirpichov Date: Fri, 23 Jun 2017 18:02:10 -0700 Subject: [PATCH 194/200] Adds TextIO.readAll(), implemented rather naively --- ...edSplittableProcessElementInvokerTest.java | 2 +- .../core/SplittableParDoProcessFnTest.java | 2 +- .../DataflowPipelineTranslatorTest.java | 2 +- .../apache/beam/sdk/io/CompressedSource.java | 40 ++- .../apache/beam/sdk/io/OffsetBasedSource.java | 22 +- .../java/org/apache/beam/sdk/io/TextIO.java | 230 ++++++++++++++++-- .../range}/OffsetRange.java | 32 ++- .../beam/sdk/io/range/OffsetRangeTracker.java | 3 + .../splittabledofn/OffsetRangeTracker.java | 1 + .../org/apache/beam/sdk/io/TextIOTest.java | 62 +++-- .../sdk/transforms/SplittableDoFnTest.java | 2 +- .../OffsetRangeTrackerTest.java | 1 + 12 files changed, 314 insertions(+), 85 deletions(-) rename sdks/java/core/src/main/java/org/apache/beam/sdk/{transforms/splittabledofn => io/range}/OffsetRange.java (61%) diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvokerTest.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvokerTest.java index a2f6acc9cdeb3..b80a6326d906d 100644 --- a/runners/core-java/src/test/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvokerTest.java +++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvokerTest.java @@ -25,10 +25,10 @@ import java.util.Collection; import java.util.concurrent.Executors; +import org.apache.beam.sdk.io.range.OffsetRange; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.reflect.DoFnInvokers; -import org.apache.beam.sdk.transforms.splittabledofn.OffsetRange; import org.apache.beam.sdk.transforms.splittabledofn.OffsetRangeTracker; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.GlobalWindow; diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/SplittableParDoProcessFnTest.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/SplittableParDoProcessFnTest.java index 9543de8c61a00..1cd127547cf12 100644 --- a/runners/core-java/src/test/java/org/apache/beam/runners/core/SplittableParDoProcessFnTest.java +++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/SplittableParDoProcessFnTest.java @@ -39,11 +39,11 @@ import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.InstantCoder; import org.apache.beam.sdk.coders.SerializableCoder; +import org.apache.beam.sdk.io.range.OffsetRange; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.DoFnTester; import org.apache.beam.sdk.transforms.splittabledofn.HasDefaultTracker; -import org.apache.beam.sdk.transforms.splittabledofn.OffsetRange; import org.apache.beam.sdk.transforms.splittabledofn.OffsetRangeTracker; import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; diff --git a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslatorTest.java b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslatorTest.java index 948af1cf606ac..43b27880ee37a 100644 --- a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslatorTest.java +++ b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslatorTest.java @@ -84,6 +84,7 @@ import org.apache.beam.sdk.extensions.gcp.storage.GcsPathValidator; import org.apache.beam.sdk.io.FileSystems; import org.apache.beam.sdk.io.TextIO; +import org.apache.beam.sdk.io.range.OffsetRange; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.options.ValueProvider; @@ -98,7 +99,6 @@ import org.apache.beam.sdk.transforms.Sum; import org.apache.beam.sdk.transforms.View; import org.apache.beam.sdk.transforms.display.DisplayData; -import org.apache.beam.sdk.transforms.splittabledofn.OffsetRange; import org.apache.beam.sdk.transforms.splittabledofn.OffsetRangeTracker; import org.apache.beam.sdk.transforms.windowing.FixedWindows; import org.apache.beam.sdk.transforms.windowing.Window; diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/CompressedSource.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/CompressedSource.java index 6ab8dec3db32b..4baac367f6869 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/CompressedSource.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/CompressedSource.java @@ -96,12 +96,6 @@ private interface FileNameBasedDecompressingChannelFactory */ ReadableByteChannel createDecompressingChannel(String fileName, ReadableByteChannel channel) throws IOException; - - /** - * Given a file name, returns true if the file name matches any supported compression - * scheme. - */ - boolean isCompressed(String fileName); } /** @@ -242,6 +236,16 @@ public int read(byte[] b, int off, int len) throws IOException { @Override public abstract ReadableByteChannel createDecompressingChannel(ReadableByteChannel channel) throws IOException; + + /** Returns whether the file's extension matches of one of the known compression formats. */ + public static boolean isCompressed(String filename) { + for (CompressionMode type : CompressionMode.values()) { + if (type.matches(filename)) { + return true; + } + } + return false; + } } /** @@ -273,16 +277,6 @@ public ReadableByteChannel createDecompressingChannel(ReadableByteChannel channe ReadableByteChannel.class.getSimpleName(), ReadableByteChannel.class.getSimpleName())); } - - @Override - public boolean isCompressed(String fileName) { - for (CompressionMode type : CompressionMode.values()) { - if (type.matches(fileName)) { - return true; - } - } - return false; - } } private final FileBasedSource sourceDelegate; @@ -366,13 +360,9 @@ protected FileBasedSource createForSubrangeOfFile(Metadata metadata, long sta */ @Override protected final boolean isSplittable() throws Exception { - if (channelFactory instanceof FileNameBasedDecompressingChannelFactory) { - FileNameBasedDecompressingChannelFactory fileNameBasedChannelFactory = - (FileNameBasedDecompressingChannelFactory) channelFactory; - return !fileNameBasedChannelFactory.isCompressed(getFileOrPatternSpec()) - && sourceDelegate.isSplittable(); - } - return false; + return channelFactory instanceof FileNameBasedDecompressingChannelFactory + && !CompressionMode.isCompressed(getFileOrPatternSpec()) + && sourceDelegate.isSplittable(); } /** @@ -386,9 +376,7 @@ protected final boolean isSplittable() throws Exception { @Override protected final FileBasedReader createSingleFileReader(PipelineOptions options) { if (channelFactory instanceof FileNameBasedDecompressingChannelFactory) { - FileNameBasedDecompressingChannelFactory fileNameBasedChannelFactory = - (FileNameBasedDecompressingChannelFactory) channelFactory; - if (!fileNameBasedChannelFactory.isCompressed(getFileOrPatternSpec())) { + if (!CompressionMode.isCompressed(getFileOrPatternSpec())) { return sourceDelegate.createSingleFileReader(options); } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/OffsetBasedSource.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/OffsetBasedSource.java index 05f0d97d86fde..c3687a95232c0 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/OffsetBasedSource.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/OffsetBasedSource.java @@ -23,6 +23,7 @@ import java.util.ArrayList; import java.util.List; import java.util.NoSuchElementException; +import org.apache.beam.sdk.io.range.OffsetRange; import org.apache.beam.sdk.io.range.OffsetRangeTracker; import org.apache.beam.sdk.io.range.RangeTracker; import org.apache.beam.sdk.options.PipelineOptions; @@ -110,8 +111,7 @@ public long getEstimatedSizeBytes(PipelineOptions options) throws Exception { @Override public List> split( long desiredBundleSizeBytes, PipelineOptions options) throws Exception { - // Split the range into bundles based on the desiredBundleSizeBytes. Final bundle is adjusted to - // make sure that we do not end up with a too small bundle at the end. If the desired bundle + // Split the range into bundles based on the desiredBundleSizeBytes. If the desired bundle // size is smaller than the minBundleSize of the source then minBundleSize will be used instead. long desiredBundleSizeOffsetUnits = Math.max( @@ -119,20 +119,10 @@ public List> split( minBundleSize); List> subSources = new ArrayList<>(); - long start = startOffset; - long maxEnd = Math.min(endOffset, getMaxEndOffset(options)); - - while (start < maxEnd) { - long end = start + desiredBundleSizeOffsetUnits; - end = Math.min(end, maxEnd); - // Avoid having a too small bundle at the end and ensure that we respect minBundleSize. - long remaining = maxEnd - end; - if ((remaining < desiredBundleSizeOffsetUnits / 4) || (remaining < minBundleSize)) { - end = maxEnd; - } - subSources.add(createSourceForSubrange(start, end)); - - start = end; + for (OffsetRange range : + new OffsetRange(startOffset, Math.min(endOffset, getMaxEndOffset(options))) + .split(desiredBundleSizeOffsetUnits, minBundleSize)) { + subSources.add(createSourceForSubrange(range.getFrom(), range.getTo())); } return subSources; } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TextIO.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TextIO.java index 524158968237a..78340f32a16a9 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TextIO.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TextIO.java @@ -23,25 +23,37 @@ import com.google.auto.value.AutoValue; import com.google.common.annotations.VisibleForTesting; +import java.io.IOException; +import java.util.concurrent.ThreadLocalRandom; import javax.annotation.Nullable; import org.apache.beam.sdk.annotations.Experimental; import org.apache.beam.sdk.annotations.Experimental.Kind; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.coders.VoidCoder; +import org.apache.beam.sdk.io.CompressedSource.CompressionMode; import org.apache.beam.sdk.io.DefaultFilenamePolicy.Params; import org.apache.beam.sdk.io.FileBasedSink.DynamicDestinations; import org.apache.beam.sdk.io.FileBasedSink.FilenamePolicy; import org.apache.beam.sdk.io.FileBasedSink.WritableByteChannelFactory; import org.apache.beam.sdk.io.Read.Bounded; +import org.apache.beam.sdk.io.fs.MatchResult; +import org.apache.beam.sdk.io.fs.MatchResult.Metadata; +import org.apache.beam.sdk.io.fs.MatchResult.Status; import org.apache.beam.sdk.io.fs.ResourceId; +import org.apache.beam.sdk.io.range.OffsetRange; import org.apache.beam.sdk.options.ValueProvider; import org.apache.beam.sdk.options.ValueProvider.NestedValueProvider; import org.apache.beam.sdk.options.ValueProvider.StaticValueProvider; +import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.Reshuffle; import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.transforms.SerializableFunctions; +import org.apache.beam.sdk.transforms.Values; import org.apache.beam.sdk.transforms.display.DisplayData; +import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PDone; @@ -51,13 +63,14 @@ * *

    To read a {@link PCollection} from one or more text files, use {@code TextIO.read()} to * instantiate a transform and use {@link TextIO.Read#from(String)} to specify the path of the - * file(s) to be read. + * file(s) to be read. Alternatively, if the filenames to be read are themselves in a + * {@link PCollection}, apply {@link TextIO#readAll()}. * *

    {@link TextIO.Read} returns a {@link PCollection} of {@link String Strings}, each * corresponding to one line of an input UTF-8 text file (split into lines delimited by '\n', '\r', * or '\r\n'). * - *

    Example: + *

    Example 1: reading a file or filepattern. * *

    {@code
      * Pipeline p = ...;
    @@ -66,6 +79,19 @@
      * PCollection lines = p.apply(TextIO.read().from("/local/path/to/file.txt"));
      * }
    * + *

    Example 2: reading a PCollection of filenames. + * + *

    {@code
    + * Pipeline p = ...;
    + *
    + * // E.g. the filenames might be computed from other data in the pipeline, or
    + * // read from a data source.
    + * PCollection filenames = ...;
    + *
    + * // Read all files in the collection.
    + * PCollection lines = filenames.apply(TextIO.readAll());
    + * }
    + * *

    To write a {@link PCollection} to one or more text files, use {@code TextIO.write()}, using * {@link TextIO.Write#to(String)} to specify the output prefix of the files to write. * @@ -131,6 +157,26 @@ public static Read read() { return new AutoValue_TextIO_Read.Builder().setCompressionType(CompressionType.AUTO).build(); } + /** + * A {@link PTransform} that works like {@link #read}, but reads each file in a {@link + * PCollection} of filepatterns. + * + *

    Can be applied to both bounded and unbounded {@link PCollection PCollections}, so this is + * suitable for reading a {@link PCollection} of filepatterns arriving as a stream. However, every + * filepattern is expanded once at the moment it is processed, rather than watched for new files + * matching the filepattern to appear. Likewise, every file is read once, rather than watched for + * new entries. + */ + public static ReadAll readAll() { + return new AutoValue_TextIO_ReadAll.Builder() + .setCompressionType(CompressionType.AUTO) + // 64MB is a reasonable value that allows to amortize the cost of opening files, + // but is not so large as to exhaust a typical runner's maximum amount of output per + // ProcessElement call. + .setDesiredBundleSizeBytes(64 * 1024 * 1024L) + .build(); + } + /** * A {@link PTransform} that writes a {@link PCollection} to a text file (or multiple text files * matching a sharding pattern), with each element of the input collection encoded into its own @@ -228,29 +274,34 @@ public PCollection expand(PBegin input) { // Helper to create a source specific to the requested compression type. protected FileBasedSource getSource() { - switch (getCompressionType()) { + return wrapWithCompression(new TextSource(getFilepattern()), getCompressionType()); + } + + private static FileBasedSource wrapWithCompression( + FileBasedSource source, CompressionType compressionType) { + switch (compressionType) { case UNCOMPRESSED: - return new TextSource(getFilepattern()); + return source; case AUTO: - return CompressedSource.from(new TextSource(getFilepattern())); + return CompressedSource.from(source); case BZIP2: return - CompressedSource.from(new TextSource(getFilepattern())) - .withDecompression(CompressedSource.CompressionMode.BZIP2); + CompressedSource.from(source) + .withDecompression(CompressionMode.BZIP2); case GZIP: return - CompressedSource.from(new TextSource(getFilepattern())) - .withDecompression(CompressedSource.CompressionMode.GZIP); + CompressedSource.from(source) + .withDecompression(CompressionMode.GZIP); case ZIP: return - CompressedSource.from(new TextSource(getFilepattern())) - .withDecompression(CompressedSource.CompressionMode.ZIP); + CompressedSource.from(source) + .withDecompression(CompressionMode.ZIP); case DEFLATE: return - CompressedSource.from(new TextSource(getFilepattern())) - .withDecompression(CompressedSource.CompressionMode.DEFLATE); + CompressedSource.from(source) + .withDecompression(CompressionMode.DEFLATE); default: - throw new IllegalArgumentException("Unknown compression type: " + getFilepattern()); + throw new IllegalArgumentException("Unknown compression type: " + compressionType); } } @@ -273,7 +324,156 @@ protected Coder getDefaultOutputCoder() { } } - // /////////////////////////////////////////////////////////////////////////// + ///////////////////////////////////////////////////////////////////////////// + + /** Implementation of {@link #readAll}. */ + @AutoValue + public abstract static class ReadAll + extends PTransform, PCollection> { + abstract CompressionType getCompressionType(); + abstract long getDesiredBundleSizeBytes(); + + abstract Builder toBuilder(); + + @AutoValue.Builder + abstract static class Builder { + abstract Builder setCompressionType(CompressionType compressionType); + abstract Builder setDesiredBundleSizeBytes(long desiredBundleSizeBytes); + + abstract ReadAll build(); + } + + /** Same as {@link Read#withCompressionType(CompressionType)}. */ + public ReadAll withCompressionType(CompressionType compressionType) { + return toBuilder().setCompressionType(compressionType).build(); + } + + @VisibleForTesting + ReadAll withDesiredBundleSizeBytes(long desiredBundleSizeBytes) { + return toBuilder().setDesiredBundleSizeBytes(desiredBundleSizeBytes).build(); + } + + @Override + public PCollection expand(PCollection input) { + return input + .apply("Expand glob", ParDo.of(new ExpandGlobFn())) + .apply( + "Split into ranges", + ParDo.of(new SplitIntoRangesFn(getCompressionType(), getDesiredBundleSizeBytes()))) + .apply("Reshuffle", new ReshuffleWithUniqueKey>()) + .apply("Read", ParDo.of(new ReadTextFn(this))); + } + + private static class ReshuffleWithUniqueKey + extends PTransform, PCollection> { + @Override + public PCollection expand(PCollection input) { + return input + .apply("Unique key", ParDo.of(new AssignUniqueKeyFn())) + .apply("Reshuffle", Reshuffle.of()) + .apply("Values", Values.create()); + } + } + + private static class AssignUniqueKeyFn extends DoFn> { + private int index; + + @Setup + public void setup() { + this.index = ThreadLocalRandom.current().nextInt(); + } + + @ProcessElement + public void process(ProcessContext c) { + c.output(KV.of(++index, c.element())); + } + } + + private static class ExpandGlobFn extends DoFn { + @ProcessElement + public void process(ProcessContext c) throws Exception { + MatchResult match = FileSystems.match(c.element()); + checkArgument( + match.status().equals(Status.OK), + "Failed to match filepattern %s: %s", + c.element(), + match.status()); + for (Metadata metadata : match.metadata()) { + c.output(metadata); + } + } + } + + private static class SplitIntoRangesFn extends DoFn> { + private final CompressionType compressionType; + private final long desiredBundleSize; + + private SplitIntoRangesFn(CompressionType compressionType, long desiredBundleSize) { + this.compressionType = compressionType; + this.desiredBundleSize = desiredBundleSize; + } + + @ProcessElement + public void process(ProcessContext c) { + Metadata metadata = c.element(); + final boolean isSplittable = isSplittable(metadata, compressionType); + if (!isSplittable) { + c.output(KV.of(metadata, new OffsetRange(0, metadata.sizeBytes()))); + return; + } + for (OffsetRange range : + new OffsetRange(0, metadata.sizeBytes()).split(desiredBundleSize, 0)) { + c.output(KV.of(metadata, range)); + } + } + + static boolean isSplittable(Metadata metadata, CompressionType compressionType) { + if (!metadata.isReadSeekEfficient()) { + return false; + } + switch (compressionType) { + case AUTO: + return !CompressionMode.isCompressed(metadata.resourceId().toString()); + case UNCOMPRESSED: + return true; + case GZIP: + case BZIP2: + case ZIP: + case DEFLATE: + return false; + default: + throw new UnsupportedOperationException("Unknown compression type: " + compressionType); + } + } + } + + private static class ReadTextFn extends DoFn, String> { + private final TextIO.ReadAll spec; + + private ReadTextFn(ReadAll spec) { + this.spec = spec; + } + + @ProcessElement + public void process(ProcessContext c) throws IOException { + Metadata metadata = c.element().getKey(); + OffsetRange range = c.element().getValue(); + FileBasedSource source = + TextIO.Read.wrapWithCompression( + new TextSource(StaticValueProvider.of(metadata.toString())), + spec.getCompressionType()); + BoundedSource.BoundedReader reader = + source + .createForSubrangeOfFile(metadata, range.getFrom(), range.getTo()) + .createReader(c.getPipelineOptions()); + for (boolean more = reader.start(); more; more = reader.advance()) { + c.output(reader.getCurrent()); + } + } + } + } + + ///////////////////////////////////////////////////////////////////////////// /** Implementation of {@link #write}. */ @AutoValue diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/splittabledofn/OffsetRange.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/range/OffsetRange.java similarity index 61% rename from sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/splittabledofn/OffsetRange.java rename to sdks/java/core/src/main/java/org/apache/beam/sdk/io/range/OffsetRange.java index 104f5f2564a06..d3bff3739bb50 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/splittabledofn/OffsetRange.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/range/OffsetRange.java @@ -15,15 +15,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.beam.sdk.transforms.splittabledofn; +package org.apache.beam.sdk.io.range; import static com.google.common.base.Preconditions.checkArgument; import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; +import org.apache.beam.sdk.transforms.splittabledofn.HasDefaultTracker; /** A restriction represented by a range of integers [from, to). */ public class OffsetRange - implements Serializable, HasDefaultTracker { + implements Serializable, + HasDefaultTracker< + OffsetRange, org.apache.beam.sdk.transforms.splittabledofn.OffsetRangeTracker> { private final long from; private final long to; @@ -42,8 +47,8 @@ public long getTo() { } @Override - public OffsetRangeTracker newTracker() { - return new OffsetRangeTracker(this); + public org.apache.beam.sdk.transforms.splittabledofn.OffsetRangeTracker newTracker() { + return new org.apache.beam.sdk.transforms.splittabledofn.OffsetRangeTracker(this); } @Override @@ -74,4 +79,23 @@ public int hashCode() { result = 31 * result + (int) (to ^ (to >>> 32)); return result; } + + public List split(long desiredNumOffsetsPerSplit, long minNumOffsetPerSplit) { + List res = new ArrayList<>(); + long start = getFrom(); + long maxEnd = getTo(); + + while (start < maxEnd) { + long end = start + desiredNumOffsetsPerSplit; + end = Math.min(end, maxEnd); + // Avoid having a too small range at the end and ensure that we respect minNumOffsetPerSplit. + long remaining = maxEnd - end; + if ((remaining < desiredNumOffsetsPerSplit / 4) || (remaining < minNumOffsetPerSplit)) { + end = maxEnd; + } + res.add(new OffsetRange(start, end)); + start = end; + } + return res; + } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/range/OffsetRangeTracker.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/range/OffsetRangeTracker.java index 51e2b1ac2a1ce..8f0083e71483f 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/range/OffsetRangeTracker.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/range/OffsetRangeTracker.java @@ -26,6 +26,9 @@ /** * A {@link RangeTracker} for non-negative positions of type {@code long}. + * + *

    Not to be confused with {@link + * org.apache.beam.sdk.transforms.splittabledofn.OffsetRangeTracker}. */ public class OffsetRangeTracker implements RangeTracker { private static final Logger LOG = LoggerFactory.getLogger(OffsetRangeTracker.class); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/splittabledofn/OffsetRangeTracker.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/splittabledofn/OffsetRangeTracker.java index 0271a0d1f3f1e..62c10a71ffa21 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/splittabledofn/OffsetRangeTracker.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/splittabledofn/OffsetRangeTracker.java @@ -21,6 +21,7 @@ import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; +import org.apache.beam.sdk.io.range.OffsetRange; import org.apache.beam.sdk.transforms.DoFn; /** diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TextIOTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TextIOTest.java index 8797ff76c7942..a6be4fb4876fb 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TextIOTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TextIOTest.java @@ -120,10 +120,10 @@ public class TextIOTest { private static final String MY_HEADER = "myHeader"; private static final String MY_FOOTER = "myFooter"; - private static final String[] EMPTY = new String[] {}; - private static final String[] TINY = - new String[] {"Irritable eagle", "Optimistic jay", "Fanciful hawk"}; - private static final String[] LARGE = makeLines(1000); + private static final List EMPTY = Collections.emptyList(); + private static final List TINY = + Arrays.asList("Irritable eagle", "Optimistic jay", "Fanciful hawk"); + private static final List LARGE = makeLines(1000); private static Path tempFolder; private static File emptyTxt; @@ -148,7 +148,7 @@ public class TextIOTest { @Rule public ExpectedException expectedException = ExpectedException.none(); - private static File writeToFile(String[] lines, String filename, CompressionType compression) + private static File writeToFile(List lines, String filename, CompressionType compression) throws IOException { File file = tempFolder.resolve(filename).toFile(); OutputStream output = new FileOutputStream(file); @@ -791,7 +791,7 @@ public void testCompressionTypeIsSet() throws Exception { * Helper that writes the given lines (adding a newline in between) to a stream, then closes the * stream. */ - private static void writeToStreamAndClose(String[] lines, OutputStream outputStream) { + private static void writeToStreamAndClose(List lines, OutputStream outputStream) { try (PrintStream writer = new PrintStream(outputStream)) { for (String line : lines) { writer.println(line); @@ -800,27 +800,33 @@ private static void writeToStreamAndClose(String[] lines, OutputStream outputStr } /** - * Helper method that runs TextIO.read().from(filename).withCompressionType(compressionType) + * Helper method that runs TextIO.read().from(filename).withCompressionType(compressionType) and + * TextIO.readAll().withCompressionType(compressionType) applied to the single filename, * and asserts that the results match the given expected output. */ private void assertReadingCompressedFileMatchesExpected( - File file, CompressionType compressionType, String[] expected) { - - TextIO.Read read = - TextIO.read().from(file.getPath()).withCompressionType(compressionType); - PCollection output = p.apply("Read_" + file + "_" + compressionType.toString(), read); - - PAssert.that(output).containsInAnyOrder(expected); + File file, CompressionType compressionType, List expected) { + + TextIO.Read read = TextIO.read().from(file.getPath()).withCompressionType(compressionType); + PAssert.that(p.apply("Read_" + file + "_" + compressionType.toString(), read)) + .containsInAnyOrder(expected); + + TextIO.ReadAll readAll = + TextIO.readAll().withCompressionType(compressionType).withDesiredBundleSizeBytes(10); + PAssert.that( + p.apply("Create_" + file, Create.of(file.getPath())) + .apply("Read_" + compressionType.toString(), readAll)) + .containsInAnyOrder(expected); p.run(); } /** * Helper to make an array of compressible strings. Returns ["word"i] for i in range(0,n). */ - private static String[] makeLines(int n) { - String[] ret = new String[n]; + private static List makeLines(int n) { + List ret = new ArrayList<>(); for (int i = 0; i < n; ++i) { - ret[i] = "word" + i; + ret.add("word" + i); } return ret; } @@ -1004,7 +1010,7 @@ public void testZipCompressedReadWithMultiEntriesFile() throws Exception { String filename = createZipFile(expected, "multiple entries", entry0, entry1, entry2); assertReadingCompressedFileMatchesExpected( - new File(filename), CompressionType.ZIP, expected.toArray(new String[]{})); + new File(filename), CompressionType.ZIP, expected); } /** @@ -1023,7 +1029,7 @@ public void testZipCompressedReadWithComplexEmptyAndPresentEntries() throws Exce new String[]{"dog"}); assertReadingCompressedFileMatchesExpected( - new File(filename), CompressionType.ZIP, new String[] {"cat", "dog"}); + new File(filename), CompressionType.ZIP, Arrays.asList("cat", "dog")); } @Test @@ -1340,5 +1346,21 @@ public void testInitialSplitGzipModeGz() throws Exception { SourceTestUtils.assertSourcesEqualReferenceSource(source, splits, options); } -} + @Test + @Category(NeedsRunner.class) + public void testReadAll() throws IOException { + writeToFile(TINY, "readAllTiny1.zip", ZIP); + writeToFile(TINY, "readAllTiny2.zip", ZIP); + writeToFile(LARGE, "readAllLarge1.zip", ZIP); + writeToFile(LARGE, "readAllLarge2.zip", ZIP); + PCollection lines = + p.apply( + Create.of( + tempFolder.resolve("readAllTiny*").toString(), + tempFolder.resolve("readAllLarge*").toString())) + .apply(TextIO.readAll().withCompressionType(AUTO)); + PAssert.that(lines).containsInAnyOrder(Iterables.concat(TINY, TINY, LARGE, LARGE)); + p.run(); + } +} diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/SplittableDoFnTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/SplittableDoFnTest.java index 0c2bd1c871d07..cb60f9a851a43 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/SplittableDoFnTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/SplittableDoFnTest.java @@ -34,6 +34,7 @@ import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.coders.VarIntCoder; +import org.apache.beam.sdk.io.range.OffsetRange; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.StreamingOptions; import org.apache.beam.sdk.testing.PAssert; @@ -44,7 +45,6 @@ import org.apache.beam.sdk.testing.UsesTestStream; import org.apache.beam.sdk.testing.ValidatesRunner; import org.apache.beam.sdk.transforms.DoFn.BoundedPerElement; -import org.apache.beam.sdk.transforms.splittabledofn.OffsetRange; import org.apache.beam.sdk.transforms.splittabledofn.OffsetRangeTracker; import org.apache.beam.sdk.transforms.windowing.FixedWindows; import org.apache.beam.sdk.transforms.windowing.IntervalWindow; diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/splittabledofn/OffsetRangeTrackerTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/splittabledofn/OffsetRangeTrackerTest.java index 831894ca96929..8aed6b9c01caf 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/splittabledofn/OffsetRangeTrackerTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/splittabledofn/OffsetRangeTrackerTest.java @@ -21,6 +21,7 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; +import org.apache.beam.sdk.io.range.OffsetRange; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; From 2b86a61e5bb07d3bd7f958e124bc8d79dc300c3f Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Tue, 11 Jul 2017 14:32:47 -0700 Subject: [PATCH 195/200] Cleanup and fix ptransform_fn decorator. Previously CallablePTransform was being used both as the factory and the transform itself, which could result in state getting carried between pipelines. --- .../apache_beam/transforms/combiners.py | 8 ++++ .../apache_beam/transforms/combiners_test.py | 7 +--- .../apache_beam/transforms/ptransform.py | 41 ++++++++----------- 3 files changed, 28 insertions(+), 28 deletions(-) diff --git a/sdks/python/apache_beam/transforms/combiners.py b/sdks/python/apache_beam/transforms/combiners.py index fa0742db41797..875306f808296 100644 --- a/sdks/python/apache_beam/transforms/combiners.py +++ b/sdks/python/apache_beam/transforms/combiners.py @@ -149,6 +149,7 @@ class Top(object): """Combiners for obtaining extremal elements.""" # pylint: disable=no-self-argument + @staticmethod @ptransform.ptransform_fn def Of(pcoll, n, compare=None, *args, **kwargs): """Obtain a list of the compare-most N elements in a PCollection. @@ -177,6 +178,7 @@ def Of(pcoll, n, compare=None, *args, **kwargs): return pcoll | core.CombineGlobally( TopCombineFn(n, compare, key, reverse), *args, **kwargs) + @staticmethod @ptransform.ptransform_fn def PerKey(pcoll, n, compare=None, *args, **kwargs): """Identifies the compare-most N elements associated with each key. @@ -210,21 +212,25 @@ def PerKey(pcoll, n, compare=None, *args, **kwargs): return pcoll | core.CombinePerKey( TopCombineFn(n, compare, key, reverse), *args, **kwargs) + @staticmethod @ptransform.ptransform_fn def Largest(pcoll, n): """Obtain a list of the greatest N elements in a PCollection.""" return pcoll | Top.Of(n) + @staticmethod @ptransform.ptransform_fn def Smallest(pcoll, n): """Obtain a list of the least N elements in a PCollection.""" return pcoll | Top.Of(n, reverse=True) + @staticmethod @ptransform.ptransform_fn def LargestPerKey(pcoll, n): """Identifies the N greatest elements associated with each key.""" return pcoll | Top.PerKey(n) + @staticmethod @ptransform.ptransform_fn def SmallestPerKey(pcoll, n, reverse=True): """Identifies the N least elements associated with each key.""" @@ -369,10 +375,12 @@ class Sample(object): """Combiners for sampling n elements without replacement.""" # pylint: disable=no-self-argument + @staticmethod @ptransform.ptransform_fn def FixedSizeGlobally(pcoll, n): return pcoll | core.CombineGlobally(SampleCombineFn(n)) + @staticmethod @ptransform.ptransform_fn def FixedSizePerKey(pcoll, n): return pcoll | core.CombinePerKey(SampleCombineFn(n)) diff --git a/sdks/python/apache_beam/transforms/combiners_test.py b/sdks/python/apache_beam/transforms/combiners_test.py index c79fec864acb9..cd2b5956fef99 100644 --- a/sdks/python/apache_beam/transforms/combiners_test.py +++ b/sdks/python/apache_beam/transforms/combiners_test.py @@ -156,14 +156,11 @@ def individual_test_per_key_dd(combineFn): def test_combine_sample_display_data(self): def individual_test_per_key_dd(sampleFn, args, kwargs): - trs = [beam.CombinePerKey(sampleFn(*args, **kwargs)), - beam.CombineGlobally(sampleFn(*args, **kwargs))] + trs = [sampleFn(*args, **kwargs)] for transform in trs: dd = DisplayData.create_from(transform) expected_items = [ - DisplayDataItemMatcher('fn', sampleFn.fn.__name__), - DisplayDataItemMatcher('combine_fn', - transform.fn.__class__)] + DisplayDataItemMatcher('fn', transform._fn.__name__)] if args: expected_items.append( DisplayDataItemMatcher('args', str(args))) diff --git a/sdks/python/apache_beam/transforms/ptransform.py b/sdks/python/apache_beam/transforms/ptransform.py index 60413535f65a9..cd84122d5e1ec 100644 --- a/sdks/python/apache_beam/transforms/ptransform.py +++ b/sdks/python/apache_beam/transforms/ptransform.py @@ -595,32 +595,23 @@ def default_label(self): return '%s(%s)' % (self.__class__.__name__, self.fn.default_label()) -class CallablePTransform(PTransform): +class _PTransformFnPTransform(PTransform): """A class wrapper for a function-based transform.""" - def __init__(self, fn): - # pylint: disable=super-init-not-called - # This is a helper class for a function decorator. Only when the class - # is called (and __call__ invoked) we will have all the information - # needed to initialize the super class. - self.fn = fn - self._args = () - self._kwargs = {} + def __init__(self, fn, *args, **kwargs): + super(_PTransformFnPTransform, self).__init__() + self._fn = fn + self._args = args + self._kwargs = kwargs def display_data(self): - res = {'fn': (self.fn.__name__ - if hasattr(self.fn, '__name__') - else self.fn.__class__), + res = {'fn': (self._fn.__name__ + if hasattr(self._fn, '__name__') + else self._fn.__class__), 'args': DisplayDataItem(str(self._args)).drop_if_default('()'), 'kwargs': DisplayDataItem(str(self._kwargs)).drop_if_default('{}')} return res - def __call__(self, *args, **kwargs): - super(CallablePTransform, self).__init__() - self._args = args - self._kwargs = kwargs - return self - def expand(self, pcoll): # Since the PTransform will be implemented entirely as a function # (once called), we need to pass through any type-hinting information that @@ -629,18 +620,18 @@ def expand(self, pcoll): kwargs = dict(self._kwargs) args = tuple(self._args) try: - if 'type_hints' in inspect.getargspec(self.fn).args: + if 'type_hints' in inspect.getargspec(self._fn).args: args = (self.get_type_hints(),) + args except TypeError: # Might not be a function. pass - return self.fn(pcoll, *args, **kwargs) + return self._fn(pcoll, *args, **kwargs) def default_label(self): if self._args: return '%s(%s)' % ( - label_from_callable(self.fn), label_from_callable(self._args[0])) - return label_from_callable(self.fn) + label_from_callable(self._fn), label_from_callable(self._args[0])) + return label_from_callable(self._fn) def ptransform_fn(fn): @@ -684,7 +675,11 @@ def expand(self, pcoll): operator (i.e., `|`) will inject the pcoll argument in its proper place (first argument if no label was specified and second argument otherwise). """ - return CallablePTransform(fn) + # TODO(robertwb): Consider removing staticmethod to allow for self parameter. + + def callable_ptransform_factory(*args, **kwargs): + return _PTransformFnPTransform(fn, *args, **kwargs) + return callable_ptransform_factory def label_from_callable(fn): From 1bff4a786536ff1a4ffe9904079c7a89058e6b4e Mon Sep 17 00:00:00 2001 From: Eugene Kirpichov Date: Tue, 13 Jun 2017 16:50:35 -0700 Subject: [PATCH 196/200] [BEAM-2447] Reintroduces DoFn.ProcessContinuation --- .../construction/SplittableParDoTest.java | 10 +- ...oundedSplittableProcessElementInvoker.java | 35 +++++- .../SplittableParDoViaKeyedWorkItems.java | 9 +- .../core/SplittableProcessElementInvoker.java | 25 ++++- ...edSplittableProcessElementInvokerTest.java | 45 ++++++-- .../core/SplittableParDoProcessFnTest.java | 99 +++++++++++++++-- .../org/apache/beam/sdk/transforms/DoFn.java | 51 ++++++++- .../reflect/ByteBuddyDoFnInvokerFactory.java | 19 +++- .../sdk/transforms/reflect/DoFnInvoker.java | 4 +- .../sdk/transforms/reflect/DoFnSignature.java | 10 +- .../transforms/reflect/DoFnSignatures.java | 22 +++- .../splittabledofn/OffsetRangeTracker.java | 10 ++ .../splittabledofn/RestrictionTracker.java | 11 +- .../sdk/transforms/SplittableDoFnTest.java | 100 +++++++----------- .../transforms/reflect/DoFnInvokersTest.java | 93 ++++++++++++---- .../DoFnSignaturesProcessElementTest.java | 2 +- .../DoFnSignaturesSplittableDoFnTest.java | 83 +++++++++++++-- 17 files changed, 487 insertions(+), 141 deletions(-) diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/SplittableParDoTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/SplittableParDoTest.java index f4c596e019517..267232c028f6a 100644 --- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/SplittableParDoTest.java +++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/SplittableParDoTest.java @@ -17,6 +17,7 @@ */ package org.apache.beam.runners.core.construction; +import static org.apache.beam.sdk.transforms.DoFn.ProcessContinuation.stop; import static org.junit.Assert.assertEquals; import java.io.Serializable; @@ -24,8 +25,6 @@ import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.DoFn; -import org.apache.beam.sdk.transforms.DoFn.BoundedPerElement; -import org.apache.beam.sdk.transforms.DoFn.UnboundedPerElement; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.splittabledofn.HasDefaultTracker; import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker; @@ -70,7 +69,6 @@ public SomeRestriction checkpoint() { public void checkDone() {} } - @BoundedPerElement private static class BoundedFakeFn extends DoFn { @ProcessElement public void processElement(ProcessContext context, SomeRestrictionTracker tracker) {} @@ -81,10 +79,12 @@ public SomeRestriction getInitialRestriction(Integer element) { } } - @UnboundedPerElement private static class UnboundedFakeFn extends DoFn { @ProcessElement - public void processElement(ProcessContext context, SomeRestrictionTracker tracker) {} + public ProcessContinuation processElement( + ProcessContext context, SomeRestrictionTracker tracker) { + return stop(); + } @GetInitialRestriction public SomeRestriction getInitialRestriction(Integer element) { diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvoker.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvoker.java index 475abf25eaa06..0c956d53af9da 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvoker.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvoker.java @@ -96,7 +96,7 @@ public Result invokeProcessElement( final WindowedValue element, final TrackerT tracker) { final ProcessContext processContext = new ProcessContext(element, tracker); - invoker.invokeProcessElement( + DoFn.ProcessContinuation cont = invoker.invokeProcessElement( new DoFnInvoker.ArgumentProvider() { @Override public DoFn.ProcessContext processContext( @@ -155,10 +155,37 @@ public Timer timer(String timerId) { "Access to timers not supported in Splittable DoFn"); } }); - + // TODO: verify that if there was a failed tryClaim() call, then cont.shouldResume() is false. + // Currently we can't verify this because there are no hooks into tryClaim(). + // See https://issues.apache.org/jira/browse/BEAM-2607 + RestrictionT residual = processContext.extractCheckpoint(); + if (cont.shouldResume()) { + if (residual == null) { + // No checkpoint had been taken by the runner while the ProcessElement call ran, however + // the call says that not the whole restriction has been processed. So we need to take + // a checkpoint now: checkpoint() guarantees that the primary restriction describes exactly + // the work that was done in the current ProcessElement call, and returns a residual + // restriction that describes exactly the work that wasn't done in the current call. + residual = tracker.checkpoint(); + } else { + // A checkpoint was taken by the runner, and then the ProcessElement call returned resume() + // without making more tryClaim() calls (since no tryClaim() calls can succeed after + // checkpoint(), and since if it had made a failed tryClaim() call, it should have returned + // stop()). + // This means that the resulting primary restriction and the taken checkpoint already + // accurately describe respectively the work that was and wasn't done in the current + // ProcessElement call. + // In other words, if we took a checkpoint *after* ProcessElement completed (like in the + // branch above), it would have been equivalent to this one. + } + } else { + // The ProcessElement call returned stop() - that means the tracker's current restriction + // has been fully processed by the call. A checkpoint may or may not have been taken in + // "residual"; if it was, then we'll need to process it; if no, then we don't - nothing + // special needs to be done. + } tracker.checkDone(); - return new Result( - processContext.extractCheckpoint(), processContext.getLastReportedWatermark()); + return new Result(residual, cont, processContext.getLastReportedWatermark()); } private class ProcessContext extends DoFn.ProcessContext { diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/SplittableParDoViaKeyedWorkItems.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/SplittableParDoViaKeyedWorkItems.java index 09f3b157f7be3..6e976455a9414 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/SplittableParDoViaKeyedWorkItems.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/SplittableParDoViaKeyedWorkItems.java @@ -200,8 +200,8 @@ public static class ProcessFn< /** * The state cell containing a watermark hold for the output of this {@link DoFn}. The hold is * acquired during the first {@link DoFn.ProcessElement} call for each element and restriction, - * and is released when the {@link DoFn.ProcessElement} call returns and there is no residual - * restriction captured by the {@link SplittableProcessElementInvoker}. + * and is released when the {@link DoFn.ProcessElement} call returns {@link + * ProcessContinuation#stop()}. * *

    A hold is needed to avoid letting the output watermark immediately progress together with * the input watermark when the first {@link DoFn.ProcessElement} call for this element @@ -365,11 +365,12 @@ public void processElement(final ProcessContext c) { if (futureOutputWatermark == null) { futureOutputWatermark = elementAndRestriction.getKey().getTimestamp(); } + Instant wakeupTime = + timerInternals.currentProcessingTime().plus(result.getContinuation().resumeDelay()); holdState.add(futureOutputWatermark); // Set a timer to continue processing this element. timerInternals.setTimer( - TimerInternals.TimerData.of( - stateNamespace, timerInternals.currentProcessingTime(), TimeDomain.PROCESSING_TIME)); + TimerInternals.TimerData.of(stateNamespace, wakeupTime, TimeDomain.PROCESSING_TIME)); } private DoFn.StartBundleContext wrapContextAsStartBundle( diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/SplittableProcessElementInvoker.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/SplittableProcessElementInvoker.java index ced6c015039f8..7732df371ac83 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/SplittableProcessElementInvoker.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/SplittableProcessElementInvoker.java @@ -17,6 +17,8 @@ */ package org.apache.beam.runners.core; +import static com.google.common.base.Preconditions.checkNotNull; + import javax.annotation.Nullable; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.reflect.DoFnInvoker; @@ -34,20 +36,35 @@ public abstract class SplittableProcessElementInvoker< public class Result { @Nullable private final RestrictionT residualRestriction; + private final DoFn.ProcessContinuation continuation; private final Instant futureOutputWatermark; public Result( - @Nullable RestrictionT residualRestriction, Instant futureOutputWatermark) { + @Nullable RestrictionT residualRestriction, + DoFn.ProcessContinuation continuation, + Instant futureOutputWatermark) { + this.continuation = checkNotNull(continuation); + if (continuation.shouldResume()) { + checkNotNull(residualRestriction); + } this.residualRestriction = residualRestriction; this.futureOutputWatermark = futureOutputWatermark; } - /** If {@code null}, means the call should not resume. */ + /** + * Can be {@code null} only if {@link #getContinuation} specifies the call should not resume. + * However, the converse is not true: this can be non-null even if {@link #getContinuation} + * is {@link DoFn.ProcessContinuation#stop()}. + */ @Nullable public RestrictionT getResidualRestriction() { return residualRestriction; } + public DoFn.ProcessContinuation getContinuation() { + return continuation; + } + public Instant getFutureOutputWatermark() { return futureOutputWatermark; } @@ -57,8 +74,8 @@ public Instant getFutureOutputWatermark() { * Invokes the {@link DoFn.ProcessElement} method using the given {@link DoFnInvoker} for the * original {@link DoFn}, on the given element and with the given {@link RestrictionTracker}. * - * @return Information on how to resume the call: residual restriction and a - * future output watermark. + * @return Information on how to resume the call: residual restriction, a {@link + * DoFn.ProcessContinuation}, and a future output watermark. */ public abstract Result invokeProcessElement( DoFnInvoker invoker, WindowedValue element, TrackerT tracker); diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvokerTest.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvokerTest.java index b80a6326d906d..959909e6690e4 100644 --- a/runners/core-java/src/test/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvokerTest.java +++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvokerTest.java @@ -17,11 +17,15 @@ */ package org.apache.beam.runners.core; +import static org.apache.beam.sdk.transforms.DoFn.ProcessContinuation.resume; +import static org.apache.beam.sdk.transforms.DoFn.ProcessContinuation.stop; import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.lessThan; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; import java.util.Collection; import java.util.concurrent.Executors; @@ -42,19 +46,27 @@ /** Tests for {@link OutputAndTimeBoundedSplittableProcessElementInvoker}. */ public class OutputAndTimeBoundedSplittableProcessElementInvokerTest { private static class SomeFn extends DoFn { + private final int numOutputsPerProcessCall; private final Duration sleepBeforeEachOutput; - private SomeFn(Duration sleepBeforeEachOutput) { + private SomeFn(int numOutputsPerProcessCall, Duration sleepBeforeEachOutput) { + this.numOutputsPerProcessCall = numOutputsPerProcessCall; this.sleepBeforeEachOutput = sleepBeforeEachOutput; } @ProcessElement - public void process(ProcessContext context, OffsetRangeTracker tracker) + public ProcessContinuation process(ProcessContext context, OffsetRangeTracker tracker) throws Exception { - for (long i = tracker.currentRestriction().getFrom(); tracker.tryClaim(i); ++i) { + for (long i = tracker.currentRestriction().getFrom(), numIterations = 1; + tracker.tryClaim(i); + ++i, ++numIterations) { Thread.sleep(sleepBeforeEachOutput.getMillis()); context.output("" + i); + if (numIterations == numOutputsPerProcessCall) { + return resume(); + } } + return stop(); } @GetInitialRestriction @@ -64,8 +76,8 @@ public OffsetRange getInitialRestriction(Integer element) { } private SplittableProcessElementInvoker.Result - runTest(int count, Duration sleepPerElement) { - SomeFn fn = new SomeFn(sleepPerElement); + runTest(int totalNumOutputs, int numOutputsPerProcessCall, Duration sleepPerElement) { + SomeFn fn = new SomeFn(numOutputsPerProcessCall, sleepPerElement); SplittableProcessElementInvoker invoker = new OutputAndTimeBoundedSplittableProcessElementInvoker<>( fn, @@ -93,14 +105,15 @@ public void outputWindowedValue( return invoker.invokeProcessElement( DoFnInvokers.invokerFor(fn), - WindowedValue.of(count, Instant.now(), GlobalWindow.INSTANCE, PaneInfo.NO_FIRING), - new OffsetRangeTracker(new OffsetRange(0, count))); + WindowedValue.of(totalNumOutputs, Instant.now(), GlobalWindow.INSTANCE, PaneInfo.NO_FIRING), + new OffsetRangeTracker(new OffsetRange(0, totalNumOutputs))); } @Test public void testInvokeProcessElementOutputBounded() throws Exception { SplittableProcessElementInvoker.Result res = - runTest(10000, Duration.ZERO); + runTest(10000, Integer.MAX_VALUE, Duration.ZERO); + assertFalse(res.getContinuation().shouldResume()); OffsetRange residualRange = res.getResidualRestriction(); // Should process the first 100 elements. assertEquals(1000, residualRange.getFrom()); @@ -110,7 +123,8 @@ public void testInvokeProcessElementOutputBounded() throws Exception { @Test public void testInvokeProcessElementTimeBounded() throws Exception { SplittableProcessElementInvoker.Result res = - runTest(10000, Duration.millis(100)); + runTest(10000, Integer.MAX_VALUE, Duration.millis(100)); + assertFalse(res.getContinuation().shouldResume()); OffsetRange residualRange = res.getResidualRestriction(); // Should process ideally around 30 elements - but due to timing flakiness, we can't enforce // that precisely. Just test that it's not egregiously off. @@ -120,9 +134,18 @@ public void testInvokeProcessElementTimeBounded() throws Exception { } @Test - public void testInvokeProcessElementVoluntaryReturn() throws Exception { + public void testInvokeProcessElementVoluntaryReturnStop() throws Exception { SplittableProcessElementInvoker.Result res = - runTest(5, Duration.millis(100)); + runTest(5, Integer.MAX_VALUE, Duration.millis(100)); + assertFalse(res.getContinuation().shouldResume()); assertNull(res.getResidualRestriction()); } + + @Test + public void testInvokeProcessElementVoluntaryReturnResume() throws Exception { + SplittableProcessElementInvoker.Result res = + runTest(10, 5, Duration.millis(100)); + assertTrue(res.getContinuation().shouldResume()); + assertEquals(new OffsetRange(5, 10), res.getResidualRestriction()); + } } diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/SplittableParDoProcessFnTest.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/SplittableParDoProcessFnTest.java index 1cd127547cf12..7449af326e9bf 100644 --- a/runners/core-java/src/test/java/org/apache/beam/runners/core/SplittableParDoProcessFnTest.java +++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/SplittableParDoProcessFnTest.java @@ -17,6 +17,9 @@ */ package org.apache.beam.runners.core; +import static org.apache.beam.sdk.transforms.DoFn.ProcessContinuation.resume; +import static org.apache.beam.sdk.transforms.DoFn.ProcessContinuation.stop; +import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.Matchers.hasItem; import static org.hamcrest.Matchers.hasItems; @@ -365,16 +368,71 @@ public void testUpdatesWatermark() throws Exception { assertEquals(null, tester.getWatermarkHold()); } - /** - * A splittable {@link DoFn} that generates the sequence [init, init + total). - */ + /** A simple splittable {@link DoFn} that outputs the given element every 5 seconds forever. */ + private static class SelfInitiatedResumeFn extends DoFn { + @ProcessElement + public ProcessContinuation process(ProcessContext c, SomeRestrictionTracker tracker) { + c.output(c.element().toString()); + return resume().withResumeDelay(Duration.standardSeconds(5)); + } + + @GetInitialRestriction + public SomeRestriction getInitialRestriction(Integer elem) { + return new SomeRestriction(); + } + } + + @Test + public void testResumeSetsTimer() throws Exception { + DoFn fn = new SelfInitiatedResumeFn(); + Instant base = Instant.now(); + ProcessFnTester tester = + new ProcessFnTester<>( + base, + fn, + BigEndianIntegerCoder.of(), + SerializableCoder.of(SomeRestriction.class), + MAX_OUTPUTS_PER_BUNDLE, + MAX_BUNDLE_DURATION); + + tester.startElement(42, new SomeRestriction()); + assertThat(tester.takeOutputElements(), contains("42")); + + // Should resume after 5 seconds: advancing by 3 seconds should have no effect. + assertFalse(tester.advanceProcessingTimeBy(Duration.standardSeconds(3))); + assertTrue(tester.takeOutputElements().isEmpty()); + + // 6 seconds should be enough should invoke the fn again. + assertTrue(tester.advanceProcessingTimeBy(Duration.standardSeconds(3))); + assertThat(tester.takeOutputElements(), contains("42")); + + // Should again resume after 5 seconds: advancing by 3 seconds should again have no effect. + assertFalse(tester.advanceProcessingTimeBy(Duration.standardSeconds(3))); + assertTrue(tester.takeOutputElements().isEmpty()); + + // 6 seconds should again be enough. + assertTrue(tester.advanceProcessingTimeBy(Duration.standardSeconds(3))); + assertThat(tester.takeOutputElements(), contains("42")); + } + + /** A splittable {@link DoFn} that generates the sequence [init, init + total). */ private static class CounterFn extends DoFn { + private final int numOutputsPerCall; + + public CounterFn(int numOutputsPerCall) { + this.numOutputsPerCall = numOutputsPerCall; + } + @ProcessElement - public void process(ProcessContext c, OffsetRangeTracker tracker) { - for (long i = tracker.currentRestriction().getFrom(); - tracker.tryClaim(i); ++i) { + public ProcessContinuation process(ProcessContext c, OffsetRangeTracker tracker) { + for (long i = tracker.currentRestriction().getFrom(), numIterations = 0; + tracker.tryClaim(i); ++i, ++numIterations) { c.output(String.valueOf(c.element() + i)); + if (numIterations == numOutputsPerCall) { + return resume(); + } } + return stop(); } @GetInitialRestriction @@ -383,10 +441,35 @@ public OffsetRange getInitialRestriction(Integer elem) { } } + public void testResumeCarriesOverState() throws Exception { + DoFn fn = new CounterFn(1); + Instant base = Instant.now(); + ProcessFnTester tester = + new ProcessFnTester<>( + base, + fn, + BigEndianIntegerCoder.of(), + SerializableCoder.of(OffsetRange.class), + MAX_OUTPUTS_PER_BUNDLE, + MAX_BUNDLE_DURATION); + + tester.startElement(42, new OffsetRange(0, 3)); + assertThat(tester.takeOutputElements(), contains("42")); + assertTrue(tester.advanceProcessingTimeBy(Duration.standardSeconds(1))); + assertThat(tester.takeOutputElements(), contains("43")); + assertTrue(tester.advanceProcessingTimeBy(Duration.standardSeconds(1))); + assertThat(tester.takeOutputElements(), contains("44")); + assertTrue(tester.advanceProcessingTimeBy(Duration.standardSeconds(1))); + // After outputting all 3 items, should not output anything more. + assertEquals(0, tester.takeOutputElements().size()); + // Should also not ask to resume. + assertFalse(tester.advanceProcessingTimeBy(Duration.standardSeconds(1))); + } + @Test public void testCheckpointsAfterNumOutputs() throws Exception { int max = 100; - DoFn fn = new CounterFn(); + DoFn fn = new CounterFn(Integer.MAX_VALUE); Instant base = Instant.now(); int baseIndex = 42; @@ -428,7 +511,7 @@ public void testCheckpointsAfterDuration() throws Exception { // But bound bundle duration - the bundle should terminate. Duration maxBundleDuration = Duration.standardSeconds(1); // Create an fn that attempts to 2x output more than checkpointing allows. - DoFn fn = new CounterFn(); + DoFn fn = new CounterFn(Integer.MAX_VALUE); Instant base = Instant.now(); int baseIndex = 42; diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java index a2e5c162c7cb5..1b809c2ff8a58 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java @@ -17,6 +17,7 @@ */ package org.apache.beam.sdk.transforms; +import com.google.auto.value.AutoValue; import java.io.Serializable; import java.lang.annotation.Documented; import java.lang.annotation.ElementType; @@ -545,11 +546,15 @@ public interface OutputReceiver { * returned by {@link GetInitialRestriction} implements {@link HasDefaultTracker}. *

  • It may define a {@link GetRestrictionCoder} method. *
  • The type of restrictions used by all of these methods must be the same. + *
  • Its {@link ProcessElement} method may return a {@link ProcessContinuation} to + * indicate whether there is more work to be done for the current element. *
  • Its {@link ProcessElement} method must not use any extra context parameters, such as * {@link BoundedWindow}. *
  • The {@link DoFn} itself may be annotated with {@link BoundedPerElement} or * {@link UnboundedPerElement}, but not both at the same time. If it's not annotated with - * either of these, it's assumed to be {@link BoundedPerElement}. + * either of these, it's assumed to be {@link BoundedPerElement} if its {@link + * ProcessElement} method returns {@code void} and {@link UnboundedPerElement} if it + * returns a {@link ProcessContinuation}. * * *

    A non-splittable {@link DoFn} must not define any of these methods. @@ -677,8 +682,48 @@ public interface OutputReceiver { @Experimental(Kind.SPLITTABLE_DO_FN) public @interface UnboundedPerElement {} - /** Temporary, do not use. See https://issues.apache.org/jira/browse/BEAM-1904 */ - public class ProcessContinuation {} + // This can't be put into ProcessContinuation itself due to the following problem: + // http://ternarysearch.blogspot.com/2013/07/static-initialization-deadlock.html + private static final ProcessContinuation PROCESS_CONTINUATION_STOP = + new AutoValue_DoFn_ProcessContinuation(false, Duration.ZERO); + + /** + * When used as a return value of {@link ProcessElement}, indicates whether there is more work to + * be done for the current element. + * + *

    If the {@link ProcessElement} call completes because of a failed {@code tryClaim()} call + * on the {@link RestrictionTracker}, then the call MUST return {@link #stop()}. + */ + @Experimental(Kind.SPLITTABLE_DO_FN) + @AutoValue + public abstract static class ProcessContinuation { + /** Indicates that there is no more work to be done for the current element. */ + public static ProcessContinuation stop() { + return PROCESS_CONTINUATION_STOP; + } + + /** Indicates that there is more work to be done for the current element. */ + public static ProcessContinuation resume() { + return new AutoValue_DoFn_ProcessContinuation(true, Duration.ZERO); + } + + /** + * If false, the {@link DoFn} promises that there is no more work remaining for the current + * element, so the runner should not resume the {@link ProcessElement} call. + */ + public abstract boolean shouldResume(); + + /** + * A minimum duration that should elapse between the end of this {@link ProcessElement} call and + * the {@link ProcessElement} call continuing processing of the same element. By default, zero. + */ + public abstract Duration resumeDelay(); + + /** Builder method to set the value of {@link #resumeDelay()}. */ + public ProcessContinuation withResumeDelay(Duration resumeDelay) { + return new AutoValue_DoFn_ProcessContinuation(shouldResume(), resumeDelay); + } + } /** * Finalize the {@link DoFn} construction to prepare for processing. diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/ByteBuddyDoFnInvokerFactory.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/ByteBuddyDoFnInvokerFactory.java index 837820411d9bd..cf96c9bea4f55 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/ByteBuddyDoFnInvokerFactory.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/ByteBuddyDoFnInvokerFactory.java @@ -49,7 +49,6 @@ import net.bytebuddy.implementation.bytecode.assign.Assigner; import net.bytebuddy.implementation.bytecode.assign.Assigner.Typing; import net.bytebuddy.implementation.bytecode.assign.TypeCasting; -import net.bytebuddy.implementation.bytecode.constant.NullConstant; import net.bytebuddy.implementation.bytecode.constant.TextConstant; import net.bytebuddy.implementation.bytecode.member.FieldAccess; import net.bytebuddy.implementation.bytecode.member.MethodInvocation; @@ -641,6 +640,17 @@ public StackManipulation dispatch(DoFnSignature.Parameter.PipelineOptionsParamet * {@link ProcessElement} method. */ private static final class ProcessElementDelegation extends DoFnMethodDelegation { + private static final MethodDescription PROCESS_CONTINUATION_STOP_METHOD; + + static { + try { + PROCESS_CONTINUATION_STOP_METHOD = + new MethodDescription.ForLoadedMethod(DoFn.ProcessContinuation.class.getMethod("stop")); + } catch (NoSuchMethodException e) { + throw new RuntimeException("Failed to locate ProcessContinuation.stop()"); + } + } + private final DoFnSignature.ProcessElementMethod signature; /** Implementation of {@link MethodDelegation} for the {@link ProcessElement} method. */ @@ -677,7 +687,12 @@ protected StackManipulation beforeDelegation(MethodDescription instrumentedMetho @Override protected StackManipulation afterDelegation(MethodDescription instrumentedMethod) { - return new StackManipulation.Compound(NullConstant.INSTANCE, MethodReturn.REFERENCE); + if (TypeDescription.VOID.equals(targetMethod.getReturnType().asErasure())) { + return new StackManipulation.Compound( + MethodInvocation.invoke(PROCESS_CONTINUATION_STOP_METHOD), MethodReturn.REFERENCE); + } else { + return MethodReturn.of(targetMethod.getReturnType().asErasure()); + } } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvoker.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvoker.java index 3b22fdaccb01c..8b41fee109eb5 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvoker.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvoker.java @@ -54,8 +54,8 @@ public interface DoFnInvoker { * Invoke the {@link DoFn.ProcessElement} method on the bound {@link DoFn}. * * @param extra Factory for producing extra parameter objects (such as window), if necessary. - * @return {@code null} - see JIRA - * tracking the complete removal of {@link DoFn.ProcessContinuation}. + * @return The {@link DoFn.ProcessContinuation} returned by the underlying method, or {@link + * DoFn.ProcessContinuation#stop()} if it returns {@code void}. */ DoFn.ProcessContinuation invokeProcessElement(ArgumentProvider extra); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java index 6eeed8e054950..bfad69ea77669 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java @@ -33,6 +33,7 @@ import org.apache.beam.sdk.state.Timer; import org.apache.beam.sdk.state.TimerSpec; import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.DoFn.ProcessContinuation; import org.apache.beam.sdk.transforms.DoFn.StateId; import org.apache.beam.sdk.transforms.DoFn.TimerId; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.RestrictionTrackerParameter; @@ -433,16 +434,21 @@ public abstract static class ProcessElementMethod implements MethodWithExtraPara @Nullable public abstract TypeDescriptor windowT(); + /** Whether this {@link DoFn} returns a {@link ProcessContinuation} or void. */ + public abstract boolean hasReturnValue(); + static ProcessElementMethod create( Method targetMethod, List extraParameters, TypeDescriptor trackerT, - @Nullable TypeDescriptor windowT) { + @Nullable TypeDescriptor windowT, + boolean hasReturnValue) { return new AutoValue_DoFnSignature_ProcessElementMethod( targetMethod, Collections.unmodifiableList(extraParameters), trackerT, - windowT); + windowT, + hasReturnValue); } /** diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java index 1b27e66aa77fa..de57c3bed85a2 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java @@ -17,6 +17,8 @@ */ package org.apache.beam.sdk.transforms.reflect; +import static com.google.common.base.Preconditions.checkState; + import com.google.auto.value.AutoValue; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Predicates; @@ -440,6 +442,8 @@ private static DoFnSignature parseSignature(Class> fnClass) *

  • If the {@link DoFn} (or any of its supertypes) is annotated as {@link * DoFn.BoundedPerElement} or {@link DoFn.UnboundedPerElement}, use that. Only one of * these must be specified. + *
  • If {@link DoFn.ProcessElement} returns {@link DoFn.ProcessContinuation}, assume it is + * unbounded. Otherwise (if it returns {@code void}), assume it is bounded. *
  • If {@link DoFn.ProcessElement} returns {@code void}, but the {@link DoFn} is annotated * {@link DoFn.UnboundedPerElement}, this is an error. * @@ -465,7 +469,10 @@ private static PCollection.IsBounded inferBoundedness( } if (processElement.isSplittable()) { if (isBounded == null) { - isBounded = PCollection.IsBounded.BOUNDED; + isBounded = + processElement.hasReturnValue() + ? PCollection.IsBounded.UNBOUNDED + : PCollection.IsBounded.BOUNDED; } } else { errors.checkArgument( @@ -474,6 +481,7 @@ private static PCollection.IsBounded inferBoundedness( + ((isBounded == PCollection.IsBounded.BOUNDED) ? DoFn.BoundedPerElement.class.getSimpleName() : DoFn.UnboundedPerElement.class.getSimpleName())); + checkState(!processElement.hasReturnValue(), "Should have been inferred splittable"); isBounded = PCollection.IsBounded.BOUNDED; } return isBounded; @@ -710,8 +718,10 @@ static DoFnSignature.ProcessElementMethod analyzeProcessElementMethod( TypeDescriptor outputT, FnAnalysisContext fnContext) { errors.checkArgument( - void.class.equals(m.getReturnType()), - "Must return void"); + void.class.equals(m.getReturnType()) + || DoFn.ProcessContinuation.class.equals(m.getReturnType()), + "Must return void or %s", + DoFn.ProcessContinuation.class.getSimpleName()); MethodAnalysisContext methodContext = MethodAnalysisContext.create(); @@ -751,7 +761,11 @@ static DoFnSignature.ProcessElementMethod analyzeProcessElementMethod( } return DoFnSignature.ProcessElementMethod.create( - m, methodContext.getExtraParameters(), trackerT, windowT); + m, + methodContext.getExtraParameters(), + trackerT, + windowT, + DoFn.ProcessContinuation.class.equals(m.getReturnType())); } private static void checkParameterOneOf( diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/splittabledofn/OffsetRangeTracker.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/splittabledofn/OffsetRangeTracker.java index 62c10a71ffa21..4987409e5cd82 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/splittabledofn/OffsetRangeTracker.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/splittabledofn/OffsetRangeTracker.java @@ -21,6 +21,7 @@ import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; +import com.google.common.base.MoreObjects; import org.apache.beam.sdk.io.range.OffsetRange; import org.apache.beam.sdk.transforms.DoFn; @@ -100,4 +101,13 @@ public synchronized void checkDone() throws IllegalStateException { lastAttemptedOffset + 1, range.getTo()); } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("range", range) + .add("lastClaimedOffset", lastClaimedOffset) + .add("lastAttemptedOffset", lastAttemptedOffset) + .toString(); + } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/splittabledofn/RestrictionTracker.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/splittabledofn/RestrictionTracker.java index 27ef68f4a980c..8cb0a6bd4baac 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/splittabledofn/RestrictionTracker.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/splittabledofn/RestrictionTracker.java @@ -31,10 +31,13 @@ public interface RestrictionTracker { RestrictionT currentRestriction(); /** - * Signals that the current {@link DoFn.ProcessElement} call should terminate as soon as possible. - * Modifies {@link #currentRestriction}. Returns a restriction representing the rest of the work: - * the old value of {@link #currentRestriction} is equivalent to the new value and the return - * value of this method combined. Must be called at most once on a given object. + * Signals that the current {@link DoFn.ProcessElement} call should terminate as soon as possible: + * after this method returns, the tracker MUST refuse all future claim calls, and {@link + * #checkDone} MUST succeed. + * + *

    Modifies {@link #currentRestriction}. Returns a restriction representing the rest of the + * work: the old value of {@link #currentRestriction} is equivalent to the new value and the + * return value of this method combined. Must be called at most once on a given object. */ RestrictionT checkpoint(); diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/SplittableDoFnTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/SplittableDoFnTest.java index cb60f9a851a43..d2d2529871f5b 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/SplittableDoFnTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/SplittableDoFnTest.java @@ -19,10 +19,10 @@ import static com.google.common.base.Preconditions.checkState; import static org.apache.beam.sdk.testing.TestPipeline.testingPipelineOptions; -import static org.hamcrest.Matchers.greaterThan; +import static org.apache.beam.sdk.transforms.DoFn.ProcessContinuation.resume; +import static org.apache.beam.sdk.transforms.DoFn.ProcessContinuation.stop; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; import com.google.common.collect.Ordering; @@ -33,7 +33,6 @@ import org.apache.beam.sdk.coders.BigEndianIntegerCoder; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.coders.StringUtf8Coder; -import org.apache.beam.sdk.coders.VarIntCoder; import org.apache.beam.sdk.io.range.OffsetRange; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.StreamingOptions; @@ -74,10 +73,16 @@ public class SplittableDoFnTest implements Serializable { static class PairStringWithIndexToLength extends DoFn> { @ProcessElement - public void process(ProcessContext c, OffsetRangeTracker tracker) { - for (long i = tracker.currentRestriction().getFrom(); tracker.tryClaim(i); ++i) { + public ProcessContinuation process(ProcessContext c, OffsetRangeTracker tracker) { + for (long i = tracker.currentRestriction().getFrom(), numIterations = 0; + tracker.tryClaim(i); + ++i, ++numIterations) { c.output(KV.of(c.element(), (int) i)); + if (numIterations % 3 == 0) { + return resume(); + } } + return stop(); } @GetInitialRestriction @@ -206,10 +211,10 @@ public void testPairWithIndexWindowedTimestamped() { private static class SDFWithMultipleOutputsPerBlock extends DoFn { private static final int MAX_INDEX = 98765; - private final TupleTag numProcessCalls; + private final int numClaimsPerCall; - private SDFWithMultipleOutputsPerBlock(TupleTag numProcessCalls) { - this.numProcessCalls = numProcessCalls; + private SDFWithMultipleOutputsPerBlock(int numClaimsPerCall) { + this.numClaimsPerCall = numClaimsPerCall; } private static int snapToNextBlock(int index, int[] blockStarts) { @@ -222,15 +227,20 @@ private static int snapToNextBlock(int index, int[] blockStarts) { } @ProcessElement - public void processElement(ProcessContext c, OffsetRangeTracker tracker) { + public ProcessContinuation processElement(ProcessContext c, OffsetRangeTracker tracker) { int[] blockStarts = {-1, 0, 12, 123, 1234, 12345, 34567, MAX_INDEX}; int trueStart = snapToNextBlock((int) tracker.currentRestriction().getFrom(), blockStarts); - c.output(numProcessCalls, 1); - for (int i = trueStart; tracker.tryClaim(blockStarts[i]); ++i) { + for (int i = trueStart, numIterations = 1; + tracker.tryClaim(blockStarts[i]); + ++i, ++numIterations) { for (int index = blockStarts[i]; index < blockStarts[i + 1]; ++index) { c.output(index); } + if (numIterations == numClaimsPerCall) { + return resume(); + } } + return stop(); } @GetInitialRestriction @@ -242,26 +252,10 @@ public OffsetRange getInitialRange(String element) { @Test @Category({ValidatesRunner.class, UsesSplittableParDo.class}) public void testOutputAfterCheckpoint() throws Exception { - TupleTag main = new TupleTag<>(); - TupleTag numProcessCalls = new TupleTag<>(); - PCollectionTuple outputs = - p.apply(Create.of("foo")) - .apply( - ParDo.of(new SDFWithMultipleOutputsPerBlock(numProcessCalls)) - .withOutputTags(main, TupleTagList.of(numProcessCalls))); - PAssert.thatSingleton(outputs.get(main).apply(Count.globally())) + PCollection outputs = p.apply(Create.of("foo")) + .apply(ParDo.of(new SDFWithMultipleOutputsPerBlock(3))); + PAssert.thatSingleton(outputs.apply(Count.globally())) .isEqualTo((long) SDFWithMultipleOutputsPerBlock.MAX_INDEX); - // Verify that more than 1 process() call was involved, i.e. that there was checkpointing. - PAssert.thatSingleton( - outputs.get(numProcessCalls).setCoder(VarIntCoder.of()).apply(Sum.integersGlobally())) - .satisfies( - new SerializableFunction() { - @Override - public Void apply(Integer input) { - assertThat(input, greaterThan(1)); - return null; - } - }); p.run(); } @@ -341,12 +335,12 @@ private static class SDFWithMultipleOutputsPerBlockAndSideInput extends DoFn> { private static final int MAX_INDEX = 98765; private final PCollectionView sideInput; - private final TupleTag numProcessCalls; + private final int numClaimsPerCall; public SDFWithMultipleOutputsPerBlockAndSideInput( - PCollectionView sideInput, TupleTag numProcessCalls) { + PCollectionView sideInput, int numClaimsPerCall) { this.sideInput = sideInput; - this.numProcessCalls = numProcessCalls; + this.numClaimsPerCall = numClaimsPerCall; } private static int snapToNextBlock(int index, int[] blockStarts) { @@ -359,15 +353,20 @@ private static int snapToNextBlock(int index, int[] blockStarts) { } @ProcessElement - public void processElement(ProcessContext c, OffsetRangeTracker tracker) { + public ProcessContinuation processElement(ProcessContext c, OffsetRangeTracker tracker) { int[] blockStarts = {-1, 0, 12, 123, 1234, 12345, 34567, MAX_INDEX}; int trueStart = snapToNextBlock((int) tracker.currentRestriction().getFrom(), blockStarts); - c.output(numProcessCalls, 1); - for (int i = trueStart; tracker.tryClaim(blockStarts[i]); ++i) { + for (int i = trueStart, numIterations = 1; + tracker.tryClaim(blockStarts[i]); + ++i, ++numIterations) { for (int index = blockStarts[i]; index < blockStarts[i + 1]; ++index) { c.output(KV.of(c.sideInput(sideInput) + ":" + c.element(), index)); } + if (numIterations == numClaimsPerCall) { + return resume(); + } } + return stop(); } @GetInitialRestriction @@ -400,15 +399,14 @@ public void testWindowedSideInputWithCheckpoints() throws Exception { .apply("window 2", Window.into(FixedWindows.of(Duration.millis(2)))) .apply("singleton", View.asSingleton()); - TupleTag> main = new TupleTag<>(); - TupleTag numProcessCalls = new TupleTag<>(); - PCollectionTuple res = + PCollection> res = mainInput.apply( - ParDo.of(new SDFWithMultipleOutputsPerBlockAndSideInput(sideInput, numProcessCalls)) - .withSideInputs(sideInput) - .withOutputTags(main, TupleTagList.of(numProcessCalls))); + ParDo.of( + new SDFWithMultipleOutputsPerBlockAndSideInput( + sideInput, 3 /* numClaimsPerCall */)) + .withSideInputs(sideInput)); PCollection>> grouped = - res.get(main).apply(GroupByKey.create()); + res.apply(GroupByKey.create()); PAssert.that(grouped.apply(Keys.create())) .containsInAnyOrder("a:0", "a:1", "b:2", "b:3"); @@ -427,22 +425,6 @@ public Void apply(Iterable>> input) { return null; } }); - - // Verify that more than 1 process() call was involved, i.e. that there was checkpointing. - PAssert.thatSingleton( - res.get(numProcessCalls) - .setCoder(VarIntCoder.of()) - .apply(Sum.integersGlobally().withoutDefaults())) - // This should hold in all windows, but verifying a particular window is sufficient. - .inOnlyPane(new IntervalWindow(new Instant(0), new Instant(1))) - .satisfies( - new SerializableFunction() { - @Override - public Void apply(Integer input) { - assertThat(input, greaterThan(1)); - return null; - } - }); p.run(); // TODO: also test coverage when some of the windows of the side input are not ready. diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokersTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokersTest.java index 3edb19478b3f2..2098c664bb664 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokersTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokersTest.java @@ -17,6 +17,8 @@ */ package org.apache.beam.sdk.transforms.reflect; +import static org.apache.beam.sdk.transforms.DoFn.ProcessContinuation.resume; +import static org.apache.beam.sdk.transforms.DoFn.ProcessContinuation.stop; import static org.hamcrest.CoreMatchers.instanceOf; import static org.hamcrest.Matchers.equalTo; import static org.junit.Assert.assertEquals; @@ -89,8 +91,8 @@ public void setUp() { when(mockArgumentProvider.processContext(Matchers.any())).thenReturn(mockProcessContext); } - private void invokeProcessElement(DoFn fn) { - DoFnInvokers.invokerFor(fn).invokeProcessElement(mockArgumentProvider); + private DoFn.ProcessContinuation invokeProcessElement(DoFn fn) { + return DoFnInvokers.invokerFor(fn).invokeProcessElement(mockArgumentProvider); } private void invokeOnTimer(String timerId, DoFn fn) { @@ -119,7 +121,7 @@ class MockFn extends DoFn { public void processElement(ProcessContext c) throws Exception {} } MockFn mockFn = mock(MockFn.class); - invokeProcessElement(mockFn); + assertEquals(stop(), invokeProcessElement(mockFn)); verify(mockFn).processElement(mockProcessContext); } @@ -140,7 +142,7 @@ public void processElement(DoFn.ProcessContext c) {} public void testDoFnWithProcessElementInterface() throws Exception { IdentityUsingInterfaceWithProcessElement fn = mock(IdentityUsingInterfaceWithProcessElement.class); - invokeProcessElement(fn); + assertEquals(stop(), invokeProcessElement(fn)); verify(fn).processElement(mockProcessContext); } @@ -161,14 +163,14 @@ public void process(DoFn.ProcessContext c) { @Test public void testDoFnWithMethodInSuperclass() throws Exception { IdentityChildWithoutOverride fn = mock(IdentityChildWithoutOverride.class); - invokeProcessElement(fn); + assertEquals(stop(), invokeProcessElement(fn)); verify(fn).process(mockProcessContext); } @Test public void testDoFnWithMethodInSubclass() throws Exception { IdentityChildWithOverride fn = mock(IdentityChildWithOverride.class); - invokeProcessElement(fn); + assertEquals(stop(), invokeProcessElement(fn)); verify(fn).process(mockProcessContext); } @@ -179,7 +181,7 @@ class MockFn extends DoFn { public void processElement(ProcessContext c, IntervalWindow w) throws Exception {} } MockFn fn = mock(MockFn.class); - invokeProcessElement(fn); + assertEquals(stop(), invokeProcessElement(fn)); verify(fn).processElement(mockProcessContext, mockWindow); } @@ -203,7 +205,7 @@ public void processElement(ProcessContext c, @StateId(stateId) ValueState { + @DoFn.ProcessElement + public ProcessContinuation processElement(ProcessContext c, SomeRestrictionTracker tracker) + throws Exception { + return null; + } + + @GetInitialRestriction + public SomeRestriction getInitialRestriction(String element) { + return null; + } + + @NewTracker + public SomeRestrictionTracker newTracker(SomeRestriction restriction) { + return null; + } + } + MockFn fn = mock(MockFn.class); + when(fn.processElement(mockProcessContext, null)).thenReturn(resume()); + assertEquals(resume(), invokeProcessElement(fn)); + } + @Test public void testDoFnWithStartBundleSetupTeardown() throws Exception { class MockFn extends DoFn { @@ -288,7 +314,9 @@ public SomeRestriction decode(InputStream inStream) { /** Public so Mockito can do "delegatesTo()" in the test below. */ public static class MockFn extends DoFn { @ProcessElement - public void processElement(ProcessContext c, SomeRestrictionTracker tracker) {} + public ProcessContinuation processElement(ProcessContext c, SomeRestrictionTracker tracker) { + return null; + } @GetInitialRestriction public SomeRestriction getInitialRestriction(String element) { @@ -340,7 +368,7 @@ public void splitRestriction( .splitRestriction( eq("blah"), same(restriction), Mockito.>any()); when(fn.newTracker(restriction)).thenReturn(tracker); - fn.processElement(mockProcessContext, tracker); + when(fn.processElement(mockProcessContext, tracker)).thenReturn(resume()); assertEquals(coder, invoker.invokeGetRestrictionCoder(CoderRegistry.createDefault())); assertEquals(restriction, invoker.invokeGetInitialRestriction("blah")); @@ -356,6 +384,8 @@ public void output(SomeRestriction output) { }); assertEquals(Arrays.asList(part1, part2, part3), outputs); assertEquals(tracker, invoker.invokeNewTracker(restriction)); + assertEquals( + resume(), invoker.invokeProcessElement( new FakeArgumentProvider() { @Override @@ -367,7 +397,7 @@ public DoFn.ProcessContext processContext(DoFn f public RestrictionTracker restrictionTracker() { return tracker; } - }); + })); } private static class RestrictionWithDefaultTracker @@ -441,7 +471,7 @@ public void output(String output) { assertEquals("foo", output); } }); - invoker.invokeProcessElement(mockArgumentProvider); + assertEquals(stop(), invoker.invokeProcessElement(mockArgumentProvider)); assertThat( invoker.invokeNewTracker(new RestrictionWithDefaultTracker()), instanceOf(DefaultTracker.class)); @@ -531,14 +561,14 @@ public void processThis(ProcessContext c) {} @Test public void testLocalPrivateDoFnClass() throws Exception { PrivateDoFnClass fn = mock(PrivateDoFnClass.class); - invokeProcessElement(fn); + assertEquals(stop(), invokeProcessElement(fn)); verify(fn).processThis(mockProcessContext); } @Test public void testStaticPackagePrivateDoFnClass() throws Exception { DoFn fn = mock(DoFnInvokersTestHelper.newStaticPackagePrivateDoFn().getClass()); - invokeProcessElement(fn); + assertEquals(stop(), invokeProcessElement(fn)); DoFnInvokersTestHelper.verifyStaticPackagePrivateDoFn(fn, mockProcessContext); } @@ -546,28 +576,28 @@ public void testStaticPackagePrivateDoFnClass() throws Exception { public void testInnerPackagePrivateDoFnClass() throws Exception { DoFn fn = mock(new DoFnInvokersTestHelper().newInnerPackagePrivateDoFn().getClass()); - invokeProcessElement(fn); + assertEquals(stop(), invokeProcessElement(fn)); DoFnInvokersTestHelper.verifyInnerPackagePrivateDoFn(fn, mockProcessContext); } @Test public void testStaticPrivateDoFnClass() throws Exception { DoFn fn = mock(DoFnInvokersTestHelper.newStaticPrivateDoFn().getClass()); - invokeProcessElement(fn); + assertEquals(stop(), invokeProcessElement(fn)); DoFnInvokersTestHelper.verifyStaticPrivateDoFn(fn, mockProcessContext); } @Test public void testInnerPrivateDoFnClass() throws Exception { DoFn fn = mock(new DoFnInvokersTestHelper().newInnerPrivateDoFn().getClass()); - invokeProcessElement(fn); + assertEquals(stop(), invokeProcessElement(fn)); DoFnInvokersTestHelper.verifyInnerPrivateDoFn(fn, mockProcessContext); } @Test public void testAnonymousInnerDoFn() throws Exception { DoFn fn = mock(new DoFnInvokersTestHelper().newInnerAnonymousDoFn().getClass()); - invokeProcessElement(fn); + assertEquals(stop(), invokeProcessElement(fn)); DoFnInvokersTestHelper.verifyInnerAnonymousDoFn(fn, mockProcessContext); } @@ -603,6 +633,31 @@ public DoFn.ProcessContext processContext(DoFn() { + @ProcessElement + public ProcessContinuation processElement( + @SuppressWarnings("unused") ProcessContext c, SomeRestrictionTracker tracker) { + throw new IllegalArgumentException("bogus"); + } + + @GetInitialRestriction + public SomeRestriction getInitialRestriction(Integer element) { + return null; + } + + @NewTracker + public SomeRestrictionTracker newTracker(SomeRestriction restriction) { + return null; + } + }) + .invokeProcessElement(new FakeArgumentProvider()); + } + @Test public void testStartBundleException() throws Exception { DoFnInvoker invoker = diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesProcessElementTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesProcessElementTest.java index d321f54d68bd7..44ae5c4f2425a 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesProcessElementTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesProcessElementTest.java @@ -50,7 +50,7 @@ private void method(DoFn.ProcessContext c, Integer n) {} @Test public void testBadReturnType() throws Exception { thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("Must return void"); + thrown.expectMessage("Must return void or ProcessContinuation"); analyzeProcessElementMethod( new AnonymousMethod() { diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesSplittableDoFnTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesSplittableDoFnTest.java index 07b3348fe1010..08af65e93e7c3 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesSplittableDoFnTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesSplittableDoFnTest.java @@ -52,13 +52,28 @@ public class DoFnSignaturesSplittableDoFnTest { @Rule public ExpectedException thrown = ExpectedException.none(); - private static class SomeRestriction {} + private abstract static class SomeRestriction + implements HasDefaultTracker {} private abstract static class SomeRestrictionTracker implements RestrictionTracker {} private abstract static class SomeRestrictionCoder extends StructuredCoder {} + @Test + public void testReturnsProcessContinuation() throws Exception { + DoFnSignature.ProcessElementMethod signature = + analyzeProcessElementMethod( + new AnonymousMethod() { + private DoFn.ProcessContinuation method( + DoFn.ProcessContext context) { + return null; + } + }); + + assertTrue(signature.hasReturnValue()); + } + @Test public void testHasRestrictionTracker() throws Exception { DoFnSignature.ProcessElementMethod signature = @@ -100,11 +115,6 @@ public void processElement(ProcessContext context, SomeRestrictionTracker tracke public SomeRestriction getInitialRestriction(Integer element) { return null; } - - @NewTracker - public SomeRestrictionTracker newTracker(SomeRestriction restriction) { - return null; - } } @BoundedPerElement @@ -130,6 +140,55 @@ class UnboundedSplittableFn extends BaseSplittableFn {} .isBoundedPerElement()); } + private static class BaseFnWithoutContinuation extends DoFn { + @ProcessElement + public void processElement(ProcessContext context, SomeRestrictionTracker tracker) {} + + @GetInitialRestriction + public SomeRestriction getInitialRestriction(Integer element) { + return null; + } + } + + private static class BaseFnWithContinuation extends DoFn { + @ProcessElement + public ProcessContinuation processElement( + ProcessContext context, SomeRestrictionTracker tracker) { + return null; + } + + @GetInitialRestriction + public SomeRestriction getInitialRestriction(Integer element) { + return null; + } + } + + @Test + public void testSplittableBoundednessInferredFromReturnValue() throws Exception { + assertEquals( + PCollection.IsBounded.BOUNDED, + DoFnSignatures.getSignature(BaseFnWithoutContinuation.class).isBoundedPerElement()); + assertEquals( + PCollection.IsBounded.UNBOUNDED, + DoFnSignatures.getSignature(BaseFnWithContinuation.class).isBoundedPerElement()); + } + + @Test + public void testSplittableRespectsBoundednessAnnotation() throws Exception { + @BoundedPerElement + class BoundedFnWithContinuation extends BaseFnWithContinuation {} + + assertEquals( + PCollection.IsBounded.BOUNDED, + DoFnSignatures.getSignature(BoundedFnWithContinuation.class).isBoundedPerElement()); + + @UnboundedPerElement + class UnboundedFnWithContinuation extends BaseFnWithContinuation {} + + assertEquals( + PCollection.IsBounded.UNBOUNDED, + DoFnSignatures.getSignature(UnboundedFnWithContinuation.class).isBoundedPerElement()); + } @Test public void testUnsplittableIsBounded() throws Exception { class UnsplittableFn extends DoFn { @@ -172,8 +231,10 @@ public void process(ProcessContext context) {} public void testSplittableWithAllFunctions() throws Exception { class GoodSplittableDoFn extends DoFn { @ProcessElement - public void processElement( - ProcessContext context, SomeRestrictionTracker tracker) {} + public ProcessContinuation processElement( + ProcessContext context, SomeRestrictionTracker tracker) { + return null; + } @GetInitialRestriction public SomeRestriction getInitialRestriction(Integer element) { @@ -198,6 +259,7 @@ public SomeRestrictionCoder getRestrictionCoder() { DoFnSignature signature = DoFnSignatures.getSignature(GoodSplittableDoFn.class); assertEquals(SomeRestrictionTracker.class, signature.processElement().trackerT().getRawType()); assertTrue(signature.processElement().isSplittable()); + assertTrue(signature.processElement().hasReturnValue()); assertEquals( SomeRestriction.class, signature.getInitialRestriction().restrictionT().getRawType()); assertEquals(SomeRestriction.class, signature.splitRestriction().restrictionT().getRawType()); @@ -214,7 +276,9 @@ public SomeRestrictionCoder getRestrictionCoder() { public void testSplittableWithAllFunctionsGeneric() throws Exception { class GoodGenericSplittableDoFn extends DoFn { @ProcessElement - public void processElement(ProcessContext context, TrackerT tracker) {} + public ProcessContinuation processElement(ProcessContext context, TrackerT tracker) { + return null; + } @GetInitialRestriction public RestrictionT getInitialRestriction(Integer element) { @@ -242,6 +306,7 @@ public CoderT getRestrictionCoder() { SomeRestriction, SomeRestrictionTracker, SomeRestrictionCoder>() {}.getClass()); assertEquals(SomeRestrictionTracker.class, signature.processElement().trackerT().getRawType()); assertTrue(signature.processElement().isSplittable()); + assertTrue(signature.processElement().hasReturnValue()); assertEquals( SomeRestriction.class, signature.getInitialRestriction().restrictionT().getRawType()); assertEquals(SomeRestriction.class, signature.splitRestriction().restrictionT().getRawType()); From 016baf3465bbccbc9d3df6999b38b1b2533aee8c Mon Sep 17 00:00:00 2001 From: Colin Phipps Date: Mon, 10 Jul 2017 16:09:23 +0000 Subject: [PATCH 197/200] Implement retries in the read connector. Respect non-retryable error codes from Datastore. Add unit tests to check that retryable errors are retried. --- .../sdk/io/gcp/datastore/DatastoreV1.java | 45 +++++++++++++++- .../sdk/io/gcp/datastore/DatastoreV1Test.java | 51 ++++++++++++++++++- 2 files changed, 94 insertions(+), 2 deletions(-) diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/datastore/DatastoreV1.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/datastore/DatastoreV1.java index 5f65428141af3..1ed643014a739 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/datastore/DatastoreV1.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/datastore/DatastoreV1.java @@ -40,6 +40,7 @@ import com.google.common.base.MoreObjects; import com.google.common.base.Strings; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; import com.google.datastore.v1.CommitRequest; import com.google.datastore.v1.Entity; import com.google.datastore.v1.EntityResult; @@ -65,6 +66,7 @@ import java.util.ArrayList; import java.util.List; import java.util.NoSuchElementException; +import java.util.Set; import javax.annotation.Nullable; import org.apache.beam.sdk.PipelineRunner; import org.apache.beam.sdk.annotations.Experimental; @@ -237,6 +239,14 @@ public class DatastoreV1 { @VisibleForTesting static final int DATASTORE_BATCH_UPDATE_BYTES_LIMIT = 9_000_000; + /** + * Non-retryable errors. + * See https://cloud.google.com/datastore/docs/concepts/errors#Error_Codes . + */ + private static final Set NON_RETRYABLE_ERRORS = + ImmutableSet.of(Code.FAILED_PRECONDITION, Code.INVALID_ARGUMENT, Code.PERMISSION_DENIED, + Code.UNAUTHENTICATED); + /** * Returns an empty {@link DatastoreV1.Read} builder. Configure the source {@code projectId}, * {@code query}, and optionally {@code namespace} and {@code numQuerySplits} using @@ -840,6 +850,14 @@ static class ReadFn extends DoFn { private final V1DatastoreFactory datastoreFactory; // Datastore client private transient Datastore datastore; + private final Counter rpcErrors = + Metrics.counter(DatastoreWriterFn.class, "datastoreRpcErrors"); + private final Counter rpcSuccesses = + Metrics.counter(DatastoreWriterFn.class, "datastoreRpcSuccesses"); + private static final int MAX_RETRIES = 5; + private static final FluentBackoff RUNQUERY_BACKOFF = + FluentBackoff.DEFAULT + .withMaxRetries(MAX_RETRIES).withInitialBackoff(Duration.standardSeconds(5)); public ReadFn(V1Options options) { this(options, new V1DatastoreFactory()); @@ -857,6 +875,28 @@ public void startBundle(StartBundleContext c) throws Exception { options.getLocalhost()); } + private RunQueryResponse runQueryWithRetries(RunQueryRequest request) throws Exception { + Sleeper sleeper = Sleeper.DEFAULT; + BackOff backoff = RUNQUERY_BACKOFF.backoff(); + while (true) { + try { + RunQueryResponse response = datastore.runQuery(request); + rpcSuccesses.inc(); + return response; + } catch (DatastoreException exception) { + rpcErrors.inc(); + + if (NON_RETRYABLE_ERRORS.contains(exception.getCode())) { + throw exception; + } + if (!BackOffUtils.next(sleeper, backoff)) { + LOG.error("Aborting after {} retries.", MAX_RETRIES); + throw exception; + } + } + } + } + /** Read and output entities for the given query. */ @ProcessElement public void processElement(ProcessContext context) throws Exception { @@ -878,7 +918,7 @@ public void processElement(ProcessContext context) throws Exception { } RunQueryRequest request = makeRequest(queryBuilder.build(), namespace); - RunQueryResponse response = datastore.runQuery(request); + RunQueryResponse response = runQueryWithRetries(request); currentBatch = response.getBatch(); @@ -1328,6 +1368,9 @@ private void flushBatch() throws DatastoreException, IOException, InterruptedExc exception.getCode(), exception.getMessage()); rpcErrors.inc(); + if (NON_RETRYABLE_ERRORS.contains(exception.getCode())) { + throw exception; + } if (!BackOffUtils.next(sleeper, backoff)) { LOG.error("Aborting after {} retries.", MAX_RETRIES); throw exception; diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/datastore/DatastoreV1Test.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/datastore/DatastoreV1Test.java index a3f5d38ae886b..cfba1d6f95992 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/datastore/DatastoreV1Test.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/datastore/DatastoreV1Test.java @@ -51,6 +51,7 @@ import static org.mockito.Mockito.when; import com.google.datastore.v1.CommitRequest; +import com.google.datastore.v1.CommitResponse; import com.google.datastore.v1.Entity; import com.google.datastore.v1.EntityResult; import com.google.datastore.v1.GqlQuery; @@ -682,6 +683,29 @@ public void testDatatoreWriterFnWithLargeEntities() throws Exception { } } + /** Tests {@link DatastoreWriterFn} with a failed request which is retried. */ + @Test + public void testDatatoreWriterFnRetriesErrors() throws Exception { + List mutations = new ArrayList<>(); + int numRpcs = 2; + for (int i = 0; i < DATASTORE_BATCH_UPDATE_ENTITIES_START * numRpcs; ++i) { + mutations.add( + makeUpsert(Entity.newBuilder().setKey(makeKey("key" + i, i + 1)).build()).build()); + } + + CommitResponse successfulCommit = CommitResponse.getDefaultInstance(); + when(mockDatastore.commit(any(CommitRequest.class))).thenReturn(successfulCommit) + .thenThrow( + new DatastoreException("commit", Code.DEADLINE_EXCEEDED, "", null)) + .thenReturn(successfulCommit); + + DatastoreWriterFn datastoreWriter = new DatastoreWriterFn(StaticValueProvider.of(PROJECT_ID), + null, mockDatastoreFactory, new FakeWriteBatcher()); + DoFnTester doFnTester = DoFnTester.of(datastoreWriter); + doFnTester.setCloningBehavior(CloningBehavior.DO_NOT_CLONE); + doFnTester.processBundle(mutations); + } + /** * Tests {@link DatastoreV1.Read#getEstimatedSizeBytes} to fetch and return estimated size for a * query. @@ -816,6 +840,31 @@ public void testReadFnWithBatchesExactMultiple() throws Exception { readFnTest(5 * QUERY_BATCH_LIMIT); } + /** Tests that {@link ReadFn} retries after an error. */ + @Test + public void testReadFnRetriesErrors() throws Exception { + // An empty query to read entities. + Query query = Query.newBuilder().setLimit( + Int32Value.newBuilder().setValue(1)).build(); + + // Use mockResponseForQuery to generate results. + when(mockDatastore.runQuery(any(RunQueryRequest.class))) + .thenThrow( + new DatastoreException("RunQuery", Code.DEADLINE_EXCEEDED, "", null)) + .thenAnswer(new Answer() { + @Override + public RunQueryResponse answer(InvocationOnMock invocationOnMock) throws Throwable { + Query q = ((RunQueryRequest) invocationOnMock.getArguments()[0]).getQuery(); + return mockResponseForQuery(q); + } + }); + + ReadFn readFn = new ReadFn(V_1_OPTIONS, mockDatastoreFactory); + DoFnTester doFnTester = DoFnTester.of(readFn); + doFnTester.setCloningBehavior(CloningBehavior.DO_NOT_CLONE); + List entities = doFnTester.processBundle(query); + } + @Test public void testTranslateGqlQueryWithLimit() throws Exception { String gql = "SELECT * from DummyKind LIMIT 10"; @@ -1096,7 +1145,7 @@ public void addRequestLatency(long timeSinceEpochMillis, long latencyMillis, int } @Override public int nextBatchSize(long timeSinceEpochMillis) { - return DatastoreV1.DATASTORE_BATCH_UPDATE_ENTITIES_START; + return DATASTORE_BATCH_UPDATE_ENTITIES_START; } } } From 64997efa597a6fd74f4a6b6a7ab48d663c56845f Mon Sep 17 00:00:00 2001 From: Reuven Lax Date: Mon, 10 Jul 2017 21:30:50 -0700 Subject: [PATCH 198/200] Unbundle Context and WindowedContext. --- .../common/WriteOneFilePerWindow.java | 19 +- .../complete/game/utils/WriteToText.java | 18 +- .../WriteFilesTranslationTest.java | 12 +- .../beam/sdk/io/DefaultFilenamePolicy.java | 47 ++-- .../org/apache/beam/sdk/io/FileBasedSink.java | 198 ++++--------- .../org/apache/beam/sdk/io/AvroIOTest.java | 263 ++++++++++-------- .../apache/beam/sdk/io/FileBasedSinkTest.java | 88 +++--- .../apache/beam/sdk/io/WriteFilesTest.java | 122 ++++---- 8 files changed, 358 insertions(+), 409 deletions(-) diff --git a/examples/java/src/main/java/org/apache/beam/examples/common/WriteOneFilePerWindow.java b/examples/java/src/main/java/org/apache/beam/examples/common/WriteOneFilePerWindow.java index 49865ba60b7c9..abd14b70118fd 100644 --- a/examples/java/src/main/java/org/apache/beam/examples/common/WriteOneFilePerWindow.java +++ b/examples/java/src/main/java/org/apache/beam/examples/common/WriteOneFilePerWindow.java @@ -28,7 +28,9 @@ import org.apache.beam.sdk.io.fs.ResourceId; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.IntervalWindow; +import org.apache.beam.sdk.transforms.windowing.PaneInfo; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PDone; import org.joda.time.format.DateTimeFormatter; @@ -88,14 +90,18 @@ public String filenamePrefixForWindow(IntervalWindow window) { } @Override - public ResourceId windowedFilename(WindowedContext context, OutputFileHints outputFileHints) { - IntervalWindow window = (IntervalWindow) context.getWindow(); + public ResourceId windowedFilename(int shardNumber, + int numShards, + BoundedWindow window, + PaneInfo paneInfo, + OutputFileHints outputFileHints) { + IntervalWindow intervalWindow = (IntervalWindow) window; String filename = String.format( "%s-%s-of-%s%s", - filenamePrefixForWindow(window), - context.getShardNumber(), - context.getNumShards(), + filenamePrefixForWindow(intervalWindow), + shardNumber, + numShards, outputFileHints.getSuggestedFilenameSuffix()); return baseFilename .getCurrentDirectory() @@ -103,7 +109,8 @@ public ResourceId windowedFilename(WindowedContext context, OutputFileHints outp } @Override - public ResourceId unwindowedFilename(Context context, OutputFileHints outputFileHints) { + public ResourceId unwindowedFilename( + int shardNumber, int numShards, OutputFileHints outputFileHints) { throw new UnsupportedOperationException("Unsupported."); } } diff --git a/examples/java8/src/main/java/org/apache/beam/examples/complete/game/utils/WriteToText.java b/examples/java8/src/main/java/org/apache/beam/examples/complete/game/utils/WriteToText.java index 1d601987211b7..6b7c928ee2e46 100644 --- a/examples/java8/src/main/java/org/apache/beam/examples/complete/game/utils/WriteToText.java +++ b/examples/java8/src/main/java/org/apache/beam/examples/complete/game/utils/WriteToText.java @@ -36,6 +36,7 @@ import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.IntervalWindow; +import org.apache.beam.sdk.transforms.windowing.PaneInfo; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PDone; import org.joda.time.DateTimeZone; @@ -143,20 +144,25 @@ public String filenamePrefixForWindow(IntervalWindow window) { } @Override - public ResourceId windowedFilename(WindowedContext context, OutputFileHints outputFileHints) { - IntervalWindow window = (IntervalWindow) context.getWindow(); + public ResourceId windowedFilename(int shardNumber, + int numShards, + BoundedWindow window, + PaneInfo paneInfo, + OutputFileHints outputFileHints) { + IntervalWindow intervalWindow = (IntervalWindow) window; String filename = String.format( "%s-%s-of-%s%s", - filenamePrefixForWindow(window), - context.getShardNumber(), - context.getNumShards(), + filenamePrefixForWindow(intervalWindow), + shardNumber, + numShards, outputFileHints.getSuggestedFilenameSuffix()); return prefix.getCurrentDirectory().resolve(filename, StandardResolveOptions.RESOLVE_FILE); } @Override - public ResourceId unwindowedFilename(Context context, OutputFileHints outputFileHints) { + public ResourceId unwindowedFilename( + int shardNumber, int numShards, OutputFileHints outputFileHints) { throw new UnsupportedOperationException("Unsupported."); } } diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/WriteFilesTranslationTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/WriteFilesTranslationTest.java index 283df1657dedb..4259ac893b64c 100644 --- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/WriteFilesTranslationTest.java +++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/WriteFilesTranslationTest.java @@ -40,6 +40,8 @@ import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.transforms.SerializableFunctions; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.PaneInfo; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PDone; import org.junit.Test; @@ -163,13 +165,19 @@ public FileBasedSink.Writer createWriter() throws Exception { private static class DummyFilenamePolicy extends FilenamePolicy { @Override - public ResourceId windowedFilename(WindowedContext c, OutputFileHints outputFileHints) { + public ResourceId windowedFilename( + int shardNumber, + int numShards, + BoundedWindow window, + PaneInfo paneInfo, + OutputFileHints outputFileHints) { throw new UnsupportedOperationException("Should never be called."); } @Nullable @Override - public ResourceId unwindowedFilename(Context c, OutputFileHints outputFileHints) { + public ResourceId unwindowedFilename( + int shardNumber, int numShards, OutputFileHints outputFileHints) { throw new UnsupportedOperationException("Should never be called."); } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/DefaultFilenamePolicy.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/DefaultFilenamePolicy.java index 7a60e49ebfb03..64d7edc45f007 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/DefaultFilenamePolicy.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/DefaultFilenamePolicy.java @@ -52,19 +52,19 @@ * with the number of shards, index of the particular file, current window and pane information, * using {@link #constructName}. * - *

    Most users will use this {@link DefaultFilenamePolicy}. For more advanced - * uses in generating different files for each window and other sharding controls, see the - * {@code WriteOneFilePerWindow} example pipeline. + *

    Most users will use this {@link DefaultFilenamePolicy}. For more advanced uses in generating + * different files for each window and other sharding controls, see the {@code + * WriteOneFilePerWindow} example pipeline. */ public final class DefaultFilenamePolicy extends FilenamePolicy { /** The default sharding name template. */ public static final String DEFAULT_UNWINDOWED_SHARD_TEMPLATE = ShardNameTemplate.INDEX_OF_MAX; - /** The default windowed sharding name template used when writing windowed files. - * This is used as default in cases when user did not specify shard template to - * be used and there is a need to write windowed files. In cases when user does - * specify shard template to be used then provided template will be used for both - * windowed and non-windowed file names. + /** + * The default windowed sharding name template used when writing windowed files. This is used as + * default in cases when user did not specify shard template to be used and there is a need to + * write windowed files. In cases when user does specify shard template to be used then provided + * template will be used for both windowed and non-windowed file names. */ private static final String DEFAULT_WINDOWED_SHARD_TEMPLATE = "W-P" + DEFAULT_UNWINDOWED_SHARD_TEMPLATE; @@ -190,11 +190,11 @@ public Params decode(InputStream inStream) throws IOException { *

    This is a shortcut for: * *

    {@code
    -   *   DefaultFilenamePolicy.fromParams(new Params()
    -   *     .withBaseFilename(baseFilename)
    -   *     .withShardTemplate(shardTemplate)
    -   *     .withSuffix(filenameSuffix)
    -   *     .withWindowedWrites())
    +   * DefaultFilenamePolicy.fromParams(new Params()
    +   *   .withBaseFilename(baseFilename)
    +   *   .withShardTemplate(shardTemplate)
    +   *   .withSuffix(filenameSuffix)
    +   *   .withWindowedWrites())
        * }
    * *

    Where the respective {@code with} methods are invoked only if the value is non-null or true. @@ -284,28 +284,33 @@ static ResourceId constructName( @Override @Nullable - public ResourceId unwindowedFilename(Context context, OutputFileHints outputFileHints) { + public ResourceId unwindowedFilename( + int shardNumber, int numShards, OutputFileHints outputFileHints) { return constructName( params.baseFilename.get(), params.shardTemplate, params.suffix + outputFileHints.getSuggestedFilenameSuffix(), - context.getShardNumber(), - context.getNumShards(), + shardNumber, + numShards, null, null); } @Override - public ResourceId windowedFilename(WindowedContext context, OutputFileHints outputFileHints) { - final PaneInfo paneInfo = context.getPaneInfo(); + public ResourceId windowedFilename( + int shardNumber, + int numShards, + BoundedWindow window, + PaneInfo paneInfo, + OutputFileHints outputFileHints) { String paneStr = paneInfoToString(paneInfo); - String windowStr = windowToString(context.getWindow()); + String windowStr = windowToString(window); return constructName( params.baseFilename.get(), params.shardTemplate, params.suffix + outputFileHints.getSuggestedFilenameSuffix(), - context.getShardNumber(), - context.getNumShards(), + shardNumber, + numShards, paneStr, windowStr); } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileBasedSink.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileBasedSink.java index 583af60df68b3..c68b79438f0b8 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileBasedSink.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileBasedSink.java @@ -58,8 +58,6 @@ import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.coders.StructuredCoder; import org.apache.beam.sdk.coders.VarIntCoder; -import org.apache.beam.sdk.io.FileBasedSink.FilenamePolicy.Context; -import org.apache.beam.sdk.io.FileBasedSink.FilenamePolicy.WindowedContext; import org.apache.beam.sdk.io.fs.MatchResult; import org.apache.beam.sdk.io.fs.MatchResult.Metadata; import org.apache.beam.sdk.io.fs.MoveOptions.StandardMoveOptions; @@ -96,9 +94,9 @@ *

    The process of writing to file-based sink is as follows: * *

      - *
    1. An optional subclass-defined initialization, - *
    2. a parallel write of bundles to temporary files, and finally, - *
    3. these temporary files are renamed with final output filenames. + *
    4. An optional subclass-defined initialization, + *
    5. a parallel write of bundles to temporary files, and finally, + *
    6. these temporary files are renamed with final output filenames. *
    * *

    In order to ensure fault-tolerance, a bundle may be executed multiple times (e.g., in the @@ -125,46 +123,36 @@ public abstract class FileBasedSink implements Serializable, HasDisplayData { private static final Logger LOG = LoggerFactory.getLogger(FileBasedSink.class); - /** - * Directly supported file output compression types. - */ + /** Directly supported file output compression types. */ public enum CompressionType implements WritableByteChannelFactory { - /** - * No compression, or any other transformation, will be used. - */ + /** No compression, or any other transformation, will be used. */ UNCOMPRESSED("", null) { @Override public WritableByteChannel create(WritableByteChannel channel) throws IOException { return channel; } }, - /** - * Provides GZip output transformation. - */ + /** Provides GZip output transformation. */ GZIP(".gz", MimeTypes.BINARY) { @Override public WritableByteChannel create(WritableByteChannel channel) throws IOException { return Channels.newChannel(new GZIPOutputStream(Channels.newOutputStream(channel), true)); } }, - /** - * Provides BZip2 output transformation. - */ + /** Provides BZip2 output transformation. */ BZIP2(".bz2", MimeTypes.BINARY) { @Override public WritableByteChannel create(WritableByteChannel channel) throws IOException { - return Channels - .newChannel(new BZip2CompressorOutputStream(Channels.newOutputStream(channel))); + return Channels.newChannel( + new BZip2CompressorOutputStream(Channels.newOutputStream(channel))); } }, - /** - * Provides deflate output transformation. - */ + /** Provides deflate output transformation. */ DEFLATE(".deflate", MimeTypes.BINARY) { @Override public WritableByteChannel create(WritableByteChannel channel) throws IOException { - return Channels - .newChannel(new DeflateCompressorOutputStream(Channels.newOutputStream(channel))); + return Channels.newChannel( + new DeflateCompressorOutputStream(Channels.newOutputStream(channel))); } }; @@ -182,7 +170,8 @@ public String getSuggestedFilenameSuffix() { } @Override - @Nullable public String getMimeType() { + @Nullable + public String getMimeType() { return mimeType; } } @@ -213,8 +202,8 @@ public static ResourceId convertToFileResourceIfPossible(String outputPrefix) { /** * The {@link WritableByteChannelFactory} that is used to wrap the raw data output to the - * underlying channel. The default is to not compress the output using - * {@link CompressionType#UNCOMPRESSED}. + * underlying channel. The default is to not compress the output using {@link + * CompressionType#UNCOMPRESSED}. */ private final WritableByteChannelFactory writableByteChannelFactory; @@ -284,86 +273,21 @@ final Coder getDestinationCoderWithDefault(CoderRegistry registry) /** A naming policy for output files. */ @Experimental(Kind.FILESYSTEM) public abstract static class FilenamePolicy implements Serializable { - /** - * Context used for generating a name based on shard number, and num shards. - * The policy must produce unique filenames for unique {@link Context} objects. - * - *

    Be careful about adding fields to this as existing strategies will not notice the new - * fields, and may not produce unique filenames. - */ - public static class Context { - private int shardNumber; - private int numShards; - - - public Context(int shardNumber, int numShards) { - this.shardNumber = shardNumber; - this.numShards = numShards; - } - - public int getShardNumber() { - return shardNumber; - } - - - public int getNumShards() { - return numShards; - } - } - - /** - * Context used for generating a name based on window, pane, shard number, and num shards. - * The policy must produce unique filenames for unique {@link WindowedContext} objects. - * - *

    Be careful about adding fields to this as existing strategies will not notice the new - * fields, and may not produce unique filenames. - */ - public static class WindowedContext { - private int shardNumber; - private int numShards; - private BoundedWindow window; - private PaneInfo paneInfo; - - public WindowedContext( - BoundedWindow window, - PaneInfo paneInfo, - int shardNumber, - int numShards) { - this.window = window; - this.paneInfo = paneInfo; - this.shardNumber = shardNumber; - this.numShards = numShards; - } - - public BoundedWindow getWindow() { - return window; - } - - public PaneInfo getPaneInfo() { - return paneInfo; - } - - public int getShardNumber() { - return shardNumber; - } - - public int getNumShards() { - return numShards; - } - } - /** * When a sink has requested windowed or triggered output, this method will be invoked to return * the file {@link ResourceId resource} to be created given the base output directory and a * {@link OutputFileHints} containing information about the file, including a suggested * extension (e.g. coming from {@link CompressionType}). * - *

    The {@link WindowedContext} object gives access to the window and pane, as well as - * sharding information. The policy must return unique and consistent filenames for different - * windows and panes. + *

    The policy must return unique and consistent filenames for different windows and panes. */ @Experimental(Kind.FILESYSTEM) - public abstract ResourceId windowedFilename(WindowedContext c, OutputFileHints outputFileHints); + public abstract ResourceId windowedFilename( + int shardNumber, + int numShards, + BoundedWindow window, + PaneInfo paneInfo, + OutputFileHints outputFileHints); /** * When a sink has not requested windowed or triggered output, this method will be invoked to @@ -371,18 +295,16 @@ public int getNumShards() { * a {@link OutputFileHints} containing information about the file, including a suggested (e.g. * coming from {@link CompressionType}). * - *

    The {@link Context} object only provides sharding information, which is used by the policy - * to generate unique and consistent filenames. + *

    The shardNumber and numShards parameters, should be used by the policy to generate unique + * and consistent filenames. */ @Experimental(Kind.FILESYSTEM) @Nullable - public abstract ResourceId unwindowedFilename(Context c, OutputFileHints outputFileHints); + public abstract ResourceId unwindowedFilename( + int shardNumber, int numShards, OutputFileHints outputFileHints); - /** - * Populates the display data. - */ - public void populateDisplayData(DisplayData.Builder builder) { - } + /** Populates the display data. */ + public void populateDisplayData(DisplayData.Builder builder) {} } /** The directory to which files will be written. */ @@ -449,11 +371,11 @@ public void populateDisplayData(DisplayData.Builder builder) { * written, * *

      - *
    1. {@link WriteOperation#finalize} is given a list of the temporary files containing the - * output bundles. - *
    2. During finalize, these temporary files are copied to final output locations and named - * according to a file naming template. - *
    3. Finally, any temporary files that were created during the write are removed. + *
    4. {@link WriteOperation#finalize} is given a list of the temporary files containing the + * output bundles. + *
    5. During finalize, these temporary files are copied to final output locations and named + * according to a file naming template. + *
    6. Finally, any temporary files that were created during the write are removed. *
    * *

    Subclass implementations of WriteOperation must implement {@link @@ -558,9 +480,7 @@ private WriteOperation( */ public abstract Writer createWriter() throws Exception; - /** - * Indicates that the operation will be performing windowed writes. - */ + /** Indicates that the operation will be performing windowed writes. */ public void setWindowedWrites(boolean windowedWrites) { this.windowedWrites = windowedWrites; } @@ -659,9 +579,11 @@ public int compare( } int numDistinctShards = new HashSet<>(outputFilenames.values()).size(); - checkState(numDistinctShards == outputFilenames.size(), - "Only generated %s distinct file names for %s files.", - numDistinctShards, outputFilenames.size()); + checkState( + numDistinctShards == outputFilenames.size(), + "Only generated %s distinct file names for %s files.", + numDistinctShards, + outputFilenames.size()); return outputFilenames; } @@ -726,8 +648,9 @@ final void removeTemporaryFiles( // ignore the exception for now to avoid failing the pipeline. if (shouldRemoveTemporaryDirectory) { try { - MatchResult singleMatch = Iterables.getOnlyElement( - FileSystems.match(Collections.singletonList(tempDir.toString() + "*"))); + MatchResult singleMatch = + Iterables.getOnlyElement( + FileSystems.match(Collections.singletonList(tempDir.toString() + "*"))); for (Metadata matchResult : singleMatch.metadata()) { matches.add(matchResult.resourceId()); } @@ -807,18 +730,16 @@ public abstract static class Writer { /** The output file for this bundle. May be null if opening failed. */ private @Nullable ResourceId outputFile; - /** - * The channel to write to. - */ + /** The channel to write to. */ private WritableByteChannel channel; /** * The MIME type used in the creation of the output channel (if the file system supports it). * - *

    This is the default for the sink, but it may be overridden by a supplied - * {@link WritableByteChannelFactory}. For example, {@link TextIO.Write} uses - * {@link MimeTypes#TEXT} by default but if {@link CompressionType#BZIP2} is set then - * the MIME type will be overridden to {@link MimeTypes#BINARY}. + *

    This is the default for the sink, but it may be overridden by a supplied {@link + * WritableByteChannelFactory}. For example, {@link TextIO.Write} uses {@link MimeTypes#TEXT} by + * default but if {@link CompressionType#BZIP2} is set then the MIME type will be overridden to + * {@link MimeTypes#BINARY}. */ private final String mimeType; @@ -843,14 +764,12 @@ public Writer(WriteOperation writeOperation, String mimeT */ protected void writeHeader() throws Exception {} - /** - * Writes footer at the end of output files. Nothing by default; subclasses may override. - */ + /** Writes footer at the end of output files. Nothing by default; subclasses may override. */ protected void writeFooter() throws Exception {} /** - * Called after all calls to {@link #writeHeader}, {@link #write} and {@link #writeFooter}. - * If any resources opened in the write processes need to be flushed, flush them here. + * Called after all calls to {@link #writeHeader}, {@link #write} and {@link #writeFooter}. If + * any resources opened in the write processes need to be flushed, flush them here. */ protected void finishWrite() throws Exception {} @@ -875,9 +794,7 @@ public final void openWindowed( open(uId, window, paneInfo, shard, destination); } - /** - * Called for each value in the bundle. - */ + /** Called for each value in the bundle. */ public abstract void write(OutputT value) throws Exception; /** @@ -982,7 +899,9 @@ public final FileResult close() throws Exception { checkState( channel.isOpen(), - "Channel %s to %s should only be closed by its owner: %s", channel, outputFile); + "Channel %s to %s should only be closed by its owner: %s", + channel, + outputFile); LOG.debug("Closing channel to {}.", outputFile); try { @@ -1063,10 +982,9 @@ public ResourceId getDestinationFile( FilenamePolicy policy = dynamicDestinations.getFilenamePolicy(destination); if (getWindow() != null) { return policy.windowedFilename( - new WindowedContext(getWindow(), getPaneInfo(), getShard(), numShards), - outputFileHints); + getShard(), numShards, getWindow(), getPaneInfo(), outputFileHints); } else { - return policy.unwindowedFilename(new Context(getShard(), numShards), outputFileHints); + return policy.unwindowedFilename(getShard(), numShards, outputFileHints); } } @@ -1154,7 +1072,7 @@ public interface OutputFileHints extends Serializable { * * @see MimeTypes * @see http://www.iana.org/assignments/media-types/media-types.xhtml + * 'http://www.iana.org/assignments/media-types/media-types.xhtml'>http://www.iana.org/assignments/media-types/media-types.xhtml */ @Nullable String getMimeType(); diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/AvroIOTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/AvroIOTest.java index 260e47a25a5cb..4a1386cbf575b 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/AvroIOTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/AvroIOTest.java @@ -68,8 +68,10 @@ import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.display.DisplayData; import org.apache.beam.sdk.transforms.display.DisplayDataEvaluator; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.FixedWindows; import org.apache.beam.sdk.transforms.windowing.IntervalWindow; +import org.apache.beam.sdk.transforms.windowing.PaneInfo; import org.apache.beam.sdk.transforms.windowing.Window; import org.apache.beam.sdk.util.SerializableUtils; import org.apache.beam.sdk.values.PCollection; @@ -84,20 +86,15 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; -/** - * Tests for AvroIO Read and Write transforms. - */ +/** Tests for AvroIO Read and Write transforms. */ @RunWith(JUnit4.class) public class AvroIOTest { - @Rule - public TestPipeline p = TestPipeline.create(); + @Rule public TestPipeline p = TestPipeline.create(); - @Rule - public TemporaryFolder tmpFolder = new TemporaryFolder(); + @Rule public TemporaryFolder tmpFolder = new TemporaryFolder(); - @Rule - public ExpectedException expectedException = ExpectedException.none(); + @Rule public ExpectedException expectedException = ExpectedException.none(); @Test public void testAvroIOGetName() { @@ -109,11 +106,14 @@ public void testAvroIOGetName() { static class GenericClass { int intField; String stringField; + public GenericClass() {} + public GenericClass(int intValue, String stringValue) { this.intField = intValue; this.stringField = stringValue; } + @Override public String toString() { return MoreObjects.toStringHelper(getClass()) @@ -121,10 +121,12 @@ public String toString() { .add("stringField", stringField) .toString(); } + @Override public int hashCode() { return Objects.hash(intField, stringField); } + @Override public boolean equals(Object other) { if (other == null || !(other instanceof GenericClass)) { @@ -138,20 +140,16 @@ public boolean equals(Object other) { @Test @Category(NeedsRunner.class) public void testAvroIOWriteAndReadASingleFile() throws Throwable { - List values = ImmutableList.of(new GenericClass(3, "hi"), - new GenericClass(5, "bar")); + List values = + ImmutableList.of(new GenericClass(3, "hi"), new GenericClass(5, "bar")); File outputFile = tmpFolder.newFile("output.avro"); p.apply(Create.of(values)) - .apply(AvroIO.write(GenericClass.class) - .to(outputFile.getAbsolutePath()) - .withoutSharding()); + .apply(AvroIO.write(GenericClass.class).to(outputFile.getAbsolutePath()).withoutSharding()); p.run(); PCollection input = - p.apply( - AvroIO.read(GenericClass.class) - .from(outputFile.getAbsolutePath())); + p.apply(AvroIO.read(GenericClass.class).from(outputFile.getAbsolutePath())); PAssert.that(input).containsInAnyOrder(values); p.run(); @@ -161,25 +159,25 @@ public void testAvroIOWriteAndReadASingleFile() throws Throwable { @SuppressWarnings("unchecked") @Category(NeedsRunner.class) public void testAvroIOCompressedWriteAndReadASingleFile() throws Throwable { - List values = ImmutableList.of(new GenericClass(3, "hi"), - new GenericClass(5, "bar")); + List values = + ImmutableList.of(new GenericClass(3, "hi"), new GenericClass(5, "bar")); File outputFile = tmpFolder.newFile("output.avro"); p.apply(Create.of(values)) - .apply(AvroIO.write(GenericClass.class) - .to(outputFile.getAbsolutePath()) - .withoutSharding() - .withCodec(CodecFactory.deflateCodec(9))); + .apply( + AvroIO.write(GenericClass.class) + .to(outputFile.getAbsolutePath()) + .withoutSharding() + .withCodec(CodecFactory.deflateCodec(9))); p.run(); - PCollection input = p - .apply(AvroIO.read(GenericClass.class) - .from(outputFile.getAbsolutePath())); + PCollection input = + p.apply(AvroIO.read(GenericClass.class).from(outputFile.getAbsolutePath())); PAssert.that(input).containsInAnyOrder(values); p.run(); - DataFileStream dataFileStream = new DataFileStream(new FileInputStream(outputFile), - new GenericDatumReader()); + DataFileStream dataFileStream = + new DataFileStream(new FileInputStream(outputFile), new GenericDatumReader()); assertEquals("deflate", dataFileStream.getMetaString("avro.codec")); } @@ -187,25 +185,25 @@ public void testAvroIOCompressedWriteAndReadASingleFile() throws Throwable { @SuppressWarnings("unchecked") @Category(NeedsRunner.class) public void testAvroIONullCodecWriteAndReadASingleFile() throws Throwable { - List values = ImmutableList.of(new GenericClass(3, "hi"), - new GenericClass(5, "bar")); + List values = + ImmutableList.of(new GenericClass(3, "hi"), new GenericClass(5, "bar")); File outputFile = tmpFolder.newFile("output.avro"); p.apply(Create.of(values)) - .apply(AvroIO.write(GenericClass.class) - .to(outputFile.getAbsolutePath()) - .withoutSharding() - .withCodec(CodecFactory.nullCodec())); + .apply( + AvroIO.write(GenericClass.class) + .to(outputFile.getAbsolutePath()) + .withoutSharding() + .withCodec(CodecFactory.nullCodec())); p.run(); - PCollection input = p - .apply(AvroIO.read(GenericClass.class) - .from(outputFile.getAbsolutePath())); + PCollection input = + p.apply(AvroIO.read(GenericClass.class).from(outputFile.getAbsolutePath())); PAssert.that(input).containsInAnyOrder(values); p.run(); - DataFileStream dataFileStream = new DataFileStream(new FileInputStream(outputFile), - new GenericDatumReader()); + DataFileStream dataFileStream = + new DataFileStream(new FileInputStream(outputFile), new GenericDatumReader()); assertEquals("null", dataFileStream.getMetaString("avro.codec")); } @@ -214,12 +212,15 @@ static class GenericClassV2 { int intField; String stringField; @Nullable String nullableField; + public GenericClassV2() {} + public GenericClassV2(int intValue, String stringValue, String nullableValue) { this.intField = intValue; this.stringField = stringValue; this.nullableField = nullableValue; } + @Override public String toString() { return MoreObjects.toStringHelper(getClass()) @@ -228,10 +229,12 @@ public String toString() { .add("nullableField", nullableField) .toString(); } + @Override public int hashCode() { return Objects.hash(intField, stringField, nullableField); } + @Override public boolean equals(Object other) { if (other == null || !(other instanceof GenericClassV2)) { @@ -245,32 +248,28 @@ public boolean equals(Object other) { } /** - * Tests that {@code AvroIO} can read an upgraded version of an old class, as long as the - * schema resolution process succeeds. This test covers the case when a new, {@code @Nullable} - * field has been added. + * Tests that {@code AvroIO} can read an upgraded version of an old class, as long as the schema + * resolution process succeeds. This test covers the case when a new, {@code @Nullable} field has + * been added. * *

    For more information, see http://avro.apache.org/docs/1.7.7/spec.html#Schema+Resolution */ @Test @Category(NeedsRunner.class) public void testAvroIOWriteAndReadSchemaUpgrade() throws Throwable { - List values = ImmutableList.of(new GenericClass(3, "hi"), - new GenericClass(5, "bar")); + List values = + ImmutableList.of(new GenericClass(3, "hi"), new GenericClass(5, "bar")); File outputFile = tmpFolder.newFile("output.avro"); p.apply(Create.of(values)) - .apply(AvroIO.write(GenericClass.class) - .to(outputFile.getAbsolutePath()) - .withoutSharding()); + .apply(AvroIO.write(GenericClass.class).to(outputFile.getAbsolutePath()).withoutSharding()); p.run(); - List expected = ImmutableList.of(new GenericClassV2(3, "hi", null), - new GenericClassV2(5, "bar", null)); + List expected = + ImmutableList.of(new GenericClassV2(3, "hi", null), new GenericClassV2(5, "bar", null)); PCollection input = - p.apply( - AvroIO.read(GenericClassV2.class) - .from(outputFile.getAbsolutePath())); + p.apply(AvroIO.read(GenericClassV2.class).from(outputFile.getAbsolutePath())); PAssert.that(input).containsInAnyOrder(expected); p.run(); @@ -284,7 +283,12 @@ private static class WindowedFilenamePolicy extends FilenamePolicy { } @Override - public ResourceId windowedFilename(WindowedContext input, OutputFileHints outputFileHints) { + public ResourceId windowedFilename( + int shardNumber, + int numShards, + BoundedWindow window, + PaneInfo paneInfo, + OutputFileHints outputFileHints) { String filenamePrefix = outputFilePrefix.isDirectory() ? "" : firstNonNull(outputFilePrefix.getFilename(), ""); @@ -292,11 +296,11 @@ public ResourceId windowedFilename(WindowedContext input, OutputFileHints output String.format( "%s-%s-%s-of-%s-pane-%s%s%s", filenamePrefix, - input.getWindow(), - input.getShardNumber(), - input.getNumShards() - 1, - input.getPaneInfo().getIndex(), - input.getPaneInfo().isLast() ? "-final" : "", + window, + shardNumber, + numShards - 1, + paneInfo.getIndex(), + paneInfo.isLast() ? "-final" : "", outputFileHints.getSuggestedFilenameSuffix()); return outputFilePrefix .getCurrentDirectory() @@ -304,7 +308,8 @@ public ResourceId windowedFilename(WindowedContext input, OutputFileHints output } @Override - public ResourceId unwindowedFilename(Context input, OutputFileHints outputFileHints) { + public ResourceId unwindowedFilename( + int shardNumber, int numShards, OutputFileHints outputFileHints) { throw new UnsupportedOperationException("Expecting windowed outputs only"); } @@ -316,8 +321,7 @@ public void populateDisplayData(DisplayData.Builder builder) { } } - @Rule - public TestPipeline windowedAvroWritePipeline = TestPipeline.create(); + @Rule public TestPipeline windowedAvroWritePipeline = TestPipeline.create(); @Test @Category({ValidatesRunner.class, UsesTestStream.class}) @@ -328,27 +332,31 @@ public void testWindowedAvroIOWrite() throws Throwable { Instant base = new Instant(0); ArrayList allElements = new ArrayList<>(); ArrayList> firstWindowElements = new ArrayList<>(); - ArrayList firstWindowTimestamps = Lists.newArrayList( - base.plus(Duration.standardSeconds(0)), base.plus(Duration.standardSeconds(10)), - base.plus(Duration.standardSeconds(20)), base.plus(Duration.standardSeconds(30))); + ArrayList firstWindowTimestamps = + Lists.newArrayList( + base.plus(Duration.standardSeconds(0)), base.plus(Duration.standardSeconds(10)), + base.plus(Duration.standardSeconds(20)), base.plus(Duration.standardSeconds(30))); Random random = new Random(); for (int i = 0; i < 100; ++i) { GenericClass item = new GenericClass(i, String.valueOf(i)); allElements.add(item); - firstWindowElements.add(TimestampedValue.of(item, - firstWindowTimestamps.get(random.nextInt(firstWindowTimestamps.size())))); + firstWindowElements.add( + TimestampedValue.of( + item, firstWindowTimestamps.get(random.nextInt(firstWindowTimestamps.size())))); } ArrayList> secondWindowElements = new ArrayList<>(); - ArrayList secondWindowTimestamps = Lists.newArrayList( - base.plus(Duration.standardSeconds(60)), base.plus(Duration.standardSeconds(70)), - base.plus(Duration.standardSeconds(80)), base.plus(Duration.standardSeconds(90))); + ArrayList secondWindowTimestamps = + Lists.newArrayList( + base.plus(Duration.standardSeconds(60)), base.plus(Duration.standardSeconds(70)), + base.plus(Duration.standardSeconds(80)), base.plus(Duration.standardSeconds(90))); for (int i = 100; i < 200; ++i) { GenericClass item = new GenericClass(i, String.valueOf(i)); allElements.add(new GenericClass(i, String.valueOf(i))); - secondWindowElements.add(TimestampedValue.of(item, - secondWindowTimestamps.get(random.nextInt(secondWindowTimestamps.size())))); + secondWindowElements.add( + TimestampedValue.of( + item, secondWindowTimestamps.get(random.nextInt(secondWindowTimestamps.size())))); } TimestampedValue[] firstWindowArray = @@ -356,14 +364,17 @@ public void testWindowedAvroIOWrite() throws Throwable { TimestampedValue[] secondWindowArray = secondWindowElements.toArray(new TimestampedValue[100]); - TestStream values = TestStream.create(AvroCoder.of(GenericClass.class)) - .advanceWatermarkTo(new Instant(0)) - .addElements(firstWindowArray[0], - Arrays.copyOfRange(firstWindowArray, 1, firstWindowArray.length)) - .advanceWatermarkTo(new Instant(0).plus(Duration.standardMinutes(1))) - .addElements(secondWindowArray[0], - Arrays.copyOfRange(secondWindowArray, 1, secondWindowArray.length)) - .advanceWatermarkToInfinity(); + TestStream values = + TestStream.create(AvroCoder.of(GenericClass.class)) + .advanceWatermarkTo(new Instant(0)) + .addElements( + firstWindowArray[0], + Arrays.copyOfRange(firstWindowArray, 1, firstWindowArray.length)) + .advanceWatermarkTo(new Instant(0).plus(Duration.standardMinutes(1))) + .addElements( + secondWindowArray[0], + Arrays.copyOfRange(secondWindowArray, 1, secondWindowArray.length)) + .advanceWatermarkToInfinity(); FilenamePolicy policy = new WindowedFilenamePolicy(FileBasedSink.convertToFileResourceIfPossible(baseFilename)); @@ -384,11 +395,17 @@ public void testWindowedAvroIOWrite() throws Throwable { for (int shard = 0; shard < 2; shard++) { for (int window = 0; window < 2; window++) { Instant windowStart = new Instant(0).plus(Duration.standardMinutes(window)); - IntervalWindow intervalWindow = new IntervalWindow( - windowStart, Duration.standardMinutes(1)); + IntervalWindow intervalWindow = + new IntervalWindow(windowStart, Duration.standardMinutes(1)); expectedFiles.add( - new File(baseFilename + "-" + intervalWindow.toString() + "-" + shard - + "-of-1" + "-pane-0-final")); + new File( + baseFilename + + "-" + + intervalWindow.toString() + + "-" + + shard + + "-of-1" + + "-pane-0-final")); } } @@ -396,9 +413,10 @@ public void testWindowedAvroIOWrite() throws Throwable { for (File outputFile : expectedFiles) { assertTrue("Expected output file " + outputFile.getAbsolutePath(), outputFile.exists()); try (DataFileReader reader = - new DataFileReader<>(outputFile, - new ReflectDatumReader( - ReflectData.get().getSchema(GenericClass.class)))) { + new DataFileReader<>( + outputFile, + new ReflectDatumReader( + ReflectData.get().getSchema(GenericClass.class)))) { Iterators.addAll(actualElements, reader); } outputFile.delete(); @@ -408,25 +426,22 @@ public void testWindowedAvroIOWrite() throws Throwable { @Test public void testWriteWithDefaultCodec() throws Exception { - AvroIO.Write write = AvroIO.write(String.class) - .to("/tmp/foo/baz"); + AvroIO.Write write = AvroIO.write(String.class).to("/tmp/foo/baz"); assertEquals(CodecFactory.deflateCodec(6).toString(), write.getCodec().toString()); } @Test public void testWriteWithCustomCodec() throws Exception { - AvroIO.Write write = AvroIO.write(String.class) - .to("/tmp/foo/baz") - .withCodec(CodecFactory.snappyCodec()); + AvroIO.Write write = + AvroIO.write(String.class).to("/tmp/foo/baz").withCodec(CodecFactory.snappyCodec()); assertEquals(SNAPPY_CODEC, write.getCodec().toString()); } @Test @SuppressWarnings("unchecked") public void testWriteWithSerDeCustomDeflateCodec() throws Exception { - AvroIO.Write write = AvroIO.write(String.class) - .to("/tmp/foo/baz") - .withCodec(CodecFactory.deflateCodec(9)); + AvroIO.Write write = + AvroIO.write(String.class).to("/tmp/foo/baz").withCodec(CodecFactory.deflateCodec(9)); assertEquals( CodecFactory.deflateCodec(9).toString(), @@ -436,9 +451,8 @@ public void testWriteWithSerDeCustomDeflateCodec() throws Exception { @Test @SuppressWarnings("unchecked") public void testWriteWithSerDeCustomXZCodec() throws Exception { - AvroIO.Write write = AvroIO.write(String.class) - .to("/tmp/foo/baz") - .withCodec(CodecFactory.xzCodec(9)); + AvroIO.Write write = + AvroIO.write(String.class).to("/tmp/foo/baz").withCodec(CodecFactory.xzCodec(9)); assertEquals( CodecFactory.xzCodec(9).toString(), @@ -449,28 +463,32 @@ public void testWriteWithSerDeCustomXZCodec() throws Exception { @SuppressWarnings("unchecked") @Category(NeedsRunner.class) public void testMetadata() throws Exception { - List values = ImmutableList.of(new GenericClass(3, "hi"), - new GenericClass(5, "bar")); + List values = + ImmutableList.of(new GenericClass(3, "hi"), new GenericClass(5, "bar")); File outputFile = tmpFolder.newFile("output.avro"); p.apply(Create.of(values)) - .apply(AvroIO.write(GenericClass.class) - .to(outputFile.getAbsolutePath()) - .withoutSharding() - .withMetadata(ImmutableMap.of( - "stringKey", "stringValue", - "longKey", 100L, - "bytesKey", "bytesValue".getBytes()))); + .apply( + AvroIO.write(GenericClass.class) + .to(outputFile.getAbsolutePath()) + .withoutSharding() + .withMetadata( + ImmutableMap.of( + "stringKey", + "stringValue", + "longKey", + 100L, + "bytesKey", + "bytesValue".getBytes()))); p.run(); - DataFileStream dataFileStream = new DataFileStream(new FileInputStream(outputFile), - new GenericDatumReader()); + DataFileStream dataFileStream = + new DataFileStream(new FileInputStream(outputFile), new GenericDatumReader()); assertEquals("stringValue", dataFileStream.getMetaString("stringKey")); assertEquals(100L, dataFileStream.getMetaLong("longKey")); assertArrayEquals("bytesValue".getBytes(), dataFileStream.getMeta("bytesKey")); } - @SuppressWarnings("deprecation") // using AvroCoder#createDatumReader for tests. private void runTestWrite(String[] expectedElements, int numShards) throws IOException { File baseOutputFile = new File(tmpFolder.getRoot(), "prefix"); @@ -488,8 +506,8 @@ private void runTestWrite(String[] expectedElements, int numShards) throws IOExc p.run(); String shardNameTemplate = - firstNonNull(write.getShardTemplate(), - DefaultFilenamePolicy.DEFAULT_UNWINDOWED_SHARD_TEMPLATE); + firstNonNull( + write.getShardTemplate(), DefaultFilenamePolicy.DEFAULT_UNWINDOWED_SHARD_TEMPLATE); assertTestOutputs(expectedElements, numShards, outputFilePrefix, shardNameTemplate); } @@ -517,8 +535,8 @@ public static void assertTestOutputs( for (File outputFile : expectedFiles) { assertTrue("Expected output file " + outputFile.getName(), outputFile.exists()); try (DataFileReader reader = - new DataFileReader<>(outputFile, - new ReflectDatumReader(ReflectData.get().getSchema(String.class)))) { + new DataFileReader<>( + outputFile, new ReflectDatumReader(ReflectData.get().getSchema(String.class)))) { Iterators.addAll(actualElements, reader); } } @@ -560,18 +578,21 @@ public void testPrimitiveReadDisplayData() { AvroIO.readGenericRecords(Schema.create(Schema.Type.STRING)).from("/foo.*"); Set displayData = evaluator.displayDataForPrimitiveSourceTransforms(read); - assertThat("AvroIO.Read should include the file pattern in its primitive transform", - displayData, hasItem(hasDisplayItem("filePattern"))); + assertThat( + "AvroIO.Read should include the file pattern in its primitive transform", + displayData, + hasItem(hasDisplayItem("filePattern"))); } @Test public void testWriteDisplayData() { - AvroIO.Write write = AvroIO.write(GenericClass.class) - .to("/foo") - .withShardNameTemplate("-SS-of-NN-") - .withSuffix("bar") - .withNumShards(100) - .withCodec(CodecFactory.snappyCodec()); + AvroIO.Write write = + AvroIO.write(GenericClass.class) + .to("/foo") + .withShardNameTemplate("-SS-of-NN-") + .withSuffix("bar") + .withNumShards(100) + .withCodec(CodecFactory.snappyCodec()); DisplayData displayData = DisplayData.from(write); diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/FileBasedSinkTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/FileBasedSinkTest.java index 755bb598524d6..b7567785880de 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/FileBasedSinkTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/FileBasedSinkTest.java @@ -48,7 +48,6 @@ import org.apache.beam.sdk.io.FileBasedSink.CompressionType; import org.apache.beam.sdk.io.FileBasedSink.FileResult; import org.apache.beam.sdk.io.FileBasedSink.FilenamePolicy; -import org.apache.beam.sdk.io.FileBasedSink.FilenamePolicy.Context; import org.apache.beam.sdk.io.FileBasedSink.WritableByteChannelFactory; import org.apache.beam.sdk.io.FileBasedSink.WriteOperation; import org.apache.beam.sdk.io.FileBasedSink.Writer; @@ -62,9 +61,7 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; -/** - * Tests for {@link FileBasedSink}. - */ +/** Tests for {@link FileBasedSink}. */ @RunWith(JUnit4.class) public class FileBasedSinkTest { @Rule public TemporaryFolder tmpFolder = new TemporaryFolder(); @@ -87,14 +84,14 @@ private ResourceId getBaseTempDirectory() { } /** - * Writer opens the correct file, writes the header, footer, and elements in the correct - * order, and returns the correct filename. + * Writer opens the correct file, writes the header, footer, and elements in the correct order, + * and returns the correct filename. */ @Test public void testWriter() throws Exception { String testUid = "testId"; - ResourceId expectedTempFile = getBaseTempDirectory() - .resolve(testUid, StandardResolveOptions.RESOLVE_FILE); + ResourceId expectedTempFile = + getBaseTempDirectory().resolve(testUid, StandardResolveOptions.RESOLVE_FILE); List values = Arrays.asList("sympathetic vulture", "boresome hummingbird"); List expected = new ArrayList<>(); expected.add(SimpleSink.SimpleWriter.HEADER); @@ -114,9 +111,7 @@ public void testWriter() throws Exception { assertFileContains(expected, expectedTempFile); } - /** - * Assert that a file contains the lines provided, in the same order as expected. - */ + /** Assert that a file contains the lines provided, in the same order as expected. */ private void assertFileContains(List expected, ResourceId file) throws Exception { try (BufferedReader reader = new BufferedReader(new FileReader(file.toString()))) { List actual = new ArrayList<>(); @@ -140,9 +135,7 @@ private void writeFile(List lines, File file) throws Exception { } } - /** - * Removes temporary files when temporary and output directories differ. - */ + /** Removes temporary files when temporary and output directories differ. */ @Test public void testRemoveWithTempFilename() throws Exception { testRemoveTemporaryFiles(3, getBaseTempDirectory()); @@ -218,7 +211,7 @@ private void runFinalize(SimpleSink.SimpleWriteOperation writeOp, List tem .getSink() .getDynamicDestinations() .getFilenamePolicy(null) - .unwindowedFilename(new Context(i, numFiles), CompressionType.UNCOMPRESSED); + .unwindowedFilename(i, numFiles, CompressionType.UNCOMPRESSED); assertTrue(new File(outputFilename.toString()).exists()); assertFalse(temporaryFiles.get(i).exists()); } @@ -232,8 +225,7 @@ private void runFinalize(SimpleSink.SimpleWriteOperation writeOp, List tem * Create n temporary and output files and verify that removeTemporaryFiles only removes temporary * files. */ - private void testRemoveTemporaryFiles(int numFiles, ResourceId tempDirectory) - throws Exception { + private void testRemoveTemporaryFiles(int numFiles, ResourceId tempDirectory) throws Exception { String prefix = "file"; SimpleSink sink = SimpleSink.makeSimpleSink( @@ -245,8 +237,7 @@ private void testRemoveTemporaryFiles(int numFiles, ResourceId tempDirectory) List temporaryFiles = new ArrayList<>(); List outputFiles = new ArrayList<>(); for (int i = 0; i < numFiles; i++) { - ResourceId tempResource = - WriteOperation.buildTemporaryFilename(tempDirectory, prefix + i); + ResourceId tempResource = WriteOperation.buildTemporaryFilename(tempDirectory, prefix + i); File tmpFile = new File(tempResource.toString()); tmpFile.getParentFile().mkdirs(); assertTrue("not able to create new temp file", tmpFile.createNewFile()); @@ -264,12 +255,9 @@ private void testRemoveTemporaryFiles(int numFiles, ResourceId tempDirectory) for (int i = 0; i < numFiles; i++) { File temporaryFile = temporaryFiles.get(i); assertThat( - String.format("temp file %s exists", temporaryFile), - temporaryFile.exists(), is(false)); + String.format("temp file %s exists", temporaryFile), temporaryFile.exists(), is(false)); File outputFile = outputFiles.get(i); - assertThat( - String.format("output file %s exists", outputFile), - outputFile.exists(), is(true)); + assertThat(String.format("output file %s exists", outputFile), outputFile.exists(), is(true)); } } @@ -279,8 +267,8 @@ public void testCopyToOutputFiles() throws Exception { SimpleSink.SimpleWriteOperation writeOp = buildWriteOperation(); List inputFilenames = Arrays.asList("input-1", "input-2", "input-3"); List inputContents = Arrays.asList("1", "2", "3"); - List expectedOutputFilenames = Arrays.asList( - "file-00-of-03.test", "file-01-of-03.test", "file-02-of-03.test"); + List expectedOutputFilenames = + Arrays.asList("file-00-of-03.test", "file-01-of-03.test", "file-02-of-03.test"); Map inputFilePaths = new HashMap<>(); List expectedOutputPaths = new ArrayList<>(); @@ -301,8 +289,7 @@ public void testCopyToOutputFiles() throws Exception { .getSink() .getDynamicDestinations() .getFilenamePolicy(null) - .unwindowedFilename( - new Context(i, inputFilenames.size()), CompressionType.UNCOMPRESSED)); + .unwindowedFilename(i, inputFilenames.size(), CompressionType.UNCOMPRESSED)); } // Copy input files to output files. @@ -319,16 +306,12 @@ public List generateDestinationFilenames( ResourceId outputDirectory, FilenamePolicy policy, int numFiles) { List filenames = new ArrayList<>(); for (int i = 0; i < numFiles; i++) { - filenames.add( - policy.unwindowedFilename(new Context(i, numFiles), CompressionType.UNCOMPRESSED)); + filenames.add(policy.unwindowedFilename(i, numFiles, CompressionType.UNCOMPRESSED)); } return filenames; } - /** - * Output filenames are generated correctly when an extension is supplied. - */ - + /** Output filenames are generated correctly when an extension is supplied. */ @Test public void testGenerateOutputFilenames() { List expected; @@ -340,17 +323,17 @@ public void testGenerateOutputFilenames() { root, "file", ".SSSSS.of.NNNNN", ".test", CompressionType.UNCOMPRESSED); FilenamePolicy policy = sink.getDynamicDestinations().getFilenamePolicy(null); - expected = Arrays.asList( - root.resolve("file.00000.of.00003.test", StandardResolveOptions.RESOLVE_FILE), - root.resolve("file.00001.of.00003.test", StandardResolveOptions.RESOLVE_FILE), - root.resolve("file.00002.of.00003.test", StandardResolveOptions.RESOLVE_FILE) - ); + expected = + Arrays.asList( + root.resolve("file.00000.of.00003.test", StandardResolveOptions.RESOLVE_FILE), + root.resolve("file.00001.of.00003.test", StandardResolveOptions.RESOLVE_FILE), + root.resolve("file.00002.of.00003.test", StandardResolveOptions.RESOLVE_FILE)); actual = generateDestinationFilenames(root, policy, 3); assertEquals(expected, actual); - expected = Collections.singletonList( - root.resolve("file.00000.of.00001.test", StandardResolveOptions.RESOLVE_FILE) - ); + expected = + Collections.singletonList( + root.resolve("file.00000.of.00001.test", StandardResolveOptions.RESOLVE_FILE)); actual = generateDestinationFilenames(root, policy, 1); assertEquals(expected, actual); @@ -396,17 +379,17 @@ public void testGenerateOutputFilenamesWithoutExtension() { root, "file", "-SSSSS-of-NNNNN", "", CompressionType.UNCOMPRESSED); FilenamePolicy policy = sink.getDynamicDestinations().getFilenamePolicy(null); - expected = Arrays.asList( - root.resolve("file-00000-of-00003", StandardResolveOptions.RESOLVE_FILE), - root.resolve("file-00001-of-00003", StandardResolveOptions.RESOLVE_FILE), - root.resolve("file-00002-of-00003", StandardResolveOptions.RESOLVE_FILE) - ); + expected = + Arrays.asList( + root.resolve("file-00000-of-00003", StandardResolveOptions.RESOLVE_FILE), + root.resolve("file-00001-of-00003", StandardResolveOptions.RESOLVE_FILE), + root.resolve("file-00002-of-00003", StandardResolveOptions.RESOLVE_FILE)); actual = generateDestinationFilenames(root, policy, 3); assertEquals(expected, actual); - expected = Collections.singletonList( - root.resolve("file-00000-of-00001", StandardResolveOptions.RESOLVE_FILE) - ); + expected = + Collections.singletonList( + root.resolve("file-00000-of-00001", StandardResolveOptions.RESOLVE_FILE)); actual = generateDestinationFilenames(root, policy, 1); assertEquals(expected, actual); @@ -479,9 +462,8 @@ private void assertReadValues(final BufferedReader br, String... values) throws } } - private File writeValuesWithWritableByteChannelFactory(final WritableByteChannelFactory factory, - String... values) - throws IOException { + private File writeValuesWithWritableByteChannelFactory( + final WritableByteChannelFactory factory, String... values) throws IOException { final File file = tmpFolder.newFile("test.gz"); final WritableByteChannel channel = factory.create(Channels.newChannel(new FileOutputStream(file))); diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/WriteFilesTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/WriteFilesTest.java index 55f2a87205601..1ca7169baa62a 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/WriteFilesTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/WriteFilesTest.java @@ -70,8 +70,10 @@ import org.apache.beam.sdk.transforms.View; import org.apache.beam.sdk.transforms.display.DisplayData; import org.apache.beam.sdk.transforms.display.DisplayData.Builder; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.FixedWindows; import org.apache.beam.sdk.transforms.windowing.IntervalWindow; +import org.apache.beam.sdk.transforms.windowing.PaneInfo; import org.apache.beam.sdk.transforms.windowing.Sessions; import org.apache.beam.sdk.transforms.windowing.Window; import org.apache.beam.sdk.values.KV; @@ -89,17 +91,12 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; -/** - * Tests for the WriteFiles PTransform. - */ +/** Tests for the WriteFiles PTransform. */ @RunWith(JUnit4.class) public class WriteFilesTest { - @Rule - public TemporaryFolder tmpFolder = new TemporaryFolder(); - @Rule - public final TestPipeline p = TestPipeline.create(); - @Rule - public ExpectedException thrown = ExpectedException.none(); + @Rule public TemporaryFolder tmpFolder = new TemporaryFolder(); + @Rule public final TestPipeline p = TestPipeline.create(); + @Rule public ExpectedException thrown = ExpectedException.none(); @SuppressWarnings("unchecked") // covariant cast private static final PTransform, PCollection> IDENTITY_MAP = @@ -114,12 +111,12 @@ public String apply(String input) { private static final PTransform, PCollectionView> SHARDING_TRANSFORM = - new PTransform, PCollectionView>() { - @Override - public PCollectionView expand(PCollection input) { - return null; - } - }; + new PTransform, PCollectionView>() { + @Override + public PCollectionView expand(PCollection input) { + return null; + } + }; private static class WindowAndReshuffle extends PTransform, PCollection> { private final Window window; @@ -161,18 +158,20 @@ private String appendToTempFolder(String filename) { } private String getBaseOutputFilename() { - return getBaseOutputDirectory() - .resolve("file", StandardResolveOptions.RESOLVE_FILE).toString(); + return getBaseOutputDirectory().resolve("file", StandardResolveOptions.RESOLVE_FILE).toString(); } - /** - * Test a WriteFiles transform with a PCollection of elements. - */ + /** Test a WriteFiles transform with a PCollection of elements. */ @Test @Category(NeedsRunner.class) public void testWrite() throws IOException { - List inputs = Arrays.asList("Critical canary", "Apprehensive eagle", - "Intimidating pigeon", "Pedantic gull", "Frisky finch"); + List inputs = + Arrays.asList( + "Critical canary", + "Apprehensive eagle", + "Intimidating pigeon", + "Pedantic gull", + "Frisky finch"); runWrite( inputs, IDENTITY_MAP, @@ -180,9 +179,7 @@ public void testWrite() throws IOException { WriteFiles.to(makeSimpleSink(), SerializableFunctions.identity())); } - /** - * Test that WriteFiles with an empty input still produces one shard. - */ + /** Test that WriteFiles with an empty input still produces one shard. */ @Test @Category(NeedsRunner.class) public void testEmptyWrite() throws IOException { @@ -191,8 +188,7 @@ public void testEmptyWrite() throws IOException { IDENTITY_MAP, getBaseOutputFilename(), WriteFiles.to(makeSimpleSink(), SerializableFunctions.identity())); - checkFileContents(getBaseOutputFilename(), Collections.emptyList(), - Optional.of(1)); + checkFileContents(getBaseOutputFilename(), Collections.emptyList(), Optional.of(1)); } /** @@ -212,7 +208,6 @@ public void testShardedWrite() throws IOException { private ResourceId getBaseOutputDirectory() { return LocalResources.fromFile(tmpFolder.getRoot(), true) .resolve("output", StandardResolveOptions.RESOLVE_DIRECTORY); - } private SimpleSink makeSimpleSink() { @@ -267,9 +262,7 @@ public void testExpandShardedWrite() throws IOException { .withNumShards(20)); } - /** - * Test a WriteFiles transform with an empty PCollection. - */ + /** Test a WriteFiles transform with an empty PCollection. */ @Test @Category(NeedsRunner.class) public void testWriteWithEmptyPCollection() throws IOException { @@ -281,14 +274,17 @@ public void testWriteWithEmptyPCollection() throws IOException { WriteFiles.to(makeSimpleSink(), SerializableFunctions.identity())); } - /** - * Test a WriteFiles with a windowed PCollection. - */ + /** Test a WriteFiles with a windowed PCollection. */ @Test @Category(NeedsRunner.class) public void testWriteWindowed() throws IOException { - List inputs = Arrays.asList("Critical canary", "Apprehensive eagle", - "Intimidating pigeon", "Pedantic gull", "Frisky finch"); + List inputs = + Arrays.asList( + "Critical canary", + "Apprehensive eagle", + "Intimidating pigeon", + "Pedantic gull", + "Frisky finch"); runWrite( inputs, new WindowAndReshuffle<>(Window.into(FixedWindows.of(Duration.millis(2)))), @@ -296,14 +292,17 @@ public void testWriteWindowed() throws IOException { WriteFiles.to(makeSimpleSink(), SerializableFunctions.identity())); } - /** - * Test a WriteFiles with sessions. - */ + /** Test a WriteFiles with sessions. */ @Test @Category(NeedsRunner.class) public void testWriteWithSessions() throws IOException { - List inputs = Arrays.asList("Critical canary", "Apprehensive eagle", - "Intimidating pigeon", "Pedantic gull", "Frisky finch"); + List inputs = + Arrays.asList( + "Critical canary", + "Apprehensive eagle", + "Intimidating pigeon", + "Pedantic gull", + "Frisky finch"); runWrite( inputs, @@ -589,19 +588,24 @@ public PerWindowFiles(ResourceId baseFilename, String suffix) { public String filenamePrefixForWindow(IntervalWindow window) { String prefix = baseFilename.isDirectory() ? "" : firstNonNull(baseFilename.getFilename(), ""); - return String.format("%s%s-%s", - prefix, FORMATTER.print(window.start()), FORMATTER.print(window.end())); + return String.format( + "%s%s-%s", prefix, FORMATTER.print(window.start()), FORMATTER.print(window.end())); } @Override - public ResourceId windowedFilename(WindowedContext context, OutputFileHints outputFileHints) { - IntervalWindow window = (IntervalWindow) context.getWindow(); + public ResourceId windowedFilename( + int shardNumber, + int numShards, + BoundedWindow window, + PaneInfo paneInfo, + OutputFileHints outputFileHints) { + IntervalWindow intervalWindow = (IntervalWindow) window; String filename = String.format( "%s-%s-of-%s%s%s", - filenamePrefixForWindow(window), - context.getShardNumber(), - context.getNumShards(), + filenamePrefixForWindow(intervalWindow), + shardNumber, + numShards, outputFileHints.getSuggestedFilenameSuffix(), suffix); return baseFilename @@ -610,17 +614,14 @@ public ResourceId windowedFilename(WindowedContext context, OutputFileHints outp } @Override - public ResourceId unwindowedFilename(Context context, OutputFileHints outputFileHints) { + public ResourceId unwindowedFilename( + int shardNumber, int numShards, OutputFileHints outputFileHints) { String prefix = baseFilename.isDirectory() ? "" : firstNonNull(baseFilename.getFilename(), ""); String filename = String.format( "%s-%s-of-%s%s%s", - prefix, - context.getShardNumber(), - context.getNumShards(), - outputFileHints.getSuggestedFilenameSuffix(), - suffix); + prefix, shardNumber, numShards, outputFileHints.getSuggestedFilenameSuffix(), suffix); return baseFilename .getCurrentDirectory() .resolve(filename, StandardResolveOptions.RESOLVE_FILE); @@ -656,12 +657,14 @@ private void runShardedWrite( Optional numShards = (write.getNumShards() != null) - ? Optional.of(write.getNumShards().get()) : Optional.absent(); + ? Optional.of(write.getNumShards().get()) + : Optional.absent(); checkFileContents(baseName, inputs, numShards); } - static void checkFileContents(String baseName, List inputs, - Optional numExpectedShards) throws IOException { + static void checkFileContents( + String baseName, List inputs, Optional numExpectedShards) + throws IOException { List outputFiles = Lists.newArrayList(); final String pattern = baseName + "*"; List metadata = @@ -690,12 +693,11 @@ static void checkFileContents(String baseName, List inputs, assertThat(actual, containsInAnyOrder(inputs.toArray())); } - /** - * Options for test, exposed for PipelineOptionsFactory. - */ + /** Options for test, exposed for PipelineOptionsFactory. */ public interface WriteOptions extends TestPipelineOptions { @Description("Test flag and value") String getTestFlag(); + void setTestFlag(String value); } From a6201ed1488d9ae95637002744bc316f72401e56 Mon Sep 17 00:00:00 2001 From: Stephen Sisk Date: Fri, 16 Jun 2017 11:04:07 -0700 Subject: [PATCH 199/200] JdbcIOIT now uses writeThenRead style --- sdks/java/io/common/pom.xml | 10 + .../apache/beam/sdk/io/common/TestRow.java | 114 ++++++++++ sdks/java/io/jdbc/pom.xml | 10 +- .../org/apache/beam/sdk/io/jdbc/JdbcIOIT.java | 203 +++++++++--------- .../apache/beam/sdk/io/jdbc/JdbcIOTest.java | 115 +++++----- .../beam/sdk/io/jdbc/JdbcTestDataSet.java | 130 ----------- .../beam/sdk/io/jdbc/JdbcTestHelper.java | 81 +++++++ 7 files changed, 377 insertions(+), 286 deletions(-) create mode 100644 sdks/java/io/common/src/test/java/org/apache/beam/sdk/io/common/TestRow.java delete mode 100644 sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcTestDataSet.java create mode 100644 sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcTestHelper.java diff --git a/sdks/java/io/common/pom.xml b/sdks/java/io/common/pom.xml index df0d94bea53c2..1a6f54b81d685 100644 --- a/sdks/java/io/common/pom.xml +++ b/sdks/java/io/common/pom.xml @@ -38,5 +38,15 @@ com.google.guava guava + + com.google.auto.value + auto-value + provided + + + junit + junit + test + diff --git a/sdks/java/io/common/src/test/java/org/apache/beam/sdk/io/common/TestRow.java b/sdks/java/io/common/src/test/java/org/apache/beam/sdk/io/common/TestRow.java new file mode 100644 index 0000000000000..5f0a2fb00b21c --- /dev/null +++ b/sdks/java/io/common/src/test/java/org/apache/beam/sdk/io/common/TestRow.java @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.common; + +import com.google.auto.value.AutoValue; +import com.google.common.collect.ImmutableMap; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import org.apache.beam.sdk.transforms.DoFn; + +/** + * Used to pass values around within test pipelines. + */ +@AutoValue +public abstract class TestRow implements Serializable, Comparable { + /** + * Manually create a test row. + */ + public static TestRow create(Integer id, String name) { + return new AutoValue_TestRow(id, name); + } + + public abstract Integer id(); + public abstract String name(); + + public int compareTo(TestRow other) { + return id().compareTo(other.id()); + } + + /** + * Creates a {@link org.apache.beam.sdk.io.common.TestRow} from the seed value. + */ + public static TestRow fromSeed(Integer seed) { + return create(seed, getNameForSeed(seed)); + } + + /** + * Returns the name field value produced from the given seed. + */ + public static String getNameForSeed(Integer seed) { + return "Testval" + seed; + } + + /** + * Returns a range of {@link org.apache.beam.sdk.io.common.TestRow}s for seed values between + * rangeStart (inclusive) and rangeEnd (exclusive). + */ + public static Iterable getExpectedValues(int rangeStart, int rangeEnd) { + List ret = new ArrayList(rangeEnd - rangeStart + 1); + for (int i = rangeStart; i < rangeEnd; i++) { + ret.add(fromSeed(i)); + } + return ret; + } + + /** + * Uses the input Long values as seeds to produce {@link org.apache.beam.sdk.io.common.TestRow}s. + */ + public static class DeterministicallyConstructTestRowFn extends DoFn { + @ProcessElement + public void processElement(ProcessContext c) { + c.output(fromSeed(c.element().intValue())); + } + } + + /** + * Outputs just the name stored in the {@link org.apache.beam.sdk.io.common.TestRow}. + */ + public static class SelectNameFn extends DoFn { + @ProcessElement + public void processElement(ProcessContext c) { + c.output(c.element().name()); + } + } + + /** + * Precalculated hashes - you can calculate an entry by running HashingFn on + * the name() for the rows generated from seeds in [0, n). + */ + private static final Map EXPECTED_HASHES = ImmutableMap.of( + 1000, "7d94d63a41164be058a9680002914358" + ); + + /** + * Returns the hash value that {@link org.apache.beam.sdk.io.common.HashingFn} will return when it + * is run on {@link org.apache.beam.sdk.io.common.TestRow}s produced by + * getExpectedValues(0, rowCount). + */ + public static String getExpectedHashForRowCount(int rowCount) + throws UnsupportedOperationException { + String hash = EXPECTED_HASHES.get(rowCount); + if (hash == null) { + throw new UnsupportedOperationException("No hash for that row count"); + } + return hash; + } +} diff --git a/sdks/java/io/jdbc/pom.xml b/sdks/java/io/jdbc/pom.xml index 050fc6a5facc2..e5f4d7ed0f03f 100644 --- a/sdks/java/io/jdbc/pom.xml +++ b/sdks/java/io/jdbc/pom.xml @@ -105,11 +105,6 @@ 2.1.1 - - joda-time - joda-time - - com.google.auto.value @@ -168,5 +163,10 @@ test tests + + org.apache.beam + beam-sdks-java-io-common + test + diff --git a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOIT.java b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOIT.java index e8ffad6c56bb3..32d6d9e80b41f 100644 --- a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOIT.java +++ b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOIT.java @@ -17,41 +17,39 @@ */ package org.apache.beam.sdk.io.jdbc; -import java.sql.Connection; -import java.sql.PreparedStatement; -import java.sql.ResultSet; import java.sql.SQLException; -import java.sql.Statement; -import java.util.ArrayList; +import java.text.ParseException; import java.util.List; -import org.apache.beam.sdk.coders.BigEndianIntegerCoder; -import org.apache.beam.sdk.coders.KvCoder; -import org.apache.beam.sdk.coders.StringUtf8Coder; + +import org.apache.beam.sdk.coders.SerializableCoder; +import org.apache.beam.sdk.io.GenerateSequence; +import org.apache.beam.sdk.io.common.HashingFn; import org.apache.beam.sdk.io.common.IOTestPipelineOptions; +import org.apache.beam.sdk.io.common.TestRow; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Combine; import org.apache.beam.sdk.transforms.Count; -import org.apache.beam.sdk.transforms.Create; -import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.Top; import org.apache.beam.sdk.values.PCollection; import org.junit.AfterClass; -import org.junit.Assert; import org.junit.BeforeClass; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.postgresql.ds.PGSimpleDataSource; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * A test of {@link org.apache.beam.sdk.io.jdbc.JdbcIO} on an independent Postgres instance. * - *

    This test requires a running instance of Postgres, and the test dataset must exist in the - * database. `JdbcTestDataSet` will create the read table. - * - *

    You can run this test by doing the following: + *

    This test requires a running instance of Postgres. Pass in connection information using + * PipelineOptions: *

      *  mvn -e -Pio-it verify -pl sdks/java/io/jdbc -DintegrationTestPipelineOptions='[
      *  "--postgresServerName=1.2.3.4",
    @@ -67,112 +65,123 @@
      */
     @RunWith(JUnit4.class)
     public class JdbcIOIT {
    +  private static final Logger LOG = LoggerFactory.getLogger(JdbcIOIT.class);
    +  public static final int EXPECTED_ROW_COUNT = 1000;
       private static PGSimpleDataSource dataSource;
    -  private static String writeTableName;
    +  private static String tableName;
    +
    +  @Rule
    +  public TestPipeline pipelineWrite = TestPipeline.create();
    +  @Rule
    +  public TestPipeline pipelineRead = TestPipeline.create();
     
       @BeforeClass
    -  public static void setup() throws SQLException {
    +  public static void setup() throws SQLException, ParseException {
         PipelineOptionsFactory.register(IOTestPipelineOptions.class);
         IOTestPipelineOptions options = TestPipeline.testingPipelineOptions()
             .as(IOTestPipelineOptions.class);
     
    -    // We do dataSource set up in BeforeClass rather than Before since we don't need to create a new
    -    // dataSource for each test.
    -    dataSource = JdbcTestDataSet.getDataSource(options);
    -  }
    +    dataSource = getDataSource(options);
     
    -  @AfterClass
    -  public static void tearDown() throws SQLException {
    -    // Only do write table clean up once for the class since we don't want to clean up after both
    -    // read and write tests, only want to do it once after all the tests are done.
    -    JdbcTestDataSet.cleanUpDataTable(dataSource, writeTableName);
    +    tableName = JdbcTestHelper.getTableName("IT");
    +    JdbcTestHelper.createDataTable(dataSource, tableName);
       }
     
    -  private static class CreateKVOfNameAndId implements JdbcIO.RowMapper> {
    -    @Override
    -    public KV mapRow(ResultSet resultSet) throws Exception {
    -      KV kv =
    -          KV.of(resultSet.getString("name"), resultSet.getInt("id"));
    -      return kv;
    -    }
    -  }
    +  private static PGSimpleDataSource getDataSource(IOTestPipelineOptions options)
    +      throws SQLException {
    +    PGSimpleDataSource dataSource = new PGSimpleDataSource();
     
    -  private static class PutKeyInColumnOnePutValueInColumnTwo
    -      implements JdbcIO.PreparedStatementSetter> {
    -    @Override
    -    public void setParameters(KV element, PreparedStatement statement)
    -                    throws SQLException {
    -      statement.setInt(1, element.getKey());
    -      statement.setString(2, element.getValue());
    -    }
    +    dataSource.setDatabaseName(options.getPostgresDatabaseName());
    +    dataSource.setServerName(options.getPostgresServerName());
    +    dataSource.setPortNumber(options.getPostgresPort());
    +    dataSource.setUser(options.getPostgresUsername());
    +    dataSource.setPassword(options.getPostgresPassword());
    +    dataSource.setSsl(options.getPostgresSsl());
    +
    +    return dataSource;
       }
     
    -  @Rule
    -  public TestPipeline pipeline = TestPipeline.create();
    +  @AfterClass
    +  public static void tearDown() throws SQLException {
    +    JdbcTestHelper.cleanUpDataTable(dataSource, tableName);
    +  }
     
       /**
    -   * Does a test read of a few rows from a postgres database.
    -   *
    -   * 

    Note that IT read tests must not do any data table manipulation (setup/clean up.) - * @throws SQLException + * Tests writing then reading data for a postgres database. */ @Test - public void testRead() throws SQLException { - String writeTableName = JdbcTestDataSet.READ_TABLE_NAME; + public void testWriteThenRead() { + runWrite(); + runRead(); + } - PCollection> output = pipeline.apply(JdbcIO.>read() + /** + * Writes the test dataset to postgres. + * + *

    This method does not attempt to validate the data - we do so in the read test. This does + * make it harder to tell whether a test failed in the write or read phase, but the tests are much + * easier to maintain (don't need any separate code to write test data for read tests to + * the database.) + */ + private void runWrite() { + pipelineWrite.apply(GenerateSequence.from(0).to((long) EXPECTED_ROW_COUNT)) + .apply(ParDo.of(new TestRow.DeterministicallyConstructTestRowFn())) + .apply(JdbcIO.write() .withDataSourceConfiguration(JdbcIO.DataSourceConfiguration.create(dataSource)) - .withQuery("select name,id from " + writeTableName) - .withRowMapper(new CreateKVOfNameAndId()) - .withCoder(KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of()))); - - // TODO: validate actual contents of rows, not just count. - PAssert.thatSingleton( - output.apply("Count All", Count.>globally())) - .isEqualTo(1000L); + .withStatement(String.format("insert into %s values(?, ?)", tableName)) + .withPreparedStatementSetter(new JdbcTestHelper.PrepareStatementFromTestRow())); - List> expectedCounts = new ArrayList<>(); - for (String scientist : JdbcTestDataSet.SCIENTISTS) { - expectedCounts.add(KV.of(scientist, 100L)); - } - PAssert.that(output.apply("Count Scientist", Count.perKey())) - .containsInAnyOrder(expectedCounts); - - pipeline.run().waitUntilFinish(); + pipelineWrite.run().waitUntilFinish(); } /** - * Tests writes to a postgres database. + * Read the test dataset from postgres and validate its contents. + * + *

    When doing the validation, we wish to ensure that we: + * 1. Ensure *all* the rows are correct + * 2. Provide enough information in assertions such that it is easy to spot obvious errors (e.g. + * all elements have a similar mistake, or "only 5 elements were generated" and the user wants + * to see what the problem was. * - *

    Write Tests must clean up their data - in this case, it uses a new table every test run so - * that it won't interfere with read tests/other write tests. It uses finally to attempt to - * clean up data at the end of the test run. - * @throws SQLException + *

    We do not wish to generate and compare all of the expected values, so this method uses + * hashing to ensure that all expected data is present. However, hashing does not provide easy + * debugging information (failures like "every element was empty string" are hard to see), + * so we also: + * 1. Generate expected values for the first and last 500 rows + * 2. Use containsInAnyOrder to verify that their values are correct. + * Where first/last 500 rows is determined by the fact that we know all rows have a unique id - we + * can use the natural ordering of that key. */ - @Test - public void testWrite() throws SQLException { - writeTableName = JdbcTestDataSet.createWriteDataTable(dataSource); - - ArrayList> data = new ArrayList<>(); - for (int i = 0; i < 1000; i++) { - KV kv = KV.of(i, "Test"); - data.add(kv); - } - pipeline.apply(Create.of(data)) - .apply(JdbcIO.>write() - .withDataSourceConfiguration(JdbcIO.DataSourceConfiguration.create(dataSource)) - .withStatement(String.format("insert into %s values(?, ?)", writeTableName)) - .withPreparedStatementSetter(new PutKeyInColumnOnePutValueInColumnTwo())); - - pipeline.run().waitUntilFinish(); - - try (Connection connection = dataSource.getConnection(); - Statement statement = connection.createStatement(); - ResultSet resultSet = statement.executeQuery("select count(*) from " + writeTableName)) { - resultSet.next(); - int count = resultSet.getInt(1); - Assert.assertEquals(2000, count); - } - // TODO: Actually verify contents of the rows. + private void runRead() { + PCollection namesAndIds = + pipelineRead.apply(JdbcIO.read() + .withDataSourceConfiguration(JdbcIO.DataSourceConfiguration.create(dataSource)) + .withQuery(String.format("select name,id from %s;", tableName)) + .withRowMapper(new JdbcTestHelper.CreateTestRowOfNameAndId()) + .withCoder(SerializableCoder.of(TestRow.class))); + + PAssert.thatSingleton( + namesAndIds.apply("Count All", Count.globally())) + .isEqualTo((long) EXPECTED_ROW_COUNT); + + PCollection consolidatedHashcode = namesAndIds + .apply(ParDo.of(new TestRow.SelectNameFn())) + .apply("Hash row contents", Combine.globally(new HashingFn()).withoutDefaults()); + PAssert.that(consolidatedHashcode) + .containsInAnyOrder(TestRow.getExpectedHashForRowCount(EXPECTED_ROW_COUNT)); + + PCollection> frontOfList = + namesAndIds.apply(Top.smallest(500)); + Iterable expectedFrontOfList = TestRow.getExpectedValues(0, 500); + PAssert.thatSingletonIterable(frontOfList).containsInAnyOrder(expectedFrontOfList); + + PCollection> backOfList = + namesAndIds.apply(Top.largest(500)); + Iterable expectedBackOfList = + TestRow.getExpectedValues(EXPECTED_ROW_COUNT - 500, + EXPECTED_ROW_COUNT); + PAssert.thatSingletonIterable(backOfList).containsInAnyOrder(expectedBackOfList); + + pipelineRead.run().waitUntilFinish(); } } diff --git a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java index 984ce1ac78a31..4ea18ef8d45a8 100644 --- a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java +++ b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java @@ -17,7 +17,6 @@ */ package org.apache.beam.sdk.io.jdbc; -import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import java.io.PrintWriter; @@ -28,18 +27,22 @@ import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.ResultSet; +import java.sql.SQLException; import java.sql.Statement; import java.util.ArrayList; -import org.apache.beam.sdk.coders.BigEndianIntegerCoder; +import java.util.Collections; +import javax.sql.DataSource; + import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.coders.SerializableCoder; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.coders.VarIntCoder; +import org.apache.beam.sdk.io.common.TestRow; import org.apache.beam.sdk.testing.NeedsRunner; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.Count; import org.apache.beam.sdk.transforms.Create; -import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.derby.drda.NetworkServerControl; @@ -58,11 +61,13 @@ */ public class JdbcIOTest implements Serializable { private static final Logger LOG = LoggerFactory.getLogger(JdbcIOTest.class); + public static final int EXPECTED_ROW_COUNT = 1000; private static NetworkServerControl derbyServer; private static ClientDataSource dataSource; private static int port; + private static String readTableName; @Rule public final transient TestPipeline pipeline = TestPipeline.create(); @@ -108,14 +113,16 @@ public static void startDatabase() throws Exception { dataSource.setServerName("localhost"); dataSource.setPortNumber(port); + readTableName = JdbcTestHelper.getTableName("UT_READ"); - JdbcTestDataSet.createReadDataTable(dataSource); + JdbcTestHelper.createDataTable(dataSource, readTableName); + addInitialData(dataSource, readTableName); } @AfterClass public static void shutDownDatabase() throws Exception { try { - JdbcTestDataSet.cleanUpDataTable(dataSource, JdbcTestDataSet.READ_TABLE_NAME); + JdbcTestHelper.cleanUpDataTable(dataSource, readTableName); } finally { if (derbyServer != null) { derbyServer.shutdown(); @@ -177,39 +184,43 @@ public void testDataSourceConfigurationNullUsernameAndPassword() throws Exceptio } } + /** + * Create test data that is consistent with that generated by TestRow. + */ + private static void addInitialData(DataSource dataSource, String tableName) + throws SQLException { + try (Connection connection = dataSource.getConnection()) { + connection.setAutoCommit(false); + try (PreparedStatement preparedStatement = + connection.prepareStatement( + String.format("insert into %s values (?,?)", tableName))) { + for (int i = 0; i < EXPECTED_ROW_COUNT; i++) { + preparedStatement.clearParameters(); + preparedStatement.setInt(1, i); + preparedStatement.setString(2, TestRow.getNameForSeed(i)); + preparedStatement.executeUpdate(); + } + } + connection.commit(); + } + } + @Test @Category(NeedsRunner.class) public void testRead() throws Exception { - - PCollection> output = pipeline.apply( - JdbcIO.>read() + PCollection rows = pipeline.apply( + JdbcIO.read() .withDataSourceConfiguration(JdbcIO.DataSourceConfiguration.create(dataSource)) - .withQuery("select name,id from " + JdbcTestDataSet.READ_TABLE_NAME) - .withRowMapper(new JdbcIO.RowMapper>() { - @Override - public KV mapRow(ResultSet resultSet) throws Exception { - KV kv = - KV.of(resultSet.getString("name"), resultSet.getInt("id")); - return kv; - } - }) - .withCoder(KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of()))); + .withQuery("select name,id from " + readTableName) + .withRowMapper(new JdbcTestHelper.CreateTestRowOfNameAndId()) + .withCoder(SerializableCoder.of(TestRow.class))); PAssert.thatSingleton( - output.apply("Count All", Count.>globally())) - .isEqualTo(1000L); - - PAssert.that(output - .apply("Count Scientist", Count.perKey()) - ).satisfies(new SerializableFunction>, Void>() { - @Override - public Void apply(Iterable> input) { - for (KV element : input) { - assertEquals(element.getKey(), 100L, element.getValue().longValue()); - } - return null; - } - }); + rows.apply("Count All", Count.globally())) + .isEqualTo((long) EXPECTED_ROW_COUNT); + + Iterable expectedValues = TestRow.getExpectedValues(0, EXPECTED_ROW_COUNT); + PAssert.that(rows).containsInAnyOrder(expectedValues); pipeline.run(); } @@ -217,32 +228,27 @@ public Void apply(Iterable> input) { @Test @Category(NeedsRunner.class) public void testReadWithSingleStringParameter() throws Exception { - - PCollection> output = pipeline.apply( - JdbcIO.>read() + PCollection rows = pipeline.apply( + JdbcIO.read() .withDataSourceConfiguration(JdbcIO.DataSourceConfiguration.create(dataSource)) .withQuery(String.format("select name,id from %s where name = ?", - JdbcTestDataSet.READ_TABLE_NAME)) + readTableName)) .withStatementPreparator(new JdbcIO.StatementPreparator() { @Override public void setParameters(PreparedStatement preparedStatement) - throws Exception { - preparedStatement.setString(1, "Darwin"); - } - }) - .withRowMapper(new JdbcIO.RowMapper>() { - @Override - public KV mapRow(ResultSet resultSet) throws Exception { - KV kv = - KV.of(resultSet.getString("name"), resultSet.getInt("id")); - return kv; + throws Exception { + preparedStatement.setString(1, TestRow.getNameForSeed(1)); } }) - .withCoder(KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of()))); + .withRowMapper(new JdbcTestHelper.CreateTestRowOfNameAndId()) + .withCoder(SerializableCoder.of(TestRow.class))); PAssert.thatSingleton( - output.apply("Count One Scientist", Count.>globally())) - .isEqualTo(100L); + rows.apply("Count All", Count.globally())) + .isEqualTo(1L); + + Iterable expectedValues = Collections.singletonList(TestRow.fromSeed(1)); + PAssert.that(rows).containsInAnyOrder(expectedValues); pipeline.run(); } @@ -250,11 +256,13 @@ public KV mapRow(ResultSet resultSet) throws Exception { @Test @Category(NeedsRunner.class) public void testWrite() throws Exception { + final long rowsToAdd = 1000L; - String tableName = JdbcTestDataSet.createWriteDataTable(dataSource); + String tableName = JdbcTestHelper.getTableName("UT_WRITE"); + JdbcTestHelper.createDataTable(dataSource, tableName); try { ArrayList> data = new ArrayList<>(); - for (int i = 0; i < 1000; i++) { + for (int i = 0; i < rowsToAdd; i++) { KV kv = KV.of(i, "Test"); data.add(kv); } @@ -282,19 +290,18 @@ public void setParameters( resultSet.next(); int count = resultSet.getInt(1); - Assert.assertEquals(2000, count); + Assert.assertEquals(EXPECTED_ROW_COUNT, count); } } } } finally { - JdbcTestDataSet.cleanUpDataTable(dataSource, tableName); + JdbcTestHelper.cleanUpDataTable(dataSource, tableName); } } @Test @Category(NeedsRunner.class) public void testWriteWithEmptyPCollection() throws Exception { - pipeline .apply(Create.empty(KvCoder.of(VarIntCoder.of(), StringUtf8Coder.of()))) .apply(JdbcIO.>write() diff --git a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcTestDataSet.java b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcTestDataSet.java deleted file mode 100644 index 0b88be26af288..0000000000000 --- a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcTestDataSet.java +++ /dev/null @@ -1,130 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.jdbc; - -import java.sql.Connection; -import java.sql.PreparedStatement; -import java.sql.SQLException; -import java.sql.Statement; -import javax.sql.DataSource; -import org.apache.beam.sdk.io.common.IOTestPipelineOptions; -import org.apache.beam.sdk.options.PipelineOptionsFactory; -import org.postgresql.ds.PGSimpleDataSource; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * Manipulates test data used by the {@link org.apache.beam.sdk.io.jdbc.JdbcIO} tests. - * - *

    This is independent from the tests so that for read tests it can be run separately after data - * store creation rather than every time (which can be more fragile.) - */ -public class JdbcTestDataSet { - private static final Logger LOG = LoggerFactory.getLogger(JdbcTestDataSet.class); - public static final String[] SCIENTISTS = {"Einstein", "Darwin", "Copernicus", "Pasteur", "Curie", - "Faraday", "McClintock", "Herschel", "Hopper", "Lovelace"}; - /** - * Use this to create the read tables before IT read tests. - * - *

    To invoke this class, you can use this command line: - * (run from the jdbc root directory) - * mvn test-compile exec:java -Dexec.mainClass=org.apache.beam.sdk.io.jdbc.JdbcTestDataSet \ - * -Dexec.args="--postgresServerName=127.0.0.1 --postgresUsername=postgres \ - * --postgresDatabaseName=myfancydb \ - * --postgresPassword=yourpassword --postgresSsl=false" \ - * -Dexec.classpathScope=test - * @param args Please pass options from IOTestPipelineOptions used for connection to postgres as - * shown above. - */ - public static void main(String[] args) throws SQLException { - PipelineOptionsFactory.register(IOTestPipelineOptions.class); - IOTestPipelineOptions options = - PipelineOptionsFactory.fromArgs(args).as(IOTestPipelineOptions.class); - - createReadDataTable(getDataSource(options)); - } - - public static PGSimpleDataSource getDataSource(IOTestPipelineOptions options) - throws SQLException { - PGSimpleDataSource dataSource = new PGSimpleDataSource(); - - // Tests must receive parameters for connections from PipelineOptions - // Parameters should be generic to all tests that use a particular datasource, not - // the particular test. - dataSource.setDatabaseName(options.getPostgresDatabaseName()); - dataSource.setServerName(options.getPostgresServerName()); - dataSource.setPortNumber(options.getPostgresPort()); - dataSource.setUser(options.getPostgresUsername()); - dataSource.setPassword(options.getPostgresPassword()); - dataSource.setSsl(options.getPostgresSsl()); - - return dataSource; - } - - public static final String READ_TABLE_NAME = "BEAM_TEST_READ"; - - public static void createReadDataTable(DataSource dataSource) throws SQLException { - createDataTable(dataSource, READ_TABLE_NAME); - } - - public static String createWriteDataTable(DataSource dataSource) throws SQLException { - String tableName = "BEAMTEST" + org.joda.time.Instant.now().getMillis(); - createDataTable(dataSource, tableName); - return tableName; - } - - private static void createDataTable(DataSource dataSource, String tableName) throws SQLException { - try (Connection connection = dataSource.getConnection()) { - // something like this will need to happen in tests on a newly created postgres server, - // but likely it will happen in perfkit, not here - // alternatively, we may have a pipelineoption indicating whether we want to - // re-use the database or create a new one - try (Statement statement = connection.createStatement()) { - statement.execute( - String.format("create table %s (id INT, name VARCHAR(500))", tableName)); - } - - connection.setAutoCommit(false); - try (PreparedStatement preparedStatement = - connection.prepareStatement( - String.format("insert into %s values (?,?)", tableName))) { - for (int i = 0; i < 1000; i++) { - int index = i % SCIENTISTS.length; - preparedStatement.clearParameters(); - preparedStatement.setInt(1, i); - preparedStatement.setString(2, SCIENTISTS[index]); - preparedStatement.executeUpdate(); - } - } - connection.commit(); - } - - LOG.info("Created table {}", tableName); - } - - public static void cleanUpDataTable(DataSource dataSource, String tableName) - throws SQLException { - if (tableName != null) { - try (Connection connection = dataSource.getConnection(); - Statement statement = connection.createStatement()) { - statement.executeUpdate(String.format("drop table %s", tableName)); - } - } - } - -} diff --git a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcTestHelper.java b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcTestHelper.java new file mode 100644 index 0000000000000..fedae510ea24f --- /dev/null +++ b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcTestHelper.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.jdbc; + +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; +import java.text.ParseException; +import java.text.SimpleDateFormat; +import java.util.Date; +import javax.sql.DataSource; +import org.apache.beam.sdk.io.common.TestRow; + +/** + * Contains Test helper methods used by both Integration and Unit Tests in + * {@link org.apache.beam.sdk.io.jdbc.JdbcIO}. + */ +class JdbcTestHelper { + static String getTableName(String testIdentifier) throws ParseException { + SimpleDateFormat formatter = new SimpleDateFormat(); + formatter.applyPattern("yyyy_MM_dd_HH_mm_ss_S"); + return String.format("BEAMTEST_%s_%s", testIdentifier, formatter.format(new Date())); + } + + static void createDataTable( + DataSource dataSource, String tableName) + throws SQLException { + try (Connection connection = dataSource.getConnection()) { + try (Statement statement = connection.createStatement()) { + statement.execute( + String.format("create table %s (id INT, name VARCHAR(500))", tableName)); + } + } + } + + static void cleanUpDataTable(DataSource dataSource, String tableName) + throws SQLException { + if (tableName != null) { + try (Connection connection = dataSource.getConnection(); + Statement statement = connection.createStatement()) { + statement.executeUpdate(String.format("drop table %s", tableName)); + } + } + } + + static class CreateTestRowOfNameAndId implements JdbcIO.RowMapper { + @Override + public TestRow mapRow(ResultSet resultSet) throws Exception { + return TestRow.create( + resultSet.getInt("id"), resultSet.getString("name")); + } + } + + static class PrepareStatementFromTestRow + implements JdbcIO.PreparedStatementSetter { + @Override + public void setParameters(TestRow element, PreparedStatement statement) + throws SQLException { + statement.setLong(1, element.id()); + statement.setString(2, element.name()); + } + } + +} From eb951c2e161294510d5a23f7c641592b0a8be087 Mon Sep 17 00:00:00 2001 From: Sourabh Bajaj Date: Thu, 13 Jul 2017 12:02:31 -0700 Subject: [PATCH 200/200] [BEAM-2595] Allow table schema objects in BQ DoFn --- sdks/python/apache_beam/io/gcp/bigquery.py | 100 ++++++++++++++--- .../apache_beam/io/gcp/bigquery_test.py | 105 ++++++++++++++++-- 2 files changed, 180 insertions(+), 25 deletions(-) diff --git a/sdks/python/apache_beam/io/gcp/bigquery.py b/sdks/python/apache_beam/io/gcp/bigquery.py index da8be68000521..23fd31036435e 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery.py +++ b/sdks/python/apache_beam/io/gcp/bigquery.py @@ -1191,22 +1191,20 @@ def __init__(self, table_id, dataset_id, project_id, batch_size, schema, @staticmethod def get_table_schema(schema): - # Transform the table schema into a bigquery.TableSchema instance. - if isinstance(schema, basestring): - table_schema = bigquery.TableSchema() - schema_list = [s.strip() for s in schema.split(',')] - for field_and_type in schema_list: - field_name, field_type = field_and_type.split(':') - field_schema = bigquery.TableFieldSchema() - field_schema.name = field_name - field_schema.type = field_type - field_schema.mode = 'NULLABLE' - table_schema.fields.append(field_schema) - return table_schema - elif schema is None: - return schema - elif isinstance(schema, bigquery.TableSchema): + """Transform the table schema into a bigquery.TableSchema instance. + + Args: + schema: The schema to be used if the BigQuery table to write has to be + created. This is a dictionary object created in the WriteToBigQuery + transform. + Returns: + table_schema: The schema to be used if the BigQuery table to write has + to be created but in the bigquery.TableSchema format. + """ + if schema is None: return schema + elif isinstance(schema, dict): + return parse_table_schema_from_json(json.dumps(schema)) else: raise TypeError('Unexpected schema argument: %s.' % schema) @@ -1289,13 +1287,83 @@ def __init__(self, table, dataset=None, project=None, schema=None, self.batch_size = batch_size self.test_client = test_client + @staticmethod + def get_table_schema_from_string(schema): + """Transform the string table schema into a bigquery.TableSchema instance. + + Args: + schema: The sting schema to be used if the BigQuery table to write has + to be created. + Returns: + table_schema: The schema to be used if the BigQuery table to write has + to be created but in the bigquery.TableSchema format. + """ + table_schema = bigquery.TableSchema() + schema_list = [s.strip() for s in schema.split(',')] + for field_and_type in schema_list: + field_name, field_type = field_and_type.split(':') + field_schema = bigquery.TableFieldSchema() + field_schema.name = field_name + field_schema.type = field_type + field_schema.mode = 'NULLABLE' + table_schema.fields.append(field_schema) + return table_schema + + @staticmethod + def table_schema_to_dict(table_schema): + """Create a dictionary representation of table schema for serialization + """ + def get_table_field(field): + """Create a dictionary representation of a table field + """ + result = {} + result['name'] = field.name + result['type'] = field.type + result['mode'] = getattr(field, 'mode', 'NULLABLE') + if hasattr(field, 'description') and field.description is not None: + result['description'] = field.description + if hasattr(field, 'fields') and field.fields: + result['fields'] = [get_table_field(f) for f in field.fields] + return result + + if not isinstance(table_schema, bigquery.TableSchema): + raise ValueError("Table schema must be of the type bigquery.TableSchema") + schema = {'fields': []} + for field in table_schema.fields: + schema['fields'].append(get_table_field(field)) + return schema + + @staticmethod + def get_dict_table_schema(schema): + """Transform the table schema into a dictionary instance. + + Args: + schema: The schema to be used if the BigQuery table to write has to be + created. This can either be a dict or string or in the TableSchema + format. + Returns: + table_schema: The schema to be used if the BigQuery table to write has + to be created but in the dictionary format. + """ + if isinstance(schema, dict): + return schema + elif schema is None: + return schema + elif isinstance(schema, basestring): + table_schema = WriteToBigQuery.get_table_schema_from_string(schema) + return WriteToBigQuery.table_schema_to_dict(table_schema) + elif isinstance(schema, bigquery.TableSchema): + return WriteToBigQuery.table_schema_to_dict(schema) + else: + raise TypeError('Unexpected schema argument: %s.' % schema) + def expand(self, pcoll): bigquery_write_fn = BigQueryWriteFn( table_id=self.table_reference.tableId, dataset_id=self.table_reference.datasetId, project_id=self.table_reference.projectId, batch_size=self.batch_size, - schema=self.schema, + schema=self.get_dict_table_schema(self.schema), create_disposition=self.create_disposition, write_disposition=self.write_disposition, client=self.test_client) diff --git a/sdks/python/apache_beam/io/gcp/bigquery_test.py b/sdks/python/apache_beam/io/gcp/bigquery_test.py index b7f766bbcc859..14247bad8cbe5 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_test.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_test.py @@ -834,12 +834,15 @@ def test_dofn_client_start_bundle_called(self): projectId='project_id', datasetId='dataset_id', tableId='table_id')) create_disposition = beam.io.BigQueryDisposition.CREATE_NEVER write_disposition = beam.io.BigQueryDisposition.WRITE_APPEND + schema = {'fields': [ + {'name': 'month', 'type': 'INTEGER', 'mode': 'NULLABLE'}]} + fn = beam.io.gcp.bigquery.BigQueryWriteFn( table_id='table_id', dataset_id='dataset_id', project_id='project_id', batch_size=2, - schema='month:INTEGER', + schema=schema, create_disposition=create_disposition, write_disposition=write_disposition, client=client) @@ -855,13 +858,15 @@ def test_dofn_client_start_bundle_create_called(self): projectId='project_id', datasetId='dataset_id', tableId='table_id')) create_disposition = beam.io.BigQueryDisposition.CREATE_NEVER write_disposition = beam.io.BigQueryDisposition.WRITE_APPEND + schema = {'fields': [ + {'name': 'month', 'type': 'INTEGER', 'mode': 'NULLABLE'}]} fn = beam.io.gcp.bigquery.BigQueryWriteFn( table_id='table_id', dataset_id='dataset_id', project_id='project_id', batch_size=2, - schema='month:INTEGER', + schema=schema, create_disposition=create_disposition, write_disposition=write_disposition, client=client) @@ -879,13 +884,15 @@ def test_dofn_client_process_performs_batching(self): bigquery.TableDataInsertAllResponse(insertErrors=[]) create_disposition = beam.io.BigQueryDisposition.CREATE_NEVER write_disposition = beam.io.BigQueryDisposition.WRITE_APPEND + schema = {'fields': [ + {'name': 'month', 'type': 'INTEGER', 'mode': 'NULLABLE'}]} fn = beam.io.gcp.bigquery.BigQueryWriteFn( table_id='table_id', dataset_id='dataset_id', project_id='project_id', batch_size=2, - schema='month:INTEGER', + schema=schema, create_disposition=create_disposition, write_disposition=write_disposition, client=client) @@ -906,13 +913,15 @@ def test_dofn_client_process_flush_called(self): bigquery.TableDataInsertAllResponse(insertErrors=[])) create_disposition = beam.io.BigQueryDisposition.CREATE_NEVER write_disposition = beam.io.BigQueryDisposition.WRITE_APPEND + schema = {'fields': [ + {'name': 'month', 'type': 'INTEGER', 'mode': 'NULLABLE'}]} fn = beam.io.gcp.bigquery.BigQueryWriteFn( table_id='table_id', dataset_id='dataset_id', project_id='project_id', batch_size=2, - schema='month:INTEGER', + schema=schema, create_disposition=create_disposition, write_disposition=write_disposition, client=client) @@ -933,13 +942,15 @@ def test_dofn_client_finish_bundle_flush_called(self): bigquery.TableDataInsertAllResponse(insertErrors=[]) create_disposition = beam.io.BigQueryDisposition.CREATE_NEVER write_disposition = beam.io.BigQueryDisposition.WRITE_APPEND + schema = {'fields': [ + {'name': 'month', 'type': 'INTEGER', 'mode': 'NULLABLE'}]} fn = beam.io.gcp.bigquery.BigQueryWriteFn( table_id='table_id', dataset_id='dataset_id', project_id='project_id', batch_size=2, - schema='month:INTEGER', + schema=schema, create_disposition=create_disposition, write_disposition=write_disposition, client=client) @@ -964,13 +975,15 @@ def test_dofn_client_no_records(self): bigquery.TableDataInsertAllResponse(insertErrors=[]) create_disposition = beam.io.BigQueryDisposition.CREATE_NEVER write_disposition = beam.io.BigQueryDisposition.WRITE_APPEND + schema = {'fields': [ + {'name': 'month', 'type': 'INTEGER', 'mode': 'NULLABLE'}]} fn = beam.io.gcp.bigquery.BigQueryWriteFn( table_id='table_id', dataset_id='dataset_id', project_id='project_id', batch_size=2, - schema='month:INTEGER', + schema=schema, create_disposition=create_disposition, write_disposition=write_disposition, client=client) @@ -984,17 +997,91 @@ def test_dofn_client_no_records(self): # InsertRows not called in finish bundle as no records self.assertFalse(client.tabledata.InsertAll.called) - def test_simple_schema_parsing(self): + def test_noop_schema_parsing(self): + expected_table_schema = None table_schema = beam.io.gcp.bigquery.BigQueryWriteFn.get_table_schema( - schema='s:STRING, n:INTEGER') + schema=None) + self.assertEqual(expected_table_schema, table_schema) + + def test_dict_schema_parsing(self): + schema = {'fields': [ + {'name': 's', 'type': 'STRING', 'mode': 'NULLABLE'}, + {'name': 'n', 'type': 'INTEGER', 'mode': 'NULLABLE'}, + {'name': 'r', 'type': 'RECORD', 'mode': 'NULLABLE', 'fields': [ + {'name': 'x', 'type': 'INTEGER', 'mode': 'NULLABLE'}]}]} + table_schema = beam.io.gcp.bigquery.BigQueryWriteFn.get_table_schema(schema) string_field = bigquery.TableFieldSchema( name='s', type='STRING', mode='NULLABLE') + nested_field = bigquery.TableFieldSchema( + name='x', type='INTEGER', mode='NULLABLE') number_field = bigquery.TableFieldSchema( name='n', type='INTEGER', mode='NULLABLE') + record_field = bigquery.TableFieldSchema( + name='r', type='RECORD', mode='NULLABLE', fields=[nested_field]) expected_table_schema = bigquery.TableSchema( - fields=[string_field, number_field]) + fields=[string_field, number_field, record_field]) self.assertEqual(expected_table_schema, table_schema) + def test_string_schema_parsing(self): + schema = 's:STRING, n:INTEGER' + expected_dict_schema = {'fields': [ + {'name': 's', 'type': 'STRING', 'mode': 'NULLABLE'}, + {'name': 'n', 'type': 'INTEGER', 'mode': 'NULLABLE'}]} + dict_schema = ( + beam.io.gcp.bigquery.WriteToBigQuery.get_dict_table_schema(schema)) + self.assertEqual(expected_dict_schema, dict_schema) + + def test_table_schema_parsing(self): + string_field = bigquery.TableFieldSchema( + name='s', type='STRING', mode='NULLABLE') + nested_field = bigquery.TableFieldSchema( + name='x', type='INTEGER', mode='NULLABLE') + number_field = bigquery.TableFieldSchema( + name='n', type='INTEGER', mode='NULLABLE') + record_field = bigquery.TableFieldSchema( + name='r', type='RECORD', mode='NULLABLE', fields=[nested_field]) + schema = bigquery.TableSchema( + fields=[string_field, number_field, record_field]) + expected_dict_schema = {'fields': [ + {'name': 's', 'type': 'STRING', 'mode': 'NULLABLE'}, + {'name': 'n', 'type': 'INTEGER', 'mode': 'NULLABLE'}, + {'name': 'r', 'type': 'RECORD', 'mode': 'NULLABLE', 'fields': [ + {'name': 'x', 'type': 'INTEGER', 'mode': 'NULLABLE'}]}]} + dict_schema = ( + beam.io.gcp.bigquery.WriteToBigQuery.get_dict_table_schema(schema)) + self.assertEqual(expected_dict_schema, dict_schema) + + def test_table_schema_parsing_end_to_end(self): + string_field = bigquery.TableFieldSchema( + name='s', type='STRING', mode='NULLABLE') + nested_field = bigquery.TableFieldSchema( + name='x', type='INTEGER', mode='NULLABLE') + number_field = bigquery.TableFieldSchema( + name='n', type='INTEGER', mode='NULLABLE') + record_field = bigquery.TableFieldSchema( + name='r', type='RECORD', mode='NULLABLE', fields=[nested_field]) + schema = bigquery.TableSchema( + fields=[string_field, number_field, record_field]) + table_schema = beam.io.gcp.bigquery.BigQueryWriteFn.get_table_schema( + beam.io.gcp.bigquery.WriteToBigQuery.get_dict_table_schema(schema)) + self.assertEqual(table_schema, schema) + + def test_none_schema_parsing(self): + schema = None + expected_dict_schema = None + dict_schema = ( + beam.io.gcp.bigquery.WriteToBigQuery.get_dict_table_schema(schema)) + self.assertEqual(expected_dict_schema, dict_schema) + + def test_noop_dict_schema_parsing(self): + schema = {'fields': [ + {'name': 's', 'type': 'STRING', 'mode': 'NULLABLE'}, + {'name': 'n', 'type': 'INTEGER', 'mode': 'NULLABLE'}]} + expected_dict_schema = schema + dict_schema = ( + beam.io.gcp.bigquery.WriteToBigQuery.get_dict_table_schema(schema)) + self.assertEqual(expected_dict_schema, dict_schema) + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO)